easydel.layers.attention_operator.modules.__init__#
- class easydel.layers.attention_operator.modules.__init__.AutoRegressiveDecodeAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation tailored for the autoregressive decoding step.
This class handles the attention mechanism when generating tokens one by one, attending to the previously generated sequence stored in a cache. It utilizes shard_map for distributed computation and supports different backends, including a potential Pallas-optimized version for TPUs. It assumes the query sequence length is 1.
- metadata#
Configuration metadata for the attention mechanism.
- Type
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native). Future optimizations might add CUDA-specific kernels here.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass for autoregressive decoding attention.
Currently delegates to forward_cuda.
- forward_native(q: Array, k: Array, v: Array, cache_view: TransformerCacheView, **ignores) AttentionOutput[source]#
Performs the native JAX/XLA forward pass for autoregressive decoding attention.
This implementation uses shard_map to distribute the computation and relies on standard JAX operations (einsum, softmax). It calculates attention weights between the single query token and the keys in the cache, applies masking based on the valid range defined in cache_view, computes the softmax, and finally computes the weighted sum of values.
- Parameters
q (Array) – Query tensor of shape (batch_size, 1, num_query_heads, head_dim).
k (Array) – Key tensor (from cache) of shape (batch_size, kv_sequence_length, num_kv_heads, head_dim).
v (Array) – Value tensor (from cache) of shape (batch_size, kv_sequence_length, num_kv_heads, head_dim).
cache_view (TransformerCacheView) – Contains metadata about the cache, specifically starts (start index for valid keys/values) and index (current index or length).
**ignores – Ignored keyword arguments.
- Returns
- An object containing the attention outputs (attention_outputs)
of shape (batch_size, 1, num_query_heads, head_dim). Attention weights are not returned.
- Return type
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass for autoregressive decoding attention.
Delegates to the native JAX/XLA implementation (forward_native). Future optimizations might add ROCm-specific kernels here.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass for autoregressive decoding attention.
Currently delegates to the native JAX/XLA implementation (forward_native). Consider using _forward_tpu for potential Pallas optimization if available/enabled.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.modules.__init__.FlashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn implementation of Flash Attention V2 using specialized JAX primitives.
This class leverages jax.experimental.pallas.ops.tpu.flash_attention for TPUs and a Triton kernel (triton_flash_attention) for GPUs (CUDA). It is registered under the name “flash_attn2”. CPU execution is not supported and will raise an error.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native, which raises an error.
- Raises
NotImplementedError – Via forward_native.
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Performs Flash Attention V2 on GPU (CUDA) using a Triton kernel.
Handles optional mask/bias, KV head repetition (inside the kernel or before), and sharding based on metadata. Assumes triton_flash_attention handles KV head repetition internally if needed or expects broadcastable KV. Assumes BTHD input/output format for the Triton kernel.
- Parameters
q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).
k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
mask – Optional boolean attention mask. Used by the kernel if bias is not provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.
bias – Optional attention bias tensor. Added by the kernel. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask within the kernel logic if both are somehow passed.
init_bias – Optional callable function to initialize bias if mask and bias are None.
causal – If True, instructs the kernel to apply a causal mask.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to the CUDA-specific implementation.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Native (CPU) forward pass for Flash Attention. Not implemented.
- Raises
NotImplementedError – Flash Attention is not supported on CPU via this implementation.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented, falls back to native (error).
- Raises
NotImplementedError – Via forward_native. ROCm support requires a specific kernel.
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Performs Flash Attention V2 on TPU using jax.experimental.pallas.ops.tpu.flash_attention.
Handles optional mask/bias, KV head repetition, and sharding based on metadata. Note: The Pallas implementation expects inputs in BHTD format.
- Parameters
q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).
k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
mask – Optional boolean attention mask. Bias will be generated from this if provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.
bias – Optional attention bias tensor. Added directly to attention scores. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask.
init_bias – Optional callable function to initialize bias if mask and bias are None.
causal – If True, applies a causal mask. Ignored if q_seq_len is 1 (generation).
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.modules.__init__.PagedAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation using the Paged Attention mechanism with TPU Pallas kernels.
This class provides an attention mechanism suitable for scenarios where the Key-Value cache is managed in non-contiguous pages (Paged KV Cache). It leverages specialized kernels for efficient execution on TPUs, handling prefill and decode phases separately or in a mixed mode.
- metadata#
Configuration metadata for the attention mechanism. While this class uses AttentionMetadata, it primarily relies on the additional PagedAttentionMetadata passed during the forward call for paged-specific information.
- Type
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass.
- Raises
NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific CUDA implementation. (Future work might add this).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass.
- Raises
NotImplementedError – Paged Attention currently relies on Pallas for TPUs and does not have a specific ROCm implementation.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
Generic GPU forward pass.
- Raises
NotImplementedError – Paged Attention relies on specific kernels (currently Pallas for TPU) and does not have a generic GPU implementation.
- forward_native(*args, **kwargs) AttentionOutput[source]#
Native (CPU) forward pass.
- Raises
NotImplementedError – Paged Attention requires specialized kernels and does not have a native CPU implementation.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented for Paged Attention.
- forward_tpu(q: Array, k: Array, v: Array, cache_view: PagedAttentionCacheView, cache_metadata: PagedAttentionMetadata, **ignore) AttentionOutput[source]#
TPU forward pass for Paged Attention.
Determines the execution mode (prefill, decode, or mixed) based on the provided cache_metadata and dispatches the computation to the corresponding internal TPU method (_prefill_tpu, _decode_tpu, _mixed_tpu).
- Parameters
q (Array) – Query tensor. Shape depends on mode (prefill/decode/mixed).
k (Array) – Key tensor (ignored).
v (Array) – Value tensor (ignored).
cache_view (PagedAttentionCacheView) – Contains the paged KV cache tensors.
cache_metadata (PagedAttentionMetadata) – Contains metadata describing the state and mode (prefill/decode/mixed) of the current batch.
**ignore – Ignored keyword arguments.
- Returns
- An object containing the computed attention outputs.
Attention weights are typically not computed or returned in paged attention.
- Return type
- get_impl_metadata() AttentionMetadata[source]#
Retrieves the metadata associated with this attention implementation instance.
- Returns
The metadata object provided during initialization.
- Return type
- class easydel.layers.attention_operator.modules.__init__.RingAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation using ring-passing algorithm or blockwise scan.
This implementation supports: - Native (scan-based) blockwise attention via blockwise_attn. - TPU-specific ring attention using pallas_ring_attention kernel.
It is registered under the name “ring”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (scan-based).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Computes attention using the scan-based blockwise_attn function.
Handles optional mask/bias, KV head repetition, and sharding constraints.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If False, enables dropout. Requires dropout_rng.
dropout_rng – JAX PRNG key for dropout if deterministic is False.
causal – Apply causal mask if True.
**ignore – Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Computes Ring Attention on TPU using the pallas_ring_attention kernel.
Handles optional mask/bias, sharding, and passes configuration to the kernel.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If False, potentially enables dropout within the kernel (if supported).
dropout_rng – JAX PRNG key (may be used by the kernel if dropout is enabled).
causal – Apply causal mask if True. Passed to the kernel.
**ignore – Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this instance.
- class easydel.layers.attention_operator.modules.__init__.ScaledDotProductAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn attention implementation that leverages jax.nn.dot_product_attention.
This class utilizes JAX’s optimized SDPA primitive, which can dispatch to different backend implementations (like XLA, cuDNN, or potentially Flash Attention emulation on CUDA depending on JAX version and hardware).
It handles sharding using shard_map and manages backend-specific dispatch (primarily distinguishing between CUDA/GPU and other backends like TPU/CPU).
Registered under the names “sdpa”, “cudnn”, and “cuda_flash_attn2”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (XLA implementation).
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Computes attention using jax.nn.dot_product_attention with the “cudnn” implementation.
This is optimized for NVIDIA GPUs using cuDNN. It applies sharding via shard_map. Note: The cuDNN implementation might have specific requirements (e.g., dtype). Causal masking is disabled during generation mode (q_len=1) as it’s unnecessary.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias tensor (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
causal – If True, applies causal masking within the primitive, unless in generation mode.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object. Weights are not returned.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to the CUDA-specific implementation.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Computes attention using jax.nn.dot_product_attention with the “xla” implementation.
This is typically used for CPU and TPU backends. It applies sharding via shard_map.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Passed directly to the primitive.
bias – Optional attention bias tensor (broadcastable to B, H, T, S). Passed directly to the primitive. If bias is provided, causal is forced to False.
init_bias – Optional callable to initialize bias if mask/bias are None.
causal – If True and bias is None, applies causal masking within the primitive.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object. Note that jax.nn.dot_product_attention typically does not return attention weights.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native (XLA implementation).
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.modules.__init__.SplashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn attention implementation using the Pallas Splash Attention kernel for TPUs.
Splash Attention is an optimized attention mechanism designed for TPUs. This implementation provides a wrapper around the make_splash_mqa_single_device primitive.
Note
This implementation is primarily intended for TPUs.
- It falls back to VanillaAttn under certain conditions:
Query sequence length is 1 (generation mode).
causal is False.
Query sequence length is not divisible by 128 (kernel constraint).
Non-TPU forward methods (forward_native, forward_gpu, etc.) are not implemented and will raise NotImplementedError.
Registered under the name “splash”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Not implemented for Splash Attention.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Not implemented for Splash Attention.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Not implemented for Splash Attention.
- forward_native(*args, **kwargs) AttentionOutput[source]#
Native (CPU) forward pass. Not implemented for Splash Attention.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented for Splash Attention.
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, causal: bool = True, cache_view: Optional[TransformerCacheView] = None, **ignore) AttentionOutput[source]#
Performs Splash Attention on TPU using the Pallas kernel.
Handles fallback logic, mask processing, block size configuration, and sharding via shard_map. Expects inputs potentially in BTHD format and transposes them to BHTD for the kernel.
- Parameters
q – Query tensor (B, T, Hq, D).
k – Key tensor (B, S, Hkv, D).
v – Value tensor (B, S, Hkv, Dv).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used to generate segment IDs if provided.
causal – If True, applies causal masking via the kernel’s mask configuration. If False, falls back to VanillaAttn.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are not computed or returned by Splash Attention.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.modules.__init__.VanillaAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplA standard, non-optimized implementation of multi-head attention.
This implementation uses basic JAX operations like jnp.einsum and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping.
Registered under the name “vanilla”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Delegates to forward_native.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to forward_native.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = True, dropout_rng: Optional[PRNGKey] = None, **ignore) AttentionOutput[source]#
Computes multi-head attention using standard JAX operations.
Supports GQA/MQA by reshaping the query tensor to match the number of key/value heads. Applies scaling, optional bias/mask, softmax (potentially in float32), and optional dropout.
- Parameters
q – Query tensor (B, T, H_q, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used if bias is not provided.
bias – Optional attention bias tensor (broadcastable to B, H_q, T, S). Takes precedence over mask.
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If True, disables dropout.
dropout_rng – JAX PRNG key for dropout. Required if deterministic is False and dropout_prob > 0.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention weights (if computed) and the final attention outputs.
- Raises
NotImplementedError – If the bias head dimension cannot be reshaped correctly to match the query head structure for GQA/MQA.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Delegates to forward_native.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.