# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import typing as tp
from abc import abstractmethod
from enum import Enum
import einops
import jax
from eformer.escale import PartitionAxis
from jax import Array
from jax import numpy as jnp
from jax.sharding import PartitionSpec as Ps
from eformer.pytree import auto_pytree
from easydel.infra.base_config import EasyDeLBaseConfig
from easydel.infra.etils import EasyDeLBackends, EasyDeLPlatforms
from easydel.utils.helpers import get_logger
from ..ops import BaseOperation
logger = get_logger("EasyDeL-AttentionOperator")
[docs]@auto_pytree
class AttentionOutput:
"""
Container for the outputs of an attention operation.
Attributes:
attention_weights: The attention probabilities, typically of shape
(batch, num_heads, query_seq_len, key_value_seq_len). Optional.
attention_outputs: The final weighted sum of values, typically of shape
(batch, query_seq_len, num_heads, head_dim) or (batch, num_heads, query_seq_len, head_dim).
Optional.
"""
attention_weights: tp.Optional[Array] = None
attention_outputs: tp.Optional[Array] = None
[docs]class RuntimeType(Enum):
"""
Enumerates the possible runtime modes for attention operations.
Attributes:
normal: Standard training or evaluation mode.
generation: Autoregressive generation mode, often involving KV caching and
single token decoding.
"""
normal = "normal"
generation = "generation"
[docs]class AttentionImpl(BaseOperation):
"""
Abstract Base Class for specific attention implementations.
Inherits from `BaseOperation` to leverage backend-specific dispatching.
Subclasses must implement the core attention logic (`forward_native`) and
potentially provide optimized versions for TPU (`forward_tpu`), GPU (`forward_gpu`),
etc. They also need to declare their name and associated metadata.
Provides common helper methods for attention processing like mask manipulation,
head repeating (for GQA/MQA), and determining runtime mode.
"""
def __init__(self, metadata: AttentionMetadata) -> None:
"""
Initializes the attention implementation with its metadata.
Args:
metadata: An `AttentionMetadata` instance containing configuration
and context for this attention operation.
"""
self.metadata = metadata
[docs] @classmethod
@abstractmethod
def get_impl_name(cls) -> tp.Union[str, tp.Tuple[str, ...]]:
"""
Returns the unique name(s) identifying this attention implementation.
Used by the `AttentionRegistry`. Can return a single string or a tuple/list
of strings if the implementation has multiple aliases.
Returns:
A string or tuple/list of strings representing the implementation name(s).
"""
[docs] def get_runtime_type(self, q: jax.Array, BTHD: bool = True) -> RuntimeType:
"""
Determines the runtime mode (normal or generation) based on query shape.
Assumes generation mode if the query sequence length dimension is 1.
Args:
q: The query tensor.
BTHD: Boolean indicating tensor layout (True for B, T, H, D; False for B, H, T, D).
Returns:
RuntimeType.generation if query sequence length is 1, else RuntimeType.normal.
"""
ingeneration = q.shape[1] == 1 if BTHD else q.shape[2] == 1
return RuntimeType.generation if ingeneration else RuntimeType.normal
[docs] def current_backend(self) -> tp.Literal["tpu", "gpu", "cpu"]:
"""
Returns the current JAX default backend as a lowercase string literal.
Returns:
"tpu", "gpu", or "cpu".
"""
return jax.default_backend()
@staticmethod
def _split_attention_mask(attn_mask: Array) -> tp.Tuple[Array, Array]:
"""
Splits a combined attention mask into separate query and key-value masks.
Assumes the input mask `attn_mask` might be 4D (batch, head, q_seq, kv_seq)
or 3D (batch, q_seq, kv_seq). It derives the query mask by checking which
query positions can attend to *any* key position, and the key-value mask
by checking which key positions *can be attended to* by any query position.
Args:
attn_mask: The combined attention mask (3D or 4D). If 4D, the last head dim
is used. Shape (..., q_seq, kv_seq).
Returns:
A tuple `(q_mask, kv_mask)`:
- `q_mask`: Boolean array of shape (..., q_seq). True for valid query tokens.
- `kv_mask`: Boolean array of shape (..., kv_seq). True for valid key/value tokens.
"""
if attn_mask.ndim == 4:
attn_mask = attn_mask[:, -1, :, :]
return (
jnp.any(attn_mask, axis=-1),
jnp.any(attn_mask, axis=-2),
)
@staticmethod
def _combine_query_kv_masks(q_mask: Array, kv_mask: Array) -> Array:
"""
Combines separate query and key-value masks into a standard attention mask.
Creates a broadcastable mask where `mask[b, i, j]` is True if both
`q_mask[b, i]` and `kv_mask[b, j]` are True.
Args:
q_mask: Boolean array of shape (..., q_seq). True for valid query tokens.
kv_mask: Boolean array of shape (..., kv_seq). True for valid key/value tokens.
Returns:
A boolean attention mask of shape (..., q_seq, kv_seq).
"""
if kv_mask.ndim == 2:
kv_mask = kv_mask[:, None, :]
if q_mask.ndim == 2:
q_mask = q_mask[:, :, None]
return q_mask * kv_mask
@staticmethod
def _create_causal_mask(qseq: int) -> Array:
"""
Creates a causal attention mask (lower triangular).
Args:
qseq: The sequence length .
Returns:
A boolean array of shape (qseq, qseq) where `mask[i, j]` is
True if `j <= i`, representing causal visibility.
"""
return jnp.tril(jnp.ones((qseq, qseq), dtype="b1"))
[docs] @staticmethod
def repeat_kv_heads(
k: Array,
v: Array,
num_reps: int,
) -> tp.Tuple[Array, Array]:
"""
Repeats Key and Value heads for Grouped Query Attention (GQA) or Multi-Query Attention (MQA).
Expands the head dimension of K and V tensors to match the number of query heads.
Args:
k: Key tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).
v: Value tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).
num_reps: The number of times to repeat each KV head (num_q_heads // num_kv_heads).
Returns:
A tuple `(k_repeated, v_repeated)` with shapes
(batch, seq_len, num_q_heads, head_dim).
"""
return (
einops.repeat(k, "b s h d -> b s (h r) d", r=num_reps),
einops.repeat(v, "b s h d -> b s (h r) d", r=num_reps),
)
def _handle_kvhead(
self,
array: tp.Optional[Array],
num_q_heads: int,
num_kv_heads: int,
) -> tp.Optional[Array]:
"""
Processes an attention bias or similar array based on head configuration (GQA/MQA).
If the array's head dimension matches `num_kv_heads`, it repeats the heads
to match `num_q_heads`. If it matches `num_q_heads` or is 1 (broadcastable),
it's returned as is.
Args:
array: The input array, typically an attention bias. Assumes head dimension
is at index 1. Shape (batch, num_heads, q_seq, kv_seq) or similar.
Can be None.
num_q_heads: The number of query heads.
num_kv_heads: The number of key/value heads.
Returns:
The processed array with head dimension matching `num_q_heads`, or None
if the input was None.
Raises:
ValueError: If the array's head dimension is incompatible.
"""
if array is None:
return None
if array.shape[1] == num_q_heads or array.shape[1] == 1:
return array
elif array.shape[1] == num_kv_heads:
return einops.repeat(
array,
"b h q k -> b (h r) q k",
r=num_q_heads // array.shape[1],
)
else:
raise ValueError(
f"Incompatible array shape. Got {array.shape[1]} heads, "
f"expected {num_q_heads}, {num_kv_heads}, or 1"
)
[docs] def create_stable_sharding(
self,
state_ps: tp.Optional[Ps] = None,
preserved_indices: tp.List[int] = None,
clone_ps: tp.Optional[Ps] = None,
dep: tp.Optional[tp.Union[Ps, bool]] = True,
) -> tp.Optional[Ps]:
"""
Helper to create a PartitionSpec, potentially preserving only certain axes.
This might be used for ensuring intermediate tensors or states have compatible
sharding, possibly replicating across axes not specified in `preserved_indices`.
Args:
state_ps: The base PartitionSpec to modify.
preserved_indices: A list of dimension indices whose partitioning should be
kept from `state_ps` (or `clone_ps` if provided). Other dimensions
will be set to None (replicated). If None, `state_ps` is returned.
clone_ps: An optional PartitionSpec to copy axis names from for the
preserved indices, instead of using `state_ps`.
dep: A dependency flag or PartitionSpec. If None, returns None. Defaults to True.
(The exact purpose might be context-specific, potentially for control flow).
Returns:
A new PartitionSpec with only specified axes partitioned, or None based on `dep`.
Returns `state_ps` directly if `preserved_indices` is None.
"""
if dep is None:
return None
if state_ps is None:
return None
if preserved_indices is None:
return state_ps
new_spec = [None] * len(state_ps)
for idx in preserved_indices:
new_spec[idx] = state_ps[idx] if clone_ps is None else clone_ps[idx]
return Ps(*new_spec)
def __call__(self, *args, **kwargs) -> AttentionOutput:
"""
Executes the appropriate forward method based on the backend in metadata.
Overrides `BaseOperation.__call__` to dispatch based on `self.metadata.backend`
instead of the global `jax.default_backend()`. This allows forcing a specific
path (e.g., GPU path even if JAX defaults to CPU) based on the configuration
stored in `AttentionMetadata`.
Args:
*args: Positional arguments to pass to the forward method.
**kwargs: Keyword arguments to pass to the forward method.
Returns:
An `AttentionOutput` object containing the results.
Raises:
RuntimeError: If the backend specified in `self.metadata` is unknown.
"""
match self.metadata.backend:
case EasyDeLBackends.TPU:
logger.debug("Calling into TPU exec")
return self.forward_tpu(*args, **kwargs)
case EasyDeLBackends.GPU:
logger.debug("Calling into GPU exec")
return self.forward_gpu(*args, **kwargs)
case EasyDeLBackends.CPU:
logger.debug("Calling into CPU exec")
return self.forward_native(*args, **kwargs)
case _:
raise RuntimeError(f"unknown backend at AttentionImpl! {self.metadata.backend}")
_I = tp.TypeVar("ICa", bound=AttentionImpl)
[docs]class AttentionRegistry:
"""
Registry for discovering and managing different `AttentionImpl` classes.
Allows registering implementations using a decorator and retrieving or
instantiating them by name.
"""
_registry: tp.Dict[str, tp.Type[AttentionImpl]] = {}
[docs] @classmethod
def register(cls, impl_cls: tp.Type[_I]) -> tp.Type[_I]:
"""
Class method decorator to register an `AttentionImpl` subclass.
The implementation is registered under the name(s) returned by its
`get_impl_name()` class method.
Example:
```python
@AttentionRegistry.register
class FlashAttentionImpl(AttentionImpl):
@classmethod
def get_impl_name(cls) -> str:
return "flash"
# ... implementation ...
```
Args:
impl_cls: The `AttentionImpl` subclass to register.
Returns:
The registered class itself.
"""
impl_names = impl_cls.get_impl_name()
if not isinstance(impl_names, (list, tuple)):
impl_names = [impl_names]
for impl_name in impl_names:
if impl_name in cls._registry:
logger.warning(
f"Attention implementation '{impl_name}' already registered. Overwriting."
)
cls._registry[impl_name] = impl_cls
logger.debug(f"Registered attention implementation: {impl_name}")
return impl_cls
[docs] @classmethod
def get(cls, impl_name: str) -> tp.Type[AttentionImpl]:
"""
Retrieves an attention implementation class by its registered name.
Args:
impl_name: The name of the implementation to retrieve.
Returns:
The `AttentionImpl` subclass registered under the given name.
Raises:
ValueError: If no implementation is registered with that name.
"""
if impl_name not in cls._registry:
raise ValueError(
f"Attention implementation '{impl_name}' not found. Available implementations: {list(cls._registry.keys())}"
)
return cls._registry[impl_name]
[docs] @classmethod
def create(cls, impl_name: str, metadata: AttentionMetadata) -> AttentionImpl:
"""
Creates an instance of an attention implementation by name.
Retrieves the class associated with `impl_name` and initializes it
with the provided `metadata`.
Args:
impl_name: The name of the implementation to instantiate.
metadata: The `AttentionMetadata` to pass to the implementation's constructor.
Returns:
An initialized instance of the requested `AttentionImpl` subclass.
Raises:
ValueError: If no implementation is registered with `impl_name`.
"""
impl_cls = cls.get(impl_name)
return impl_cls(metadata)
[docs] @classmethod
def list_implementations(cls) -> tp.List[str]:
"""
Returns a list of names of all registered attention implementations.
Returns:
A list of strings, where each string is a registered implementation name.
"""
return list(cls._registry.keys())