Source code for easydel.__init__.layers.attention

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing as tp
import warnings
from enum import Enum
from functools import cached_property

import einops
import flax.nnx as nn
import jax
import jax.experimental
import jax.extend
import jax.lib
import jax.tree_util
from chex import Array
from eformer.escale import with_sharding_constraint
from jax import NamedSharding, lax, random
from jax import numpy as jnp
from jax.sharding import PartitionSpec
from jax import tree_util as jtu
from easydel.infra.base_module import EasyDeLBaseConfig
from easydel.layers.caching import TransformerCacheView
from easydel.layers.quantization.quantizers import EasyQuantizer
from easydel.utils.helpers import get_logger
from .attention_operator import AttentionMetadata, AttentionOutput, AttentionRegistry

logger = get_logger(__name__)


def _get_jax_dtype_from_string(dtype_string):
	dtype_mapping = {
		"<class 'jax.numpy.float32'>": jnp.float32,
		"<class 'jax.numpy.float64'>": jnp.float64,
		"<class 'jax.numpy.int32'>": jnp.int32,
		"<class 'jax.numpy.int64'>": jnp.int64,
		"<class 'jax.numpy.bool_'>": jnp.bool_,
		"<class 'jax.numpy.complex64'>": jnp.complex64,
		"<class 'jax.numpy.complex128'>": jnp.complex128,
	}
	return dtype_mapping.get(dtype_string, dtype_string)


