easydel.layers.operations.modules.decode_attention#
Autoregressive decode attention implementation for efficient token generation.
This module provides specialized attention implementations optimized for the autoregressive decoding phase of transformer models. During generation, models process one token at a time while attending to all previously generated tokens stored in a key-value cache.
Key optimizations: - Single query token processing (query sequence length = 1) - Efficient cache access with ragged boundaries - Backend-specific kernels for TPU and GPU - Optimized memory access patterns for decode phase - Support for variable sequence lengths per batch element
The implementation uses: - Pallas kernels for TPU acceleration - Triton kernels for GPU acceleration - Native JAX operations as fallback
This is particularly important for: - Text generation and completion - Real-time inference serving - Streaming model outputs - Interactive applications
Example
>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import AutoRegressiveDecodeAttn
>>> from easydel.layers.caching.transformer import TransformerMetadata
>>>
>>> # Configure for decoding
>>> metadata = OperationMetadata(
... runtime_dtype=jnp.float16,
... softmax_scale=1.0 / math.sqrt(head_dim)
... )
>>> decode_attn = AutoRegressiveDecodeAttn(metadata)
>>>
>>> # Use with cache during generation
>>> cache_metadata = TransformerMetadata(
... starts=jnp.array([0, 0, 0, 0]), # Start indices per batch
... indexs=jnp.array([10, 15, 8, 12]) # Current lengths per batch
... )
>>> output = decode_attn(query, cached_keys, cached_values, cache_metadata)
- class easydel.layers.operations.modules.decode_attention.AutoRegressiveDecodeAttn(metadata: OperationMetadata)[source]#
Bases:
OperationImplAttention implementation tailored for the autoregressive decoding step.
This class handles the attention mechanism when generating tokens one by one, attending to the previously generated sequence stored in a cache. It utilizes shard_map for distributed computation and supports different backends, including a potential Pallas-optimized version for TPUs. It assumes the query sequence length is 1.
- metadata#
Configuration metadata for the attention mechanism.
- Type
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native).
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass for autoregressive decoding attention.
Delegates to the GPU implementation which uses Triton kernels. Future optimizations might add CUDA-specific kernels here.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native).
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_native(query: Float[Array, 'batch 1 num_q_heads head_dim'], key: Float[Array, 'batch kv_seq_len num_kv_heads head_dim'], value: Float[Array, 'batch kv_seq_len num_kv_heads head_dim'], cache_metadata: TransformerMetadata, softmax_scale: float | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[Array, 'num_kv_heads num_sinks'] | jaxtyping.Float[Array, 'num_sinks'] | None = None, **ignores) AttentionOutput[source]#
Performs the native JAX/XLA forward pass for autoregressive decoding attention.
This implementation uses shard_map to distribute the computation across devices and leverages the ragged_decode_attention kernel for efficient processing. It calculates attention between a single query token and all previous keys/values stored in the cache, respecting the valid range defined by cache metadata.
- Parameters
query – Query tensor [batch, 1, num_q_heads, head_dim]. Single token query for next-token prediction.
key – Key tensor from cache [batch, kv_seq_len, num_kv_heads, head_dim]. All previous keys in the sequence.
value – Value tensor from cache [batch, kv_seq_len, num_kv_heads, head_dim]. All previous values in the sequence.
cache_metadata – Cache metadata containing: - starts: Start indices for valid cache entries per batch [batch]. - indexs: Current sequence lengths per batch [batch].
softmax_scale – Scaling factor for attention logits. Defaults to 1/sqrt(head_dim).
sliding_window – Window bounds (left, right) for local attention. Optional.
logits_soft_cap – Soft capping value to prevent extreme attention logits. Optional.
softmax_aux – Auxiliary tensor for sink token attention. Optional.
**ignores – Additional ignored keyword arguments.
- Returns
attention_outputs: [batch, 1, num_q_heads, head_dim] Attended representation for the current query token.
attention_weights: None (not computed for memory efficiency).
- Return type
AttentionOutput containing
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass for autoregressive decoding attention.
Delegates to the GPU implementation. Future optimizations might add ROCm-specific kernels here.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native).
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- get_impl_metadata() OperationMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The OperationMetadata provided during initialization.