easydel.layers.operations.modules.ring_attention#

Ring Attention implementation for distributed and ultra-long sequence processing.

This module implements Ring Attention, a technique for computing attention over extremely long sequences that exceed single-device memory capacity. Ring Attention partitions sequences across devices and uses a ring communication pattern to compute exact attention without approximation.

Key concepts: - Ring Topology: Devices arranged in a ring, each holding a chunk of the sequence - Blockwise Processing: Attention computed in blocks to minimize memory usage - Communication Pattern: Each device passes its KV chunks around the ring - Exact Computation: Produces the same result as standard attention

The implementation provides: 1. Native JAX scan-based version for all backends 2. TPU-optimized Pallas kernel for maximum performance 3. Support for sequences > 100K tokens 4. Memory usage O(N/P) where N=sequence length, P=number of devices

Ring Attention is ideal for: - Training on very long documents or books - Processing entire codebases as context - Multi-document reasoning tasks - Any scenario requiring exact attention over long sequences

Example

>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import RingAttn
>>>
>>> # Configure for distributed execution
>>> metadata = OperationMetadata(
...     runtime_dtype=jnp.bfloat16,
...     blocksize_q=512,  # Query block size
...     blocksize_k=1024,  # Key block size
...     sequence_axis_name="sp",  # Sequence parallel axis
...     scan_ring_attention=True
... )
>>> ring_attn = RingAttn(metadata)
>>>
>>> # Process ultra-long sequences
>>> output = ring_attn(query, key, value, causal=True)

References

class easydel.layers.operations.modules.ring_attention.RingAttn(metadata: OperationMetadata)[source]#

Bases: OperationImpl

Ring attention implementation for distributed and memory-efficient processing.

Ring attention processes attention in a ring topology, where each device/chunk communicates with neighbors to compute attention over very long sequences. This is particularly useful for sequences that don’t fit in memory.

Features:
  • Memory-efficient chunked processing

  • Distributed computation across devices

  • Support for sequences > 100K tokens

  • Native JAX scan-based implementation

  • TPU-optimized Pallas kernels

Registered name: “ring”

metadata#

OperationMetadata configuration

Type

easydel.layers.operations._operation_meta.OperationMetadata | None

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native (scan-based).

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass. Currently delegates to forward_native (scan-based).

Future versions may include CUDA-specific ring attention kernels.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Currently delegates to forward_native (scan-based).

Future versions may include GPU-specific ring attention kernels.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_native(query: Float[Array, 'batch seq_len_q num_heads head_dim'], key: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], value: Float[Array, 'batch seq_len_k num_kv_heads head_dim'], softmax_aux: jaxtyping.Float[Array, 'num_kv_heads num_sinks'] | jaxtyping.Float[Array, 'num_sinks'] | None = None, mask_info: ejkernel.types.mask.MaskInfo | None = None, logits_soft_cap: float | None = None, softmax_scale: float | None = None, sliding_window: int | tuple[int, int] | None = None, causal: bool = True, fused_backward: bool = False, **ignore) AttentionOutput[source]#

Computes attention using the scan-based blockwise_attn function.

Handles optional mask/bias, KV head repetition, and sharding constraints.

Parameters
  • q – Query tensor (B, T, H, D).

  • k – Key tensor (B, S, H_kv, D).

  • v – Value tensor (B, S, H_kv, D).

  • mask – Optional boolean attention mask (broadcastable to B, 1, T, S).

  • bias – Optional attention bias (broadcastable to B, H, T, S).

  • init_bias – Optional callable to initialize bias if mask/bias are None.

  • deterministic – If False, enables dropout. Requires dropout_rng.

  • dropout_rng – JAX PRNG key for dropout if deterministic is False.

  • causal – Apply causal mask if True.

  • **ignore – Ignored keyword arguments.

Returns

AttentionOutput containing the attention result.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Currently delegates to forward_native (scan-based).

Future versions may include ROCm-specific ring attention kernels.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

TPU-specific implementation of the operation.

Defaults to calling forward_native. Subclasses can override this for TPU-specific optimizations.

Parameters
  • *args – Positional arguments for the operation.

  • **kwargs – Keyword arguments for the operation.

Returns

The result of the operation, potentially optimized for TPU.

get_impl_metadata() OperationMetadata[source]#

Get the metadata configuration for this attention instance.

Returns

Configuration including dtype, mesh, etc.

Return type

OperationMetadata

classmethod get_impl_name() str | tuple[str][source]#

Get the registered name for this attention implementation.

Returns

The name “ring” used for registry lookup.

Return type

str