# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import typing as tp
from abc import ABC, abstractmethod
from enum import Enum
import einops
import jax
from eformer.escale import PartitionAxis
from jax import Array
from jax import numpy as jnp
from jax.sharding import PartitionSpec as Ps
from easydel.infra.base_config import EasyDeLBaseConfig
from easydel.infra.etils import EasyDeLBackends, EasyDeLPlatforms
from easydel.utils import traversals as etr
from easydel.utils.helpers import get_logger
logger = get_logger("EasyDeL-AttentionOperator")
[docs]@etr.auto_pytree
class AttentionOutput:
attention_weights: tp.Optional[Array] = None
attention_outputs: tp.Optional[Array] = None
[docs]class RuntimeType(Enum):
normal = "normal"
generation = "generation"
[docs]class AttentionImpl(ABC):
def __init__(self, metadata: AttentionMetadata) -> None:
self.metadata = metadata
[docs] @abstractmethod
def forward_native(self, *args, **kwargs) -> AttentionOutput: ...
[docs] @abstractmethod
def forward_tpu(self, *args, **kwargs) -> AttentionOutput: ...
[docs] @abstractmethod
def forward_cpu(self, *args, **kwargs) -> AttentionOutput: ...
[docs] @abstractmethod
def forward_gpu(self, *args, **kwargs) -> AttentionOutput: ...
[docs] @abstractmethod
def forward_rocm(self, *args, **kwargs) -> AttentionOutput: ...
[docs] @abstractmethod
def forward_cuda(self, *args, **kwargs) -> AttentionOutput: ...
[docs] @classmethod
@abstractmethod
def get_impl_name(cls) -> tp.Union[str, tp.Tuple[str]]: ...
[docs] def get_runtime_type(self, q: jax.Array, BTHD: bool = True) -> RuntimeType:
ingeneration = q.shape[1] == 1 if BTHD else q.shape[2] == 1
return RuntimeType.generation if ingeneration else RuntimeType.normal
[docs] def current_backend(self) -> tp.Literal["tpu", "gpu", "cpu"]:
return jax.default_backend()
@staticmethod
def _split_attention_mask(attn_mask: Array) -> tp.Tuple[Array, Array]:
"""
Takes an attention mask and splits it into query mask and key-value mask.
"""
if attn_mask.ndim == 4:
attn_mask = attn_mask[:, -1, :, :]
return (
jnp.any(attn_mask, axis=-1),
jnp.any(attn_mask, axis=-2),
)
@staticmethod
def _combine_query_kv_masks(q_mask: Array, kv_mask: Array) -> Array:
"""
Takes separate query and key-value masks and combines them into an attention mask.
"""
if kv_mask.ndim == 2:
kv_mask = kv_mask[:, None, :]
if q_mask.ndim == 2:
q_mask = q_mask[:, :, None]
return q_mask * kv_mask
@staticmethod
def _create_causal_mask(qseq) -> Array:
return jnp.tril(jnp.ones((qseq, qseq), dtype="b1"))
[docs] @staticmethod
def repeat_kv_heads(
k: Array,
v: Array,
num_reps: int,
) -> tp.Tuple[Array, Array]:
"""Repeats k and v heads to match q heads."""
return (
einops.repeat(k, "b s h d -> b s (h r) d", r=num_reps),
einops.repeat(v, "b s h d -> b s (h r) d", r=num_reps),
)
def _handle_kvhead(
self,
array: Array,
num_q_heads: int,
num_kv_heads: int,
) -> tp.Optional[Array]:
"""Processes attention bias based on head configuration."""
if array is None:
return None
if array.shape[1] == num_q_heads or array.shape[1] == 1:
return array
elif array.shape[1] == num_kv_heads:
return einops.repeat(
array,
"b h q k -> b (h r) q k",
r=num_q_heads // array.shape[1],
)
else:
raise ValueError(
f"Incompatible array shape. Got {array.shape[1]} heads, "
f"expected {num_q_heads}, {num_kv_heads}, or 1"
)
[docs] def create_stable_sharding(
self,
state_ps: tp.Optional[Ps] = None,
preserved_indices: tp.List[int] = None,
clone_ps: tp.Optional[Ps] = None,
dep: tp.Optional[tp.Union[Ps, bool]] = True,
) -> tp.Optional[Ps]:
if dep is None:
return None
if state_ps is None:
return None
if preserved_indices is None:
return state_ps
new_spec = [None] * len(state_ps)
for idx in preserved_indices:
new_spec[idx] = state_ps[idx] if clone_ps is None else clone_ps[idx]
return Ps(*new_spec)
def __call__(self, *args, **kwargs) -> AttentionOutput:
match self.metadata.backend:
case EasyDeLBackends.TPU:
logger.debug("Calling into TPU exec")
return self.forward_tpu(*args, **kwargs)
case EasyDeLBackends.GPU:
logger.debug("Calling into GPU exec")
return self.forward_gpu(*args, **kwargs)
case EasyDeLBackends.CPU:
logger.debug("Calling into CPU exec")
return self.forward_native(*args, **kwargs)
case _:
raise RuntimeError(f"unknown backend at AttentionImpl! {self.metadata.backend}")
[docs]class AttentionRegistry:
"""Registry for attention implementations."""
_registry: tp.Dict[str, tp.Type[AttentionImpl]] = {}
[docs] @classmethod
def register(cls, impl_cls: tp.Type[AttentionImpl]) -> tp.Type[AttentionImpl]:
"""
Decorator to register an attention implementation.
Example usage:
@AttentionRegistry.register
class CustomAttention(AttentionImpl):
...
"""
impl_names = impl_cls.get_impl_name()
if not isinstance(impl_names, (list, tuple)):
impl_names = [impl_names]
for impl_name in impl_names:
if impl_name in cls._registry:
logger.warning(
f"Attention implementation '{impl_name}' already registered. Overwriting."
)
cls._registry[impl_name] = impl_cls
logger.debug(f"Registered attention implementation: {impl_name}")
return impl_cls
[docs] @classmethod
def get(cls, impl_name: str) -> tp.Type[AttentionImpl]:
"""Get an attention implementation by name."""
if impl_name not in cls._registry:
raise ValueError(
f"Attention implementation '{impl_name}' not found. Available implementations: {list(cls._registry.keys())}"
)
return cls._registry[impl_name]
[docs] @classmethod
def create(cls, impl_name: str, metadata: AttentionMetadata) -> AttentionImpl:
"""Create an instance of an attention implementation by name."""
impl_cls = cls.get(impl_name)
return impl_cls(metadata)
[docs] @classmethod
def list_implementations(cls) -> tp.List[str]:
"""List all registered attention implementations."""
return list(cls._registry.keys())