easydel.layers.attention_operator.modules.ring#

class easydel.layers.attention_operator.modules.ring.Carry(numerator, denominator, max_so_far)[source]#

Bases: NamedTuple

denominator: Array#

Alias for field number 1

max_so_far: Array#

Alias for field number 2

numerator: Array#

Alias for field number 0

class easydel.layers.attention_operator.modules.ring.RingAttn(metadata: AttentionMetadata)[source]#

Bases: AttentionImpl

Attention implementation using ring-passing algorithm or blockwise scan.

This implementation supports: - Native (scan-based) blockwise attention via blockwise_attn. - TPU-specific ring attention using pallas_ring_attention kernel.

It is registered under the name โ€œringโ€.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native (scan-based).

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass. Currently delegates to forward_native (scan-based).

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Currently delegates to forward_native (scan-based).

forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#

Computes attention using the scan-based blockwise_attn function.

Handles optional mask/bias, KV head repetition, and sharding constraints.

Parameters
  • q โ€“ Query tensor (B, T, H, D).

  • k โ€“ Key tensor (B, S, H_kv, D).

  • v โ€“ Value tensor (B, S, H_kv, D).

  • mask โ€“ Optional boolean attention mask (broadcastable to B, 1, T, S).

  • bias โ€“ Optional attention bias (broadcastable to B, H, T, S).

  • init_bias โ€“ Optional callable to initialize bias if mask/bias are None.

  • deterministic โ€“ If False, enables dropout. Requires dropout_rng.

  • dropout_rng โ€“ JAX PRNG key for dropout if deterministic is False.

  • causal โ€“ Apply causal mask if True.

  • **ignore โ€“ Ignored keyword arguments.

Returns

AttentionOutput containing the attention result.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Currently delegates to forward_native (scan-based).

forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#

Computes Ring Attention on TPU using the pallas_ring_attention kernel.

Handles optional mask/bias, sharding, and passes configuration to the kernel.

Parameters
  • q โ€“ Query tensor (B, T, H, D).

  • k โ€“ Key tensor (B, S, H_kv, D).

  • v โ€“ Value tensor (B, S, H_kv, D).

  • mask โ€“ Optional boolean attention mask (broadcastable to B, 1, T, S).

  • bias โ€“ Optional attention bias (broadcastable to B, H, T, S).

  • init_bias โ€“ Optional callable to initialize bias if mask/bias are None.

  • deterministic โ€“ If False, potentially enables dropout within the kernel (if supported).

  • dropout_rng โ€“ JAX PRNG key (may be used by the kernel if dropout is enabled).

  • causal โ€“ Apply causal mask if True. Passed to the kernel.

  • **ignore โ€“ Ignored keyword arguments.

Returns

AttentionOutput containing the attention result.

get_impl_metadata() AttentionMetadata[source]#

Returns the metadata associated with this instance.

classmethod get_impl_name() Union[str, Tuple[str]][source]#

Returns the registered name: โ€œringโ€.

easydel.layers.attention_operator.modules.ring.blockwise_attn(query, key, value, bias=None, deterministic=True, dropout_rng=None, attn_pdrop=0.0, causal=True, query_chunk_size=2048, key_chunk_size=2048, dtype=<class 'jax.numpy.float32'>, policy=False, precision=None, float32_logits=True, prevent_cse=True)[source]#