Source code for easydel.layers.operations.modules.vanilla_attention

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Vanilla (standard) attention implementation for EasyDeL.

This module provides a reference implementation of multi-head attention using
standard JAX operations. It serves as both a baseline for comparison with optimized
implementations and a fallback for platforms where specialized kernels are unavailable.

The vanilla attention implementation:
- Uses standard matrix multiplication and softmax operations
- Supports all standard attention features (masking, bias, dropout)
- Works on all platforms (TPU, GPU, CPU) without specialized kernels
- Provides full attention weights for inspection when needed
- Supports Grouped Query Attention (GQA) and Multi-Query Attention (MQA)

Key characteristics:
- Memory complexity: O(N²) where N is sequence length
- Computation: Uses einsum for efficient batch matrix multiplication
- Flexibility: Supports various mask and bias shapes
- Compatibility: Works with any JAX backend without modification

This implementation is ideal for:
- Debugging and development
- Small sequence lengths where memory is not a constraint
- Platforms without optimized attention kernels
- Cases where attention weights need to be inspected

Example:
    >>> from easydel.layers.attention_operator import OperationMetadata
    >>> from easydel.layers.attention_operator.modules import VanillaAttn
    >>>
    >>> metadata = OperationMetadata(
    ...     runtime_dtype=jnp.float16,
    ...     runtime_softmax_dtype=jnp.float32,  # Higher precision for softmax
    ...     dropout_prob=0.1
    ... )
    >>> vanilla_attn = VanillaAttn(metadata)
    >>> output = vanilla_attn(query, key, value, mask=attention_mask)
    >>> attention_weights = output.attention_weights  # Available for inspection
