easydel.layers.operations.modules.flash_attention#

Flash Attention V2 implementation for EasyDeL.

This module provides optimized Flash Attention V2 implementations for TPU and GPU backends using JAX’s Pallas operations and Triton kernels respectively. Flash Attention is a memory-efficient attention mechanism that reduces memory usage from O(N²) to O(N) by computing attention in blocks and avoiding materialization of the full attention matrix.

Key Features: - TPU implementation using JAX Pallas operations - GPU implementation using Triton kernels - Support for causal masking - Support for attention bias - Efficient handling of multi-query and grouped-query attention - Automatic sharding for distributed computation

The implementation follows the Flash Attention V2 algorithm which: 1. Processes attention in blocks to minimize HBM access 2. Uses online softmax computation to avoid storing the full attention matrix 3. Achieves significant speedup and memory savings compared to standard attention

Note: This implementation does not support CPU execution as Flash Attention relies on specialized hardware features available only on TPUs and GPUs.

Example

>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import FlashAttn
>>>
>>> metadata = OperationMetadata(
...     runtime_dtype=jnp.float16,
...     softmax_scale=1.0 / math.sqrt(head_dim),
...     blocksize_q=512,
...     blocksize_k=1024
... )
>>> flash_attn = FlashAttn(metadata)
>>> output = flash_attn(query, key, value, causal=True)
class easydel.layers.operations.modules.flash_attention.FlashAttn(metadata: OperationMetadata)[source]#

Bases: OperationImpl

An implementation of Flash Attention V2 using specialized JAX primitives.

This class leverages jax.experimental.pallas.ops.tpu.flash_attention for TPUs and a Triton kernel (triton_flash_attention) for GPUs (CUDA). It is registered under the name “flash_attn2”. CPU execution is not supported and will raise an error.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native, which raises an error.

Raises

NotImplementedError – Via forward_native.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to the CUDA-specific implementation.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to the CUDA-specific implementation.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_native(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], mask_info: ejkernel.types.mask.MaskInfo | None = None, bias: jaxtyping.Float[Array, 'batch num_heads seq_len_q seq_len_k'] | None = None, softmax_scale: float | None = None, dropout_prob: float = 0.0, causal: bool = False, dropout_seed: int | None = None, cum_seqlens_q: jaxtyping.Int[Array, 'batch_plus_one'] | None = None, cum_seqlens_k: jaxtyping.Int[Array, 'batch_plus_one'] | None = None, sliding_window: int | tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[Array, 'num_heads num_sinks'] | jaxtyping.Float[Array, 'num_sinks'] | None = None, normalize_output: bool = True, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset] = Precision.DEFAULT, **ignore) AttentionOutput[source]#

Performs Flash Attention V2 using optimized kernels (TPU Pallas or GPU Triton).

Flash Attention V2 is a memory-efficient attention mechanism that reduces memory usage from O(N²) to O(N) by computing attention in blocks and avoiding materialization of the full attention matrix. This implementation uses specialized kernels for different hardware backends.

Parameters
  • query – Query tensor [batch, seq_len_q, num_heads, head_dim].

  • key – Key tensor [batch, seq_len_k, num_kv_heads, head_dim].

  • value – Value tensor [batch, seq_len_k, num_kv_heads, head_dim].

  • attention_mask – Optional boolean attention mask [batch, num_heads_or_1, seq_len_q, seq_len_k]. Used by the kernel if bias is not provided.

  • bias – Optional attention bias tensor [batch, num_heads, seq_len_q, seq_len_k]. Added to attention logits before softmax. Takes precedence over attention_mask.

  • softmax_scale – Scaling factor for attention logits. Defaults to 1/sqrt(head_dim).

  • dropout_prob – Dropout probability for attention weights. Defaults to 0.0.

  • causal – If True, applies causal (autoregressive) masking. Defaults to False.

  • dropout_seed – Random seed for dropout. Optional.

  • cum_seqlens_q – Cumulative sequence lengths for queries (for variable-length sequences).

  • cum_seqlens_k – Cumulative sequence lengths for keys (for variable-length sequences).

  • sliding_window – Sliding window size for local attention. Optional.

  • logits_soft_cap – Soft capping value for attention logits. Optional.

  • softmax_aux – Auxiliary softmax tensor (e.g., for sink tokens). Optional.

  • normalize_output – Whether to normalize the output. Defaults to True.

  • precision – JAX precision setting for matmul operations.

  • q_segment_ids – Segment IDs for queries. Optional.

  • kv_segment_ids – Segment IDs for keys/values. Optional.

  • **ignore – Additional ignored keyword arguments.

Returns

Object containing attention outputs [batch, seq_len_q, num_heads, head_dim].

Attention weights are not computed for efficiency.

Return type

AttentionOutput

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass.

Currently delegates to the standard GPU implementation. Future versions may include ROCm-specific optimizations using hipFlashAttention or similar.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to the CUDA-specific implementation.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

get_impl_metadata() OperationMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The OperationMetadata provided during initialization.

classmethod get_impl_name() str | tuple[str][source]#

Returns the registered name of this attention implementation.

Returns

The string “flash_attn2”.