easydel.layers.attention_operator.modules.flash#

class easydel.layers.attention_operator.modules.flash.FlashAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

An implementation of Flash Attention V2 using specialized JAX primitives.

This class leverages jax.experimental.pallas.ops.tpu.flash_attention for TPUs and a Triton kernel (triton_flash_attention) for GPUs (CUDA). It is registered under the name “flash_attn2”. CPU execution is not supported and will raise an error.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native, which raises an error.

Raises

NotImplementedError – Via forward_native.

forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Performs Flash Attention V2 on GPU (CUDA) using a Triton kernel.

Handles optional mask/bias, KV head repetition (inside the kernel or before), and sharding based on metadata. Assumes triton_flash_attention handles KV head repetition internally if needed or expects broadcastable KV. Assumes BTHD input/output format for the Triton kernel.

Parameters
  • q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).

  • k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).

  • v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).

  • mask – Optional boolean attention mask. Used by the kernel if bias is not provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.

  • bias – Optional attention bias tensor. Added by the kernel. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask within the kernel logic if both are somehow passed.

  • init_bias – Optional callable function to initialize bias if mask and bias are None.

  • causal – If True, instructs the kernel to apply a causal mask.

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to the CUDA-specific implementation.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Native (CPU) forward pass for Flash Attention. Not implemented.

Raises

NotImplementedError – Flash Attention is not supported on CPU via this implementation.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Not implemented, falls back to native (error).

Raises

NotImplementedError – Via forward_native. ROCm support requires a specific kernel.

forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#

Performs Flash Attention V2 on TPU using jax.experimental.pallas.ops.tpu.flash_attention.

Handles optional mask/bias, KV head repetition, and sharding based on metadata. Note: The Pallas implementation expects inputs in BHTD format.

Parameters
  • q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).

  • k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).

  • v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).

  • mask – Optional boolean attention mask. Bias will be generated from this if provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.

  • bias – Optional attention bias tensor. Added directly to attention scores. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask.

  • init_bias – Optional callable function to initialize bias if mask and bias are None.

  • causal – If True, applies a causal mask. Ignored if q_seq_len is 1 (generation).

  • **ignore – Ignored keyword arguments.

Returns

An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The AttentionMetadata provided during initialization.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name of this attention implementation.

Returns

The string “flash_attn2”.