"""

import typing as tp

import jax
from eformer import common_types
from eformer.escale import with_sharding_constraint
from ejkernel.modules import attention
from ejkernel.types import MaskInfo
from jax import numpy as jnp
from jax import random as jr
from jaxtyping import Array, Float, PRNGKeyArray

from .._attention_outputs import AttentionOutput
from .._operation_impl import OperationImpl, OperationMetadata, OperationRegistry


[docs]@OperationRegistry.register class VanillaAttn(OperationImpl): """ A standard, non-optimized implementation of multi-head attention. This implementation uses basic JAX operations like `jnp.einsum` and standard softmax. It serves as a reference implementation and a fallback for platforms where optimized kernels (like Flash Attention) are not available or desired. It supports features like attention bias, masking, dropout, and Grouped Query Attention (GQA)/Multi-Query Attention (MQA) via reshaping. Registered under the name "vanilla". """
[docs] @classmethod def get_impl_name(cls) -> str | tuple[str]: """ Returns the registered name of this attention implementation. Returns: The string "vanilla". """ return "vanilla"
[docs] def get_impl_metadata(self) -> OperationMetadata: """ Returns the metadata associated with this attention implementation instance. Returns: The `OperationMetadata` provided during initialization. """ return self.metadata
[docs] @jax.named_scope("easydel-vanillaimpl-native-xla") def forward_native( self, query: Float[Array, "batch seq_len num_q_heads head_dim"], key: Float[Array, "batch kv_len num_kv_heads head_dim"], value: Float[Array, "batch kv_len num_kv_heads head_dim"], mask_info: MaskInfo | None = None, bias: Float[Array, "batch num_heads seq_len kv_len"] | None = None, init_bias: tp.Callable[[], Float[Array, "batch num_heads seq_len kv_len"]] | None = None, deterministic: bool = True, dropout_rng: PRNGKeyArray | None = None, softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, # noqa softmax_scale: float | None = None, logits_soft_cap: float | None = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, **ignore, ) -> AttentionOutput: """ Standard multi-head attention implementation using basic JAX operations. Args: query: Query tensor [batch, seq_len, num_q_heads, head_dim]. key: Key tensor [batch, kv_len, num_kv_heads, head_dim]. value: Value tensor [batch, kv_len, num_kv_heads, head_dim]. mask_info: Optional mask information for attention. bias: Optional attention bias [batch, num_heads, seq_len, kv_len]. init_bias: Optional callable to initialize bias if mask_info and bias are None. deterministic: If True, disables dropout. dropout_rng: JAX PRNG key for dropout. softmax_aux: Auxiliary softmax tensor (e.g., for sink tokens). softmax_scale: Scaling factor for attention logits. logits_soft_cap: Soft capping value for attention logits. dropout_prob: Dropout probability. causal: Apply causal masking. sliding_window: Sliding window size for local attention. **ignore: Additional ignored arguments. Returns: AttentionOutput containing attention outputs and weights. """ with self.metadata.mesh: model_mode = self.get_mode(query=query, BTHD=True) shardings = self.metadata.get_shardings(model_mode, layout="bthd") # Initialize bias if needed needs_bias_init: bool = mask_info is None and bias is None and init_bias is not None bias_computed: Float[Array, "batch num_heads seq_len kv_len"] | None if needs_bias_init: bias_computed = init_bias() else: bias_computed = bias # Apply sharding constraints to inputs query = with_sharding_constraint(arr=query, sharding=shardings.query) key = with_sharding_constraint(arr=key, sharding=shardings.key) value = with_sharding_constraint(arr=value, sharding=shardings.value) bias: Float[Array, "batch num_heads seq_len kv_len"] | None if bias_computed is not None: bias = with_sharding_constraint(arr=bias_computed, sharding=shardings.bias) else: bias = None # Compute attention runtime_dtype: jnp.dtype = self.metadata.runtime_dtype softmax_dtype: jnp.dtype | None = self.metadata.runtime_softmax_dtype is_decode_mode: bool = model_mode == common_types.MODE_DECODE causal_computed: bool = causal if not is_decode_mode else False outputs, weights = attention( query, key, value, bias, dropout_rng, softmax_aux, mask_info=mask_info, deterministic=deterministic, dropout_prob=dropout_prob, dtype=runtime_dtype, sliding_window=sliding_window, softmax_dtype=softmax_dtype, softmax_scale=softmax_scale, init_bias=None, causal=causal_computed, logits_soft_cap=logits_soft_cap, ) # Apply output sharding outputs_sharded = with_sharding_constraint(arr=outputs, sharding=shardings.output) return AttentionOutput(attention_weights=weights, attention_outputs=outputs_sharded)
[docs] def forward_gpu(self, *args, **kwargs) -> AttentionOutput: """GPU forward pass. Delegates to `forward_native`. Args: *args: Positional arguments for the attention calculation. **kwargs: Keyword arguments for the attention calculation. Returns: An `AttentionOutput` object containing the attention results. """ return self.forward_cuda(*args, **kwargs)
[docs] def forward_tpu(self, *args, **kwargs) -> AttentionOutput: """TPU forward pass. Delegates to `forward_native`. Args: *args: Positional arguments for the attention calculation. **kwargs: Keyword arguments for the attention calculation. Returns: An `AttentionOutput` object containing the attention results. """ return self.forward_native(*args, **kwargs)
[docs] def forward_cpu(self, *args, **kwargs) -> AttentionOutput: """CPU forward pass. Delegates to `forward_native`. Args: *args: Positional arguments for the attention calculation. **kwargs: Keyword arguments for the attention calculation. Returns: An `AttentionOutput` object containing the attention results. """ return self.forward_native(*args, **kwargs)
[docs] def forward_cuda(self, *args, **kwargs) -> AttentionOutput: """CUDA GPU forward pass. Delegates to `forward_native`. Args: *args: Positional arguments for the attention calculation. **kwargs: Keyword arguments for the attention calculation. Returns: An `AttentionOutput` object containing the attention results. """ return self.forward_native(*args, **kwargs)
[docs] def forward_rocm(self, *args, **kwargs) -> AttentionOutput: """ROCm GPU forward pass. Delegates to `forward_native`. Args: *args: Positional arguments for the attention calculation. **kwargs: Keyword arguments for the attention calculation. Returns: An `AttentionOutput` object containing the attention results. """ return self.forward_native(*args, **kwargs)
def __call__( self, query: Float[Array, "batch seq_len num_q_heads head_dim"], key: Float[Array, "batch kv_len num_kv_heads head_dim"], value: Float[Array, "batch kv_len num_kv_heads head_dim"], mask_info: MaskInfo | None = None, bias: Float[Array, "batch num_heads seq_len kv_len"] | None = None, init_bias: tp.Callable[[], Float[Array, "batch num_heads seq_len kv_len"]] | None = None, deterministic: bool = True, dropout_rng: PRNGKeyArray | None = None, softmax_aux: Float[Array, "num_heads num_sinks"] | Float[Array, "num_sinks"] | None = None, # noqa softmax_scale: float | None = None, logits_soft_cap: float | None = None, dropout_prob: float = 0.0, causal: bool = False, sliding_window: int | tuple[int, int] | None = None, **ignore, ) -> AttentionOutput: """ Executes the vanilla attention computation. Calls the appropriate backend-specific forward method via `super().__call__`. Since all backend methods delegate to `forward_native`, this effectively always runs the native JAX implementation. Args: query: Query tensor. key: Key tensor. value: Value tensor. mask: Optional attention mask. bias: Optional attention bias. init_bias: Optional callable to initialize bias. deterministic: If True, disables dropout. dropout_rng: JAX PRNG key for dropout if deterministic is False. **ignore: Additional ignored keyword arguments. Returns: An `AttentionOutput` object containing the attention results. """ return super().__call__( query=query, key=key, value=value, mask_info=mask_info, bias=bias, deterministic=deterministic, dropout_prob=dropout_prob, dropout_rng=dropout_rng, sliding_window=sliding_window, softmax_aux=softmax_aux, softmax_scale=softmax_scale, init_bias=init_bias, logits_soft_cap=logits_soft_cap, causal=causal, **ignore, )
if __name__ == "__main__": from easydel.infra import EasyDeLBaseConfig # Test cace when qkv might refer to mla b, qs, ks, qh, kh, d, vd = 1, 1024, 1024, 32, 8, 128, 128 + 64 query = jr.normal(jr.key(0), (b, qs, qh, d), "f2") key = jr.normal(jr.key(1), (b, ks, kh, d), "f2") value = jr.normal(jr.key(2), (b, ks, kh, vd), "f2") mask_info = MaskInfo.from_random(b, qs, ks) metadata = OperationMetadata( runtime_dtype=jnp.float16, runtime_softmax_dtype=jnp.float32, base_config=EasyDeLBaseConfig(), ) out = VanillaAttn(metadata)(query=query, key=key, value=value, mask_info=mask_info) print(out.attention_outputs)