easydel.layers.operations.modules.ring_attention#
Ring Attention implementation for distributed and ultra-long sequence processing.
This module implements Ring Attention, a technique for computing attention over extremely long sequences that exceed single-device memory capacity. Ring Attention partitions sequences across devices and uses a ring communication pattern to compute exact attention without approximation.
Key concepts: - Ring Topology: Devices arranged in a ring, each holding a chunk of the sequence - Blockwise Processing: Attention computed in blocks to minimize memory usage - Communication Pattern: Each device passes its KV chunks around the ring - Exact Computation: Produces the same result as standard attention
The implementation provides: 1. Native JAX scan-based version for all backends 2. TPU-optimized Pallas kernel for maximum performance 3. Support for sequences > 100K tokens 4. Memory usage O(N/P) where N=sequence length, P=number of devices
Ring Attention is ideal for: - Training on very long documents or books - Processing entire codebases as context - Multi-document reasoning tasks - Any scenario requiring exact attention over long sequences
Example
>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import RingAttn
>>>
>>> # Configure for distributed execution
>>> metadata = OperationMetadata(
... runtime_dtype=jnp.bfloat16,
... blocksize_q=512, # Query block size
... blocksize_k=1024, # Key block size
... sequence_axis_name="sp", # Sequence parallel axis
... scan_ring_attention=True
... )
>>> ring_attn = RingAttn(metadata)
>>>
>>> # Process ultra-long sequences
>>> output = ring_attn(query, key, value, causal=True)
References
Liu et al., “Ring Attention with Blockwise Transformers for Near-Infinite Context” (2023)
- class easydel.layers.operations.modules.ring_attention.RingAttn(metadata: OperationMetadata)[source]#
Bases:
OperationImplRing attention implementation for distributed and memory-efficient processing.
Ring attention processes attention in a ring topology, where each device/chunk communicates with neighbors to compute attention over very long sequences. This is particularly useful for sequences that don’t fit in memory.
- Features:
Memory-efficient chunked processing
Distributed computation across devices
Support for sequences > 100K tokens
Native JAX scan-based implementation
TPU-optimized Pallas kernels
Registered name: “ring”
- metadata#
OperationMetadata configuration
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (scan-based).
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Currently delegates to forward_native (scan-based).
Future versions may include CUDA-specific ring attention kernels.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Currently delegates to forward_native (scan-based).
Future versions may include GPU-specific ring attention kernels.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_native(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], softmax_aux: jaxtyping.Float[Array, 'num_kv_heads num_sinks'] | jaxtyping.Float[Array, 'num_sinks'] | None = None, mask_info: ejkernel.types.mask.MaskInfo | None = None, logits_soft_cap: float | None = None, softmax_scale: float | None = None, sliding_window: int | tuple[int, int] | None = None, causal: bool = True, fused_backward: bool = False, **ignore) AttentionOutput[source]#
Computes attention using the scan-based blockwise_attn function.
Handles optional mask/bias, KV head repetition, and sharding constraints.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If False, enables dropout. Requires dropout_rng.
dropout_rng – JAX PRNG key for dropout if deterministic is False.
causal – Apply causal mask if True.
**ignore – Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native (scan-based).
Future versions may include ROCm-specific ring attention kernels.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU-specific implementation of the operation.
Defaults to calling forward_native. Subclasses can override this for TPU-specific optimizations.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for TPU.
- get_impl_metadata() OperationMetadata[source]#
Get the metadata configuration for this attention instance.
- Returns
Configuration including dtype, mesh, etc.
- Return type