easydel.layers.attention_operator.modules.__init__#
- class easydel.layers.attention_operator.modules.__init__.FlashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn implementation of Flash Attention V2 using specialized JAX primitives.
This class leverages jax.experimental.pallas.ops.tpu.flash_attention for TPUs and a Triton kernel (triton_flash_attention) for GPUs (CUDA). It is registered under the name “flash_attn2”. CPU execution is not supported and will raise an error.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native, which raises an error.
- Raises
NotImplementedError – Via forward_native.
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Performs Flash Attention V2 on GPU (CUDA) using a Triton kernel.
Handles optional mask/bias, KV head repetition (inside the kernel or before), and sharding based on metadata. Assumes triton_flash_attention handles KV head repetition internally if needed or expects broadcastable KV. Assumes BTHD input/output format for the Triton kernel.
- Parameters
q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).
k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
mask – Optional boolean attention mask. Used by the kernel if bias is not provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.
bias – Optional attention bias tensor. Added by the kernel. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask within the kernel logic if both are somehow passed.
init_bias – Optional callable function to initialize bias if mask and bias are None.
causal – If True, instructs the kernel to apply a causal mask.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to the CUDA-specific implementation.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Native (CPU) forward pass for Flash Attention. Not implemented.
- Raises
NotImplementedError – Flash Attention is not supported on CPU via this implementation.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented, falls back to native (error).
- Raises
NotImplementedError – Via forward_native. ROCm support requires a specific kernel.
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Performs Flash Attention V2 on TPU using jax.experimental.pallas.ops.tpu.flash_attention.
Handles optional mask/bias, KV head repetition, and sharding based on metadata. Note: The Pallas implementation expects inputs in BHTD format.
- Parameters
q – Query tensor, expected shape (batch, q_seq_len, num_q_heads, head_dim).
k – Key tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
v – Value tensor, expected shape (batch, kv_seq_len, num_kv_heads, head_dim).
mask – Optional boolean attention mask. Bias will be generated from this if provided. Shape typically (batch, 1, q_seq_len, kv_seq_len) or broadcastable.
bias – Optional attention bias tensor. Added directly to attention scores. Shape typically (batch, num_heads, q_seq_len, kv_seq_len) or broadcastable. Takes precedence over mask.
init_bias – Optional callable function to initialize bias if mask and bias are None.
causal – If True, applies a causal mask. Ignored if q_seq_len is 1 (generation).
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are typically not computed or returned by Flash Attention implementations.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.modules.__init__.PagedAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImpl- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Not implemented for Paged Attention.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Not implemented for Paged Attention.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Not implemented for Paged Attention.
- forward_native(*args, **kwargs) AttentionOutput[source]#
Native (CPU) forward pass. Not implemented for Paged Attention.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented for Paged Attention.
- forward_tpu(q: Array, k: Array, v: Array, cache_view: PagedAttentionCacheView, cache_metadata: PagedAttentionMetadata, **ignore) AttentionOutput[source]#
TPU-specific implementation of the operation.
Defaults to calling forward_native. Subclasses can override this for TPU-specific optimizations.
- Parameters
*args – Positional arguments for the operation.
**kwargs – Keyword arguments for the operation.
- Returns
The result of the operation, potentially optimized for TPU.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.modules.__init__.RingAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAttention implementation using ring-passing algorithm or blockwise scan.
This implementation supports: - Native (scan-based) blockwise attention via blockwise_attn. - TPU-specific ring attention using pallas_ring_attention kernel.
It is registered under the name “ring”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (scan-based).
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Computes attention using the scan-based blockwise_attn function.
Handles optional mask/bias, KV head repetition, and sharding constraints.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If False, enables dropout. Requires dropout_rng.
dropout_rng – JAX PRNG key for dropout if deterministic is False.
causal – Apply causal mask if True.
**ignore – Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native (scan-based).
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = False, dropout_rng: Optional[PRNGKey] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Computes Ring Attention on TPU using the pallas_ring_attention kernel.
Handles optional mask/bias, sharding, and passes configuration to the kernel.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If False, potentially enables dropout within the kernel (if supported).
dropout_rng – JAX PRNG key (may be used by the kernel if dropout is enabled).
causal – Apply causal mask if True. Passed to the kernel.
**ignore – Ignored keyword arguments.
- Returns
AttentionOutput containing the attention result.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this instance.
- class easydel.layers.attention_operator.modules.__init__.ScaledDotProductAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn attention implementation that leverages jax.nn.dot_product_attention.
This class utilizes JAX’s optimized SDPA primitive, which can dispatch to different backend implementations (like XLA, cuDNN, or potentially Flash Attention emulation on CUDA depending on JAX version and hardware).
It handles sharding using shard_map and manages backend-specific dispatch (primarily distinguishing between CUDA/GPU and other backends like TPU/CPU).
Registered under the names “sdpa”, “cudnn”, and “cuda_flash_attn2”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native (XLA implementation).
- forward_cuda(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Computes attention using jax.nn.dot_product_attention with the “cudnn” implementation.
This is optimized for NVIDIA GPUs using cuDNN. It applies sharding via shard_map. Note: The cuDNN implementation might have specific requirements (e.g., dtype). Causal masking is disabled during generation mode (q_len=1) as it’s unnecessary.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S).
bias – Optional attention bias tensor (broadcastable to B, H, T, S).
init_bias – Optional callable to initialize bias if mask/bias are None.
causal – If True, applies causal masking within the primitive, unless in generation mode.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object. Weights are not returned.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to the CUDA-specific implementation.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, causal: bool = False, **ignore) AttentionOutput[source]#
Computes attention using jax.nn.dot_product_attention with the “xla” implementation.
This is typically used for CPU and TPU backends. It applies sharding via shard_map.
- Parameters
q – Query tensor (B, T, H, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Passed directly to the primitive.
bias – Optional attention bias tensor (broadcastable to B, H, T, S). Passed directly to the primitive. If bias is provided, causal is forced to False.
init_bias – Optional callable to initialize bias if mask/bias are None.
causal – If True and bias is None, applies causal masking within the primitive.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object. Note that jax.nn.dot_product_attention typically does not return attention weights.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Currently delegates to forward_native.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native (XLA implementation).
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.modules.__init__.SplashAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplAn attention implementation using the Pallas Splash Attention kernel for TPUs.
Splash Attention is an optimized attention mechanism designed for TPUs. This implementation provides a wrapper around the make_splash_mqa_single_device primitive.
Note
This implementation is primarily intended for TPUs.
- It falls back to VanillaAttn under certain conditions:
Query sequence length is 1 (generation mode).
causal is False.
Query sequence length is not divisible by 128 (kernel constraint).
Non-TPU forward methods (forward_native, forward_gpu, etc.) are not implemented and will raise NotImplementedError.
Registered under the name “splash”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Not implemented for Splash Attention.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Not implemented for Splash Attention.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Not implemented for Splash Attention.
- forward_native(*args, **kwargs) AttentionOutput[source]#
Native (CPU) forward pass. Not implemented for Splash Attention.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Not implemented for Splash Attention.
- forward_tpu(q: Array, k: Array, v: Array, mask: Optional[Array] = None, causal: bool = True, **ignore) AttentionOutput[source]#
Performs Splash Attention on TPU using the Pallas kernel.
Handles fallback logic, mask processing, block size configuration, and sharding via shard_map. Expects inputs potentially in BTHD format and transposes them to BHTD for the kernel.
- Parameters
q – Query tensor (B, T, Hq, D).
k – Key tensor (B, S, Hkv, D).
v – Value tensor (B, S, Hkv, Dv).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used to generate segment IDs if provided.
causal – If True, applies causal masking via the kernel’s mask configuration. If False, falls back to VanillaAttn.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention outputs. Attention weights are not computed or returned by Splash Attention.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.
- class easydel.layers.attention_operator.modules.__init__.VanillaAttn(metadata: AttentionMetadata)[source]#
Bases:
AttentionImplA standard, non-optimized implementation of multi-head attention.
This implementation uses basic JAX operations like jnp.einsum and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping.
Registered under the name “vanilla”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Delegates to forward_native.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to forward_native.
- forward_native(q: Array, k: Array, v: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, init_bias: Optional[Callable[[], Array]] = None, deterministic: bool = True, dropout_rng: Optional[PRNGKey] = None, **ignore) AttentionOutput[source]#
Computes multi-head attention using standard JAX operations.
Supports GQA/MQA by reshaping the query tensor to match the number of key/value heads. Applies scaling, optional bias/mask, softmax (potentially in float32), and optional dropout.
- Parameters
q – Query tensor (B, T, H_q, D).
k – Key tensor (B, S, H_kv, D).
v – Value tensor (B, S, H_kv, D_v).
mask – Optional boolean attention mask (broadcastable to B, 1, T, S). Used if bias is not provided.
bias – Optional attention bias tensor (broadcastable to B, H_q, T, S). Takes precedence over mask.
init_bias – Optional callable to initialize bias if mask/bias are None.
deterministic – If True, disables dropout.
dropout_rng – JAX PRNG key for dropout. Required if deterministic is False and dropout_prob > 0.
**ignore – Ignored keyword arguments.
- Returns
An AttentionOutput object containing the attention weights (if computed) and the final attention outputs.
- Raises
NotImplementedError – If the bias head dimension cannot be reshaped correctly to match the query head structure for GQA/MQA.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Delegates to forward_native.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native.
- get_impl_metadata() AttentionMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The AttentionMetadata provided during initialization.