easydel.kernels.gpu_ops.flash_attention_triton._utils#

easydel.kernels.gpu_ops.flash_attention_triton._utils.attention_pack_with_static_shape(x: Array, attention_mask: Array, max_tokens: int = None) Array[source]#

Pack attention tensor by removing padding based on attention mask. Uses a static maximum shape to be compatible with JIT.

easydel.kernels.gpu_ops.flash_attention_triton._utils.attention_unpack_with_static_shape(x: Array, cum_seqlens: Array, batch_size: int, seqlen: int) Array[source]#

Unpack attention tensor by redistributing the packed values to their original positions.

Parameters
  • x – Packed tensor of shape [1, packed_tokens, num_heads, head_dim]

  • cum_seqlens – Cumulative sequence lengths, shape [batch_size+1]

  • batch_size – Number of batches

  • seqlen – Maximum sequence length

Returns

Unpacked tensor of shape [batch_size, seqlen, num_heads, head_dim]

easydel.kernels.gpu_ops.flash_attention_triton._utils.basic_attention_refrence(q: Array, k: Array, v: Array, attn_bias: Optional[Array] = None, query_padding_mask: Optional[Array] = None, key_padding_mask: Optional[Array] = None, dropout_prob: float = 0.0, dropout_key: Optional[PRNGKey] = None, window_size: Tuple[int, int] = (-1, -1), causal: bool = False, softcap: float = 0.0)[source]#
easydel.kernels.gpu_ops.flash_attention_triton._utils.calc_bias_strides(bias: Optional[Array], batch: int, nheads_q: int, QSeq: int, KSeq: int) Tuple[int, ...][source]#