easydel.layers.attention_operator.modules.ring#
- class easydel.layers.attention_operator.modules.ring.Carry(numerator, denominator, max_so_far)[source]#
Bases:
NamedTuple
- class easydel.layers.attention_operator.modules.ring.RingAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation using ring-passing algorithm or blockwise scan.
This implementation supports: - Native (scan-based) blockwise attention via blockwise_attn. - TPU-specific ring attention using pallas_ring_attention kernel.
It is registered under the name โringโ.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (scan-based).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Computes attention using the scan-based blockwise_attn function.
Handles optional mask/bias, KV head repetition, and sharding constraints.
- Parameters
q โ Query tensor (B, T, H, D).
k โ Key tensor (B, S, H_kv, D).
v โ Value tensor (B, S, H_kv, D).
mask โ Optional boolean attention mask (broadcastable to B, 1, T, S).
bias โ Optional attention bias (broadcastable to B, H, T, S).
init_bias โ Optional callable to initialize bias if mask/bias are None.
deterministic โ If False, enables dropout. Requires dropout_rng.
dropout_rng โ JAX PRNG key for dropout if deterministic is False.
causal โ Apply causal mask if True.
**ignore โ Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Computes Ring Attention on TPU using the pallas_ring_attention kernel.
Handles optional mask/bias, sharding, and passes configuration to the kernel.
- Parameters
q โ Query tensor (B, T, H, D).
k โ Key tensor (B, S, H_kv, D).
v โ Value tensor (B, S, H_kv, D).
mask โ Optional boolean attention mask (broadcastable to B, 1, T, S).
bias โ Optional attention bias (broadcastable to B, H, T, S).
init_bias โ Optional callable to initialize bias if mask/bias are None.
deterministic โ If False, potentially enables dropout within the kernel (if supported).
dropout_rng โ JAX PRNG key (may be used by the kernel if dropout is enabled).
causal โ Apply causal mask if True. Passed to the kernel.
**ignore โ Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this instance.
- easydel.layers.attention_operator.modules.ring.blockwise_attn(query, key, value, bias=None, deterministic=True, dropout_rng=None, attn_pdrop=0.0, causal=True, query_chunk_size=2048, key_chunk_size=2048, dtype=<class 'jax.numpy.float32'>, policy=False, precision=None, float32_logits=True, prevent_cse=True)[source]#