easydel.layers.operations.modules.decode_attention#

Autoregressive decode attention implementation for efficient token generation.

This module provides specialized attention implementations optimized for the autoregressive decoding phase of transformer models. During generation, models process one token at a time while attending to all previously generated tokens stored in a key-value cache.

Key optimizations: - Single query token processing (query sequence length = 1) - Efficient cache access with ragged boundaries - Backend-specific kernels for TPU and GPU - Optimized memory access patterns for decode phase - Support for variable sequence lengths per batch element

The implementation uses: - Pallas kernels for TPU acceleration - Triton kernels for GPU acceleration - Native JAX operations as fallback

This is particularly important for: - Text generation and completion - Real-time inference serving - Streaming model outputs - Interactive applications

Example

>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import AutoRegressiveDecodeAttn
>>> from easydel.layers.caching.transformer import TransformerMetadata
>>>
>>> # Configure for decoding
>>> metadata = OperationMetadata(
...     runtime_dtype=jnp.float16,
...     softmax_scale=1.0 / math.sqrt(head_dim)
... )
>>> decode_attn = AutoRegressiveDecodeAttn(metadata)
>>>
>>> # Use with cache during generation
>>> cache_metadata = TransformerMetadata(
...     starts=jnp.array([0, 0, 0, 0]),  # Start indices per batch
...     indexs=jnp.array([10, 15, 8, 12])  # Current lengths per batch
... )
>>> output = decode_attn(query, cached_keys, cached_values, cache_metadata)
class easydel.layers.operations.modules.decode_attention.AutoRegressiveDecodeAttn(metadata: OperationMetadata)[source]#

Bases: OperationImpl

Attention implementation tailored for the autoregressive decoding step.

This class handles the attention mechanism when generating tokens one by one, attending to the previously generated sequence stored in a cache. It utilizes shard_map for distributed computation and supports different backends, including a potential Pallas-optimized version for TPUs. It assumes the query sequence length is 1.

metadata#

Configuration metadata for the attention mechanism.

Type

OperationMetadata

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass for autoregressive decoding attention.

Delegates to the native JAX/XLA implementation (forward_native).

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass for autoregressive decoding attention.

Delegates to the GPU implementation which uses Triton kernels. Future optimizations might add CUDA-specific kernels here.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass for autoregressive decoding attention.

Delegates to the native JAX/XLA implementation (forward_native).

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_native(query: Float[Array, 'batch 1 num_q_heads head_dim'], key: Float[Array, 'batch kv_seq_len num_kv_heads head_dim'], value: Float[Array, 'batch kv_seq_len num_kv_heads head_dim'], cache_metadata: TransformerMetadata, softmax_scale: float | None = None, sliding_window: tuple[int, int] | None = None, logits_soft_cap: float | None = None, softmax_aux: jaxtyping.Float[Array, 'num_kv_heads num_sinks'] | jaxtyping.Float[Array, 'num_sinks'] | None = None, **ignores) AttentionOutput[source]#

Performs the native JAX/XLA forward pass for autoregressive decoding attention.

This implementation uses shard_map to distribute the computation across devices and leverages the ragged_decode_attention kernel for efficient processing. It calculates attention between a single query token and all previous keys/values stored in the cache, respecting the valid range defined by cache metadata.

Parameters
  • query – Query tensor [batch, 1, num_q_heads, head_dim]. Single token query for next-token prediction.

  • key – Key tensor from cache [batch, kv_seq_len, num_kv_heads, head_dim]. All previous keys in the sequence.

  • value – Value tensor from cache [batch, kv_seq_len, num_kv_heads, head_dim]. All previous values in the sequence.

  • cache_metadata – Cache metadata containing: - starts: Start indices for valid cache entries per batch [batch]. - indexs: Current sequence lengths per batch [batch].

  • softmax_scale – Scaling factor for attention logits. Defaults to 1/sqrt(head_dim).

  • sliding_window – Window bounds (left, right) for local attention. Optional.

  • logits_soft_cap – Soft capping value to prevent extreme attention logits. Optional.

  • softmax_aux – Auxiliary tensor for sink token attention. Optional.

  • **ignores – Additional ignored keyword arguments.

Returns

  • attention_outputs: [batch, 1, num_q_heads, head_dim] Attended representation for the current query token.

  • attention_weights: None (not computed for memory efficiency).

Return type

AttentionOutput containing

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass for autoregressive decoding attention.

Delegates to the GPU implementation. Future optimizations might add ROCm-specific kernels here.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass for autoregressive decoding attention.

Delegates to the native JAX/XLA implementation (forward_native).

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

get_impl_metadata() OperationMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The OperationMetadata provided during initialization.

classmethod get_impl_name() str | tuple[str][source]#

Returns the registered name of this attention implementation.

Returns

The string “autoregressive_decodeattn”.