easydel.layers.operations.modules.vanilla_attention#

Vanilla (standard) attention implementation for EasyDeL.

This module provides a reference implementation of multi-head attention using standard JAX operations. It serves as both a baseline for comparison with optimized implementations and a fallback for platforms where specialized kernels are unavailable.

The vanilla attention implementation: - Uses standard matrix multiplication and softmax operations - Supports all standard attention features (masking, bias, dropout) - Works on all platforms (TPU, GPU, CPU) without specialized kernels - Provides full attention weights for inspection when needed - Supports Grouped Query Attention (GQA) and Multi-Query Attention (MQA)

Key characteristics: - Memory complexity: O(N²) where N is sequence length - Computation: Uses einsum for efficient batch matrix multiplication - Flexibility: Supports various mask and bias shapes - Compatibility: Works with any JAX backend without modification

This implementation is ideal for: - Debugging and development - Small sequence lengths where memory is not a constraint - Platforms without optimized attention kernels - Cases where attention weights need to be inspected

Example

>>> from easydel.layers.attention_operator import OperationMetadata
>>> from easydel.layers.attention_operator.modules import VanillaAttn
>>>
>>> metadata = OperationMetadata(
...     runtime_dtype=jnp.float16,
...     runtime_softmax_dtype=jnp.float32,  # Higher precision for softmax
...     dropout_prob=0.1
... )
>>> vanilla_attn = VanillaAttn(metadata)
>>> output = vanilla_attn(query, key, value, mask=attention_mask)
>>> attention_weights = output.attention_weights  # Available for inspection
class easydel.layers.operations.modules.vanilla_attention.VanillaAttn(metadata: OperationMetadata)[source]#

Bases: OperationImpl

A standard, non-optimized implementation of multi-head attention.

This implementation uses basic JAX operations like jnp.einsum and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping.

Registered under the name “vanilla”.

forward_cpu(*args, **kwargs) AttentionOutput[source]#

CPU forward pass. Delegates to forward_native.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_cuda(*args, **kwargs) AttentionOutput[source]#

CUDA GPU forward pass. Delegates to forward_native.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_gpu(*args, **kwargs) AttentionOutput[source]#

GPU forward pass. Delegates to forward_native.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_native(query: Float[Array, 'batch seq_len num_q_heads head_dim'], key: Float[Array, 'batch kv_len num_kv_heads head_dim'], value: Float[Array, 'batch kv_len num_kv_heads head_dim'], mask_info: ejkernel.types.mask.MaskInfo | None = None, bias: jaxtyping.Float[Array, 'batch num_heads seq_len kv_len'] | None = None, init_bias: Optional[Callable[[], Float[Array, 'batch num_heads seq_len kv_len']]] = None, deterministic: bool = True, dropout_rng: Optional[Union[Key[Array, ''], UInt32[Array, '2']]] = None, softmax_aux: jaxtyping.Float[Array, 'num_heads num_sinks'] | jaxtyping.Float[Array, 'num_sinks'] | None = None, softmax_scale: float | None = None, logits_soft_cap: float | None = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, **ignore) AttentionOutput[source]#

Standard multi-head attention implementation using basic JAX operations.

Parameters
  • query – Query tensor [batch, seq_len, num_q_heads, head_dim].

  • key – Key tensor [batch, kv_len, num_kv_heads, head_dim].

  • value – Value tensor [batch, kv_len, num_kv_heads, head_dim].

  • mask_info – Optional mask information for attention.

  • bias – Optional attention bias [batch, num_heads, seq_len, kv_len].

  • init_bias – Optional callable to initialize bias if mask_info and bias are None.

  • deterministic – If True, disables dropout.

  • dropout_rng – JAX PRNG key for dropout.

  • softmax_aux – Auxiliary softmax tensor (e.g., for sink tokens).

  • softmax_scale – Scaling factor for attention logits.

  • logits_soft_cap – Soft capping value for attention logits.

  • dropout_prob – Dropout probability.

  • causal – Apply causal masking.

  • sliding_window – Sliding window size for local attention.

  • **ignore – Additional ignored arguments.

Returns

AttentionOutput containing attention outputs and weights.

forward_rocm(*args, **kwargs) AttentionOutput[source]#

ROCm GPU forward pass. Delegates to forward_native.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

forward_tpu(*args, **kwargs) AttentionOutput[source]#

TPU forward pass. Delegates to forward_native.

Parameters
  • *args – Positional arguments for the attention calculation.

  • **kwargs – Keyword arguments for the attention calculation.

Returns

An AttentionOutput object containing the attention results.

get_impl_metadata() OperationMetadata[source]#

Returns the metadata associated with this attention implementation instance.

Returns

The OperationMetadata provided during initialization.

classmethod get_impl_name() str | tuple[str][source]#

Returns the registered name of this attention implementation.

Returns

The string “vanilla”.