# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import dataclasses
import typing as tp
from abc import abstractmethod
import einops
import jax
from eformer import common_types
from eformer import escale as es
from eformer.escale import PartitionAxis, PartitionManager
from eformer.pytree import auto_pytree
from jax import Array
from jax import numpy as jnp
from jax.sharding import PartitionSpec as Ps
from easydel.infra.base_config import EasyDeLBaseConfig
from easydel.infra.etils import EasyDeLBackends, EasyDeLPlatforms
from easydel.utils.helpers import get_logger
from ..ops import BaseOperation
logger = get_logger("EasyDeL-AttentionOperator")
NOT_GIVEN = common_types.NOT_GIVEN
RUNTIME_MODE_TYPES = common_types.RUNTIME_MODE_TYPES
BATCH = common_types.BATCH
QUERY_LENGTH = common_types.QUERY_LENGTH
KV_LENGTH = common_types.KV_LENGTH
HEAD = common_types.HEAD
KV_HEAD = common_types.KV_HEAD
HEAD_DIM = common_types.HEAD_DIM
KV_HEAD_DIM = common_types.KV_HEAD_DIM
BIAS_HEAD_SEQ = common_types.BIAS_HEAD_SEQ
BIAS_KV_SEQ = common_types.BIAS_KV_SEQ
[docs]@auto_pytree
class AttentionOutput:
"""
Container for the outputs of an attention operation.
Attributes:
attention_weights: The attention probabilities, typically of shape
(batch, num_heads, query_seq_len, key_value_seq_len). Optional.
attention_outputs: The final weighted sum of values, typically of shape
(batch, query_seq_len, num_heads, head_dim) or (batch, num_heads, query_seq_len, head_dim).
Optional.
"""
attention_weights: tp.Optional[Array] = None
attention_outputs: tp.Optional[Array] = None
[docs]class AttentionImpl(BaseOperation):
"""
Abstract Base Class for specific attention implementations.
Inherits from `BaseOperation` to leverage backend-specific dispatching.
Subclasses must implement the core attention logic (`forward_native`) and
potentially provide optimized versions for TPU (`forward_tpu`), GPU (`forward_gpu`),
etc. They also need to declare their name and associated metadata.
Provides common helper methods for attention processing like mask manipulation,
head repeating (for GQA/MQA), and determining runtime mode.
"""
def __init__(self, metadata: AttentionMetadata) -> None:
"""
Initializes the attention implementation with its metadata.
Args:
metadata: An `AttentionMetadata` instance containing configuration
and context for this attention operation.
"""
self.metadata = metadata
[docs] @classmethod
@abstractmethod
def get_impl_name(cls) -> tp.Union[str, tp.Tuple[str, ...]]:
"""
Returns the unique name(s) identifying this attention implementation.
Used by the `AttentionRegistry`. Can return a single string or a tuple/list
of strings if the implementation has multiple aliases.
Returns:
A string or tuple/list of strings representing the implementation name(s).
"""
[docs] def get_mode(self, q: jax.Array, BTHD: bool = True) -> RUNTIME_MODE_TYPES: # type:ignore
"""
Determines the runtime mode (normal or generation) based on query shape.
Assumes generation mode if the query sequence length dimension is 1.
Args:
q: The query tensor.
BTHD: Boolean indicating tensor layout (True for B, T, H, D; False for B, H, T, D).
"""
ingeneration = q.shape[1] == 1 if BTHD else q.shape[2] == 1
return common_types.MODE_DECODE if ingeneration else common_types.MODE_TRAIN
[docs] def current_backend(self) -> tp.Literal["tpu", "gpu", "cpu"]:
"""
Returns the current JAX default backend as a lowercase string literal.
Returns:
"tpu", "gpu", or "cpu".
"""
return jax.default_backend()
@staticmethod
def _split_attention_mask(attn_mask: Array) -> tp.Tuple[Array, Array]:
"""
Splits a combined attention mask into separate query and key-value masks.
Assumes the input mask `attn_mask` might be 4D (batch, head, q_seq, kv_seq)
or 3D (batch, q_seq, kv_seq). It derives the query mask by checking which
query positions can attend to *any* key position, and the key-value mask
by checking which key positions *can be attended to* by any query position.
Args:
attn_mask: The combined attention mask (3D or 4D). If 4D, the last head dim
is used. Shape (..., q_seq, kv_seq).
Returns:
A tuple `(q_mask, kv_mask)`:
- `q_mask`: Boolean array of shape (..., q_seq). True for valid query tokens.
- `kv_mask`: Boolean array of shape (..., kv_seq). True for valid key/value tokens.
"""
if attn_mask.ndim == 4:
attn_mask = attn_mask[:, -1, :, :]
return (
jnp.any(attn_mask, axis=-1),
jnp.any(attn_mask, axis=-2),
)
@staticmethod
def _combine_query_kv_masks(q_mask: Array, kv_mask: Array) -> Array:
"""
Combines separate query and key-value masks into a standard attention mask.
Creates a broadcastable mask where `mask[b, i, j]` is True if both
`q_mask[b, i]` and `kv_mask[b, j]` are True.
Args:
q_mask: Boolean array of shape (..., q_seq). True for valid query tokens.
kv_mask: Boolean array of shape (..., kv_seq). True for valid key/value tokens.
Returns:
A boolean attention mask of shape (..., q_seq, kv_seq).
"""
if kv_mask.ndim == 2:
kv_mask = kv_mask[:, None, :]
if q_mask.ndim == 2:
q_mask = q_mask[:, :, None]
return q_mask * kv_mask
@staticmethod
def _create_causal_mask(qseq: int) -> Array:
"""
Creates a causal attention mask (lower triangular).
Args:
qseq: The sequence length .
Returns:
A boolean array of shape (qseq, qseq) where `mask[i, j]` is
True if `j <= i`, representing causal visibility.
"""
return jnp.tril(jnp.ones((qseq, qseq), dtype="b1"))
[docs] @staticmethod
def repeat_kv_heads(
k: Array,
v: Array,
num_reps: int,
) -> tp.Tuple[Array, Array]:
"""
Repeats Key and Value heads for Grouped Query Attention (GQA) or Multi-Query Attention (MQA).
Expands the head dimension of K and V tensors to match the number of query heads.
Args:
k: Key tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).
v: Value tensor, assumes shape (batch, seq_len, num_kv_heads, head_dim).
num_reps: The number of times to repeat each KV head (num_q_heads // num_kv_heads).
Returns:
A tuple `(k_repeated, v_repeated)` with shapes
(batch, seq_len, num_q_heads, head_dim).
"""
return (
einops.repeat(k, "b s h d -> b s (h r) d", r=num_reps),
einops.repeat(v, "b s h d -> b s (h r) d", r=num_reps),
)
def _handle_kvhead(
self,
array: tp.Optional[Array],
num_q_heads: int,
num_kv_heads: int,
) -> tp.Optional[Array]:
"""
Processes an attention bias or similar array based on head configuration (GQA/MQA).
If the array's head dimension matches `num_kv_heads`, it repeats the heads
to match `num_q_heads`. If it matches `num_q_heads` or is 1 (broadcastable),
it's returned as is.
Args:
array: The input array, typically an attention bias. Assumes head dimension
is at index 1. Shape (batch, num_heads, q_seq, kv_seq) or similar.
Can be None.
num_q_heads: The number of query heads.
num_kv_heads: The number of key/value heads.
Returns:
The processed array with head dimension matching `num_q_heads`, or None
if the input was None.
Raises:
ValueError: If the array's head dimension is incompatible.
"""
if array is None:
return None
if array.shape[1] == num_q_heads or array.shape[1] == 1:
return array
elif array.shape[1] == num_kv_heads:
return einops.repeat(
array,
"b h q k -> b (h r) q k",
r=num_q_heads // array.shape[1],
)
else:
raise ValueError(
f"Incompatible array shape. Got {array.shape[1]} heads, "
f"expected {num_q_heads}, {num_kv_heads}, or 1"
)
[docs] def create_stable_sharding(
self,
state_ps: tp.Optional[Ps] = None,
preserved_indices: tp.List[int] = None,
clone_ps: tp.Optional[Ps] = None,
dep: tp.Optional[tp.Union[Ps, bool]] = True,
tensor: tp.Optional[jax.Array] = None,
) -> tp.Optional[Ps]:
"""
Helper to create a PartitionSpec, potentially preserving only certain axes.
This might be used for ensuring intermediate tensors or states have compatible
sharding, possibly replicating across axes not specified in `preserved_indices`.
Args:
state_ps: The base PartitionSpec to modify.
preserved_indices: A list of dimension indices whose partitioning should be
kept from `state_ps` (or `clone_ps` if provided). Other dimensions
will be set to None (replicated). If None, `state_ps` is returned.
clone_ps: An optional PartitionSpec to copy axis names from for the
preserved indices, instead of using `state_ps`.
dep: A dependency flag or PartitionSpec. If None, returns None. Defaults to True.
(The exact purpose might be context-specific, potentially for control flow).
Returns:
A new PartitionSpec with only specified axes partitioned, or None based on `dep`.
Returns `state_ps` directly if `preserved_indices` is None.
"""
if dep is None:
return None
if state_ps is None:
return None
if preserved_indices is None:
if tensor is None:
return state_ps
return es.get_corrected_named_sharding(tensor.shape, state_ps).spec
new_spec = [None] * len(state_ps)
for idx in preserved_indices:
new_spec[idx] = state_ps[idx] if clone_ps is None else clone_ps[idx]
sharding = Ps(*new_spec)
if tensor is None:
return sharding
else:
return es.get_corrected_named_sharding(tensor.shape, sharding).spec
def __call__(self, *args, **kwargs) -> AttentionOutput:
"""
Executes the appropriate forward method based on the backend in metadata.
Overrides `BaseOperation.__call__` to dispatch based on `self.metadata.backend`
instead of the global `jax.default_backend()`. This allows forcing a specific
path (e.g., GPU path even if JAX defaults to CPU) based on the configuration
stored in `AttentionMetadata`.
Args:
*args: Positional arguments to pass to the forward method.
**kwargs: Keyword arguments to pass to the forward method.
Returns:
An `AttentionOutput` object containing the results.
Raises:
RuntimeError: If the backend specified in `self.metadata` is unknown.
"""
match self.metadata.backend:
case EasyDeLBackends.TPU:
logger.debug("Calling into TPU exec")
return self.forward_tpu(*args, **kwargs)
case EasyDeLBackends.GPU:
logger.debug("Calling into GPU exec")
return self.forward_gpu(*args, **kwargs)
case EasyDeLBackends.CPU:
logger.debug("Calling into CPU exec")
return self.forward_native(*args, **kwargs)
case _:
raise RuntimeError(f"unknown backend at AttentionImpl! {self.metadata.backend}")
_I = tp.TypeVar("ICa", bound=AttentionImpl)
[docs]class AttentionRegistry:
"""
Registry for discovering and managing different `AttentionImpl` classes.
Allows registering implementations using a decorator and retrieving or
instantiating them by name.
"""
_registry: tp.Dict[str, tp.Type[AttentionImpl]] = {}
[docs] @classmethod
def register(cls, impl_cls: tp.Type[_I]) -> tp.Type[_I]:
"""
Class method decorator to register an `AttentionImpl` subclass.
The implementation is registered under the name(s) returned by its
`get_impl_name()` class method.
Example:
```python
@AttentionRegistry.register
class FlashAttentionImpl(AttentionImpl):
@classmethod
def get_impl_name(cls) -> str:
return "flash"
# ... implementation ...
```
Args:
impl_cls: The `AttentionImpl` subclass to register.
Returns:
The registered class itself.
"""
impl_names = impl_cls.get_impl_name()
if not isinstance(impl_names, (list, tuple)):
impl_names = [impl_names]
for impl_name in impl_names:
if impl_name in cls._registry:
logger.warning(
f"Attention implementation '{impl_name}' already registered. Overwriting."
)
cls._registry[impl_name] = impl_cls
logger.debug(f"Registered attention implementation: {impl_name}")
return impl_cls
[docs] @classmethod
def get(cls, impl_name: str) -> tp.Type[AttentionImpl]:
"""
Retrieves an attention implementation class by its registered name.
Args:
impl_name: The name of the implementation to retrieve.
Returns:
The `AttentionImpl` subclass registered under the given name.
Raises:
ValueError: If no implementation is registered with that name.
"""
if impl_name not in cls._registry:
raise ValueError(
f"Attention implementation '{impl_name}' not found. Available implementations: {list(cls._registry.keys())}"
)
return cls._registry[impl_name]
[docs] @classmethod
def create(cls, impl_name: str, metadata: AttentionMetadata) -> AttentionImpl:
"""
Creates an instance of an attention implementation by name.
Retrieves the class associated with `impl_name` and initializes it
with the provided `metadata`.
Args:
impl_name: The name of the implementation to instantiate.
metadata: The `AttentionMetadata` to pass to the implementation's constructor.
Returns:
An initialized instance of the requested `AttentionImpl` subclass.
Raises:
ValueError: If no implementation is registered with `impl_name`.
"""
impl_cls = cls.get(impl_name)
return impl_cls(metadata)
[docs] @classmethod
def list_implementations(cls) -> tp.List[str]:
"""
Returns a list of names of all registered attention implementations.
Returns:
A list of strings, where each string is a registered implementation name.
"""
return list(cls._registry.keys())