easydel.layers.operations.modules.vanilla_attention#
Vanilla (standard) attention implementation for EasyDeL.
This module provides a reference implementation of multi-head attention using standard JAX operations. It serves as both a baseline for comparison with optimized implementations and a fallback for platforms where specialized kernels are unavailable.
The vanilla attention implementation: - Uses standard matrix multiplication and softmax operations - Supports all standard attention features (masking, bias, dropout) - Works on all platforms (TPU, GPU, CPU) without specialized kernels - Provides full attention weights for inspection when needed - Supports Grouped Query Attention (GQA) and Multi-Query Attention (MQA)
Key characteristics: - Memory complexity: O(N²) where N is sequence length - Computation: Uses einsum for efficient batch matrix multiplication - Flexibility: Supports various mask and bias shapes - Compatibility: Works with any JAX backend without modification
This implementation is ideal for: - Debugging and development - Small sequence lengths where memory is not a constraint - Platforms without optimized attention kernels - Cases where attention weights need to be inspected
Example
>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import VanillaAttn
>>>
>>> metadata = OperationMetadata(
... runtime_dtype=jnp.float16,
... runtime_softmax_dtype=jnp.float32, # Higher precision for softmax
... dropout_prob=0.1
... )
>>> vanilla_attn = VanillaAttn(metadata)
>>> output = vanilla_attn(query, key, value, mask=attention_mask)
>>> attention_weights = output.attention_weights # Available for inspection
- class easydel.layers.operations.modules.vanilla_attention.VanillaAttn(metadata: OperationMetadata)[source]#
Bases:
OperationImplA standard, non-optimized implementation of multi-head attention.
This implementation uses basic JAX operations like jnp.einsum and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping.
Registered under the name “vanilla”.
- forward_cpu(*args, **kwargs) AttentionOutput[source]#
CPU forward pass. Delegates to forward_native.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_cuda(*args, **kwargs) AttentionOutput[source]#
CUDA GPU forward pass. Delegates to forward_native.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_gpu(*args, **kwargs) AttentionOutput[source]#
GPU forward pass. Delegates to forward_native.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_native(query: Float[Array, 'batch seq_len num_q_heads head_dim'], key: Float[Array, 'batch kv_len num_kv_heads head_dim'], value: Float[Array, 'batch kv_len num_kv_heads head_dim'], mask_info: ejkernel.types.mask.MaskInfo | None = None, bias: jaxtyping.Float[Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: Optional[Callable[[], Float[Array, 'batch num_heads seq_len kv_len']]] = None, deterministic: bool = True, dropout_rng: Optional[Union[Key[Array, ''], UInt32[Array, '2']]] = None, softmax_aux: jaxtyping.Float[Array, 'num_heads num_sinks'] | jaxtyping.Float[Array, 'num_sinks'] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, **ignore) AttentionOutput[source]#
Standard multi-head attention implementation using basic JAX operations.
- Parameters
query – Query tensor [batch, seq_len, num_q_heads, head_dim].
key – Key tensor [batch, kv_len, num_kv_heads, head_dim].
value – Value tensor [batch, kv_len, num_kv_heads, head_dim].
mask_info – Optional mask information for attention.
bias – Optional attention bias [batch, num_heads, seq_len, kv_len].
init_bias – Optional callable to initialize bias if mask_info and bias are None.
deterministic – If True, disables dropout.
dropout_rng – JAX PRNG key for dropout.
softmax_aux – Auxiliary softmax tensor (e.g., for sink tokens).
softmax_scale – Scaling factor for attention logits.
logits_soft_cap – Soft capping value for attention logits.
dropout_prob – Dropout probability.
causal – Apply causal masking.
sliding_window – Sliding window size for local attention.
**ignore – Additional ignored arguments.
- Returns
AttentionOutput containing attention outputs and weights.
- forward_rocm(*args, **kwargs) AttentionOutput[source]#
ROCm GPU forward pass. Delegates to forward_native.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- forward_tpu(*args, **kwargs) AttentionOutput[source]#
TPU forward pass. Delegates to forward_native.
- Parameters
*args – Positional arguments for the attention calculation.
**kwargs – Keyword arguments for the attention calculation.
- Returns
An AttentionOutput object containing the attention results.
- get_impl_metadata() OperationMetadata[source]#
Returns the metadata associated with this attention implementation instance.
- Returns
The OperationMetadata provided during initialization.