[docs]class AttentionMechanisms(str, Enum): AUTO = "auto" FLASH_ATTN2 = "flash_attn2" RING = "ring" VANILLA = "vanilla" SPLASH = "splash" CUDNN = "cudnn" BLOCKWISE = "blockwise" SDPA = "sdpa" CUDA_FLASH_ATTN2 = "cuda_flash_attn2"
def tpu_version_check(version: str = "v4"): if version in getattr(jax.local_devices()[0], "device_kind", "").lower(): return True return False def get_optimal_config() -> tp.Tuple[AttentionMechanisms, jnp.dtype]: """ Returns the optimal attention mechanism and dtype for the current JAX device. Returns: A tuple of (attention_mechanism, dtype) """ match jax.default_backend(): case "tpu": if tpu_version_check("v3"): return AttentionMechanisms.FLASH_ATTN2, jnp.float32 return AttentionMechanisms.SPLASH, jnp.bfloat16 case "gpu": return AttentionMechanisms.FLASH_ATTN2, jnp.bfloat16 case _: return AttentionMechanisms.VANILLA, jnp.bfloat16 DEFAULT_ATTENTION_MECHANISM = "auto"
[docs]class FlexibleAttentionModule(nn.Module): """ Manages different attention mechanisms for efficient computation in EasyDeL models. This class serves as a central hub for handling various attention mechanisms, including optimized implementations like FlashAttention, SplashAttention, RingAttention, and more traditional approaches like vanilla (dot-product) attention. It provides a unified interface to select and execute the appropriate attention mechanism based on the model's configuration and hardware platform. Key Features: * **Attention Mechanism Selection:** Supports a wide range of attention mechanisms, allowing users to choose the most suitable option based on performance and hardware constraints. * **Sharding and Partitioning:** Integrates with JAX's sharding capabilities, enabling efficient distribution of computations and data across multiple devices. * **Block-wise Computation:** Implements block-wise attention computations for optimized memory usage and speed, particularly beneficial for large models. * **Performance Optimization:** Includes support for highly optimized implementations like FlashAttention, SplashAttention, and RingAttention for TPU and GPU acceleration. * **Flexibility and Customization:** Offers fine-grained control over attention parameters, sharding specifications, and block sizes, providing flexibility for different use cases. * **Testing and Evaluation:** Includes a `run_attention_benchmarks` method to systematically evaluate different attention mechanisms and help users identify the best-performing option. The FlexibleAttentionModule class is a crucial component within EasyDeL, responsible for managing and optimizing attention computations. It provides a user-friendly way to select and execute different attention mechanisms, leveraging JAX's sharding capabilities and offering performance enhancements through specialized implementations like FlashAttention and SplashAttention. Its ability to handle block-wise computations and customization options makes it adaptable to a variety of model architectures and hardware configurations. """ def __init__( self, base_config: EasyDeLBaseConfig, softmax_scale: float, dropout_prob: float = 0.0, ): if isinstance(base_config.attn_dtype, str): base_config.attn_dtype = _get_jax_dtype_from_string(base_config.attn_dtype) if isinstance(base_config.attn_softmax_dtype, str): base_config.attn_softmax_dtype = _get_jax_dtype_from_string( base_config.attn_softmax_dtype ) if base_config.attn_mechanism == AttentionMechanisms.AUTO: impl_name, runtime_dtype = get_optimal_config() logger.debug(f"Automatically select AttentionImpl {impl_name} | {runtime_dtype}") base_config.attn_mechanism = impl_name base_config.attn_dtype = runtime_dtype metadata = AttentionMetadata.from_config( config=base_config, softmax_scale=softmax_scale, dropout_prob=dropout_prob, ) self.impl = AttentionRegistry.create( impl_name=base_config.attn_mechanism, metadata=metadata, ) self.deterministic = True
[docs] @jax.named_scope("easydel-flexible-attention") def forward( self, query_states: Array, key_states: Array, value_states: Array, bias: tp.Optional[Array] = None, init_bias: tp.Optional[tp.Callable[[], Array]] = None, attention_mask: tp.Optional[Array] = None, segment_ids: tp.Optional[Array] = None, causal: bool = True, dropout_rng: tp.Optional[random.PRNGKey] = None, ) -> AttentionOutput: return jtu.tree_map( lambda x: x.astype(self.impl.metadata.runtime_dtype), self.impl( q=query_states, k=key_states, v=value_states, bias=bias, init_bias=init_bias, mask=attention_mask, segment_ids=segment_ids, causal=causal, deterministic=self.deterministic, dropout_rng=dropout_rng, ), )
__call__ = forward
SC = tp.TypeVar("SC") class FlaxAttentionModule(nn.Module): def __init__( self, config: SC, ): super().__init__() self.config: SC | EasyDeLBaseConfig = config self.cached_key: nn.Cache[Array] | None = None self.cached_value: nn.Cache[Array] | None = None self.cache_index: nn.Cache[Array] | None = None def make_flexible_sliding_window( self, attention_mask: jax.Array, cache_view: TransformerCacheView, sliding_window: int, ): attention_mask = jnp.logical_and( self._create_sliding_mask( cache_pos=self.build_cache_pos(attention_mask, cache_view), curr_index=cache_view.index[0] if cache_view is not None else 0, cache_length=attention_mask.shape[-1], sliding_windows=sliding_window, ), attention_mask, ) def init_attention_bias(): return jax.lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) return attention_mask, init_attention_bias @staticmethod def build_cache_pos( attention_mask: jax.Array, cache_view: TransformerCacheView = None, ) -> jax.Array: end_index = cache_view.index[0] if cache_view is not None else 0 inipos = jnp.cumsum(jnp.any(attention_mask, -1)[:, -1, :], axis=-1) return (inipos - (inipos >= 1)) + end_index @cached_property def quantizer(self): return EasyQuantizer( quantization_method=self.config.kv_cache_quantization_method, block_size=self.config.kv_cache_quantization_blocksize, ) @property def default_key_value_sharding(self): paxis = self.config.partition_axis return NamedSharding( mesh=self.config.mesh, spec=PartitionSpec( paxis.batch_axis, paxis.key_sequence_axis, paxis.head_axis, paxis.attention_dim_axis, ), ) def get_sharding_safely(self, tensor: jax.Array) -> PartitionSpec: return getattr(tensor, "sharding", self.default_key_value_sharding).spec @staticmethod def _transpose_sequence_head(*args): """The _transpose_sequence_head function transposes the query, key and value matrices. Args: *args: arrays to transpose Returns: The transpose of the query, key and value matrices """ return map(lambda x: jnp.transpose(x, (0, 2, 1, 3)), args) @jax.named_scope("easydel-flax-attention-concatenate-to-cache") def _concatenate_to_cache( self, query: Array, key: Array, value: Array, cache_view: TransformerCacheView, attention_mask: Array, causal_mask: tp.Optional[Array] = None, token_type_ids: tp.Optional[Array] = None, ) -> tp.Tuple[Array, Array, Array]: num_updated_cache_vectors = query.shape[1] end_index = cache_view.index[0] *batch_dims, max_length, num_heads, depth_per_head = cache_view.value.shape if attention_mask.ndim == 2: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) if causal_mask is not None: if hasattr(causal_mask, "value"): causal_mask = causal_mask.value causal_mask = lax.dynamic_slice( causal_mask, (0, 0, end_index, 0), (1, 1, num_updated_cache_vectors, max_length), ) if token_type_ids is not None and num_updated_cache_vectors != 1: token_type_mask = jnp.equal( jnp.expand_dims(token_type_ids, 2), jnp.expand_dims(token_type_ids, 1), ) token_type_mask = jnp.where(token_type_ids == 0, False, token_type_mask) token_type_mask = jnp.expand_dims(token_type_mask, 1) sequence_length = token_type_ids.shape[1] masked_portion = jnp.logical_or( token_type_mask[:, :, :num_updated_cache_vectors, :], causal_mask[:, :, :, :sequence_length], ) causal_mask = causal_mask.at[:, :, :, :sequence_length].set(masked_portion) causal_mask = jnp.broadcast_to( causal_mask, (query.shape[0],) + causal_mask.shape[1:], ) attention_mask = jnp.broadcast_to(attention_mask, causal_mask.shape) attention_mask = jnp.logical_and(attention_mask, causal_mask) slice_indices = (0, end_index % cache_view.value.shape[1], 0, 0) value_cache = lax.dynamic_update_slice( cache_view.value, value.astype(cache_view.value.dtype), slice_indices, ) key_cache = lax.dynamic_update_slice( cache_view.key, key.astype(cache_view.key.dtype), slice_indices, ) pad_mask = jnp.broadcast_to( jnp.arange(max_length) < end_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = jnp.logical_and(pad_mask, attention_mask) cache_view.key = self.quantizer( with_sharding_constraint( arr=key_cache, sharding=self.get_sharding_safely(cache_view.key), ) ) cache_view.value = self.quantizer( with_sharding_constraint( arr=value_cache, sharding=self.get_sharding_safely(cache_view.value), ) ) cache_view.index = cache_view.index + num_updated_cache_vectors return key_cache, value_cache, attention_mask @staticmethod def _create_sliding_mask( cache_pos: jnp.ndarray, curr_index: int, cache_length: int, sliding_windows: int, ): total_tokens = curr_index + cache_pos.shape[1] def _reconstruct_rotated_cache_positions(): cache_positions = jnp.arange(cache_length) + total_tokens - cache_length cache_positions = ( jnp.zeros_like(cache_positions) .at[cache_positions % cache_length] .set(cache_positions) ) return cache_positions cache_positions = jax.lax.cond( total_tokens <= cache_length, lambda: jnp.arange(cache_length), _reconstruct_rotated_cache_positions, ) cache_positions = cache_positions[None, None, :] cache_pos = cache_pos[:, :, None] sliding_mask = cache_positions > cache_pos - sliding_windows sliding_mask *= cache_positions < cache_pos + sliding_windows return sliding_mask @jax.named_scope("easydel-flax-attention-concatenate") def concatenate( self, *, query: Array, key: Array, value: Array, attention_mask: Array, cache_view: tp.Optional[TransformerCacheView] = None, causal_mask: tp.Optional[Array] = None, token_type_ids: tp.Optional[Array] = None, fcm_mask: tp.Optional[Array] = None, sliding_windows: tp.Optional[int] = None, ) -> tp.Tuple[Array, Array, Array, tp.Callable[[], Array]]: if attention_mask is not None: if attention_mask.dtype != jnp.bool: warnings.warn("attention_mask should be a boolean array", stacklevel=1) attention_mask = (attention_mask == 1).astype("b1") if cache_view is None: query_length = query.shape[1] key_length = key.shape[1] if causal_mask is not None: causal_mask = causal_mask[:, :, :query_length, :key_length] if token_type_ids is not None and query_length != 1: token_type_mask = jnp.equal( jnp.expand_dims(token_type_ids, 2), jnp.expand_dims(token_type_ids, 1), ) token_type_mask = token_type_mask.at[token_type_ids == 0].set(False) token_type_mask = jnp.expand_dims(token_type_mask, 1) sequence_length = token_type_ids.shape[1] masked_portion = jnp.logical_or( token_type_mask, causal_mask[ :, :, :, :sequence_length, ], ) causal_mask = causal_mask.at[ :, :, :, :sequence_length, ].set(masked_portion) causal_mask = jnp.broadcast_to( causal_mask, (query.shape[0],) + causal_mask.shape[1:] ) if attention_mask.ndim == 2: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = jnp.broadcast_to(attention_mask, causal_mask.shape) attention_mask = nn.combine_masks(attention_mask, causal_mask, fcm_mask) else: attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) attention_mask = jnp.repeat(attention_mask, query.shape[1], -2) else: key, value, attention_mask = self._concatenate_to_cache( query=query, key=key, value=value, cache_view=cache_view, attention_mask=attention_mask, causal_mask=causal_mask, token_type_ids=token_type_ids, ) if sliding_windows is not None and attention_mask is not None: sliding_window_mask = jnp.tril( jnp.ones_like(attention_mask, dtype=jnp.bool), k=-sliding_windows, ) window_mask = jnp.where(sliding_window_mask, 0, 1) attention_mask = jnp.logical_and(window_mask, attention_mask) if attention_mask.shape[-1] <= 1: attention_mask = attention_mask[:, :, :, -sliding_windows:] def init_attention_bias(): return lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) return key, value, attention_mask, init_attention_bias def shard_attention_prod(self, attn_output: jax.Array) -> jax.Array: """ shards attention output before passing that to output_proj Args: attn_output (jax.Array): merged output of dot product attention with 3 dims, (batch, seqlen, hidden_size). Returns: jax.Array: sharded version of `attn_output` """ return with_sharding_constraint( arr=attn_output, sharding=PartitionSpec( self.config.partition_axis.batch_axis, ( self.config.partition_axis.sequence_axis if attn_output.shape[1] != 1 else None ), self.config.partition_axis.hidden_state_axis, ), ) def _merge_heads(self, hidden_states: jax.Array) -> jax.Array: """ Merges the attention heads into a single hidden state tensor. Args: hidden_states (jax.Array): The hidden states with separate head dimensions. Returns: jax.Array: The hidden states with merged head dimensions. """ return hidden_states.reshape(hidden_states.shape[:2] + (-1,)) @staticmethod def repeat_key_value(key, value, num_reps: int): with jax.named_scope("easydel-flax-attention-repeat-kvheads"): key = einops.repeat(key, "b s h d -> b s (h r) d", r=num_reps) value = einops.repeat(value, "b s h d -> b s (h r) d", r=num_reps) return key, value