Source code for easydel.layers.caching.transformer_cache

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp

import chex as cx
from eformer.escale import PartitionAxis
from eformer.jaximus import ImplicitArray
from jax import numpy as jnp
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from easydel.infra.etils import EasyDeLQuantizationMethods

if tp.TYPE_CHECKING:
	from easydel.layers.quantization.quantizers import EasyQuantizer
else:
	EasyQuantizer = object


[docs]@cx.dataclass class TransformerCacheMetaData: """Metadata for transformer cache configuration.""" # Required fields batch_size: int sequence_length: int # Optional attention-related fields num_heads: tp.Optional[int] head_dim: tp.Optional[int] key_heads: tp.Optional[int] value_heads: tp.Optional[int] key_dim: tp.Optional[int] value_dim: tp.Optional[int] # Configuration flags update_causal_mask: bool create_attention_bias: bool
[docs] @classmethod def create( cls, batch_size: int, sequence_length: int, num_heads: tp.Optional[int] = None, head_dim: tp.Optional[int] = None, key_heads: tp.Optional[int] = None, value_heads: tp.Optional[int] = None, key_dim: tp.Optional[int] = None, value_dim: tp.Optional[int] = None, update_causal_mask: bool = True, create_attention_bias: bool = True, ) -> "TransformerCacheMetaData": """ Create a TransformerCacheMetaData instance with validation. Arguments: batch_size: Size of the batch. sequence_length: Length of the sequence. num_heads: Number of attention heads. head_dim: Dimension of each head. key_heads: Number of key heads. value_heads: Number of value heads. key_dim: Dimension of keys. value_dim: Dimension of values. update_causal_mask: Whether to update causal mask. create_attention_bias: Whether to create attention bias. Returns: TransformerCacheMetaData instance Raises: ValueError: If required parameters are missing or invalid. """ if batch_size <= 0: raise ValueError("batch_size must be positive") if sequence_length <= 0: raise ValueError("sequence_length must be positive") if head_dim is not None: key_dim = key_dim or head_dim value_dim = value_dim or head_dim else: if key_dim is None or value_dim is None: raise ValueError( "Either head_dim or both key_dim and value_dim must be specified" ) # Derive heads from num_heads if not specified if num_heads is not None: key_heads = key_heads or num_heads value_heads = value_heads or num_heads else: if key_heads is None or value_heads is None: raise ValueError( "Either num_heads or both key_heads and value_heads must be specified" ) return cls( batch_size=batch_size, sequence_length=sequence_length, num_heads=num_heads, head_dim=head_dim, key_heads=key_heads, value_heads=value_heads, key_dim=key_dim, value_dim=value_dim, update_causal_mask=update_causal_mask, create_attention_bias=create_attention_bias, )
[docs]@cx.dataclass class TransformerCacheView: key: tp.Union[cx.Array, ImplicitArray] value: tp.Union[cx.Array, ImplicitArray] index: tp.Union[cx.Array, ImplicitArray] metadata: TransformerCacheMetaData layer_index: tp.Optional[int] = None
[docs] @classmethod def init( cls, metadata: TransformerCacheMetaData, quantizer: EasyQuantizer, key_values_partition_specs: PartitionSpec, dtype: jnp.dtype, mesh: Mesh, layer_index: tp.Optional[int] = None, ): with jax.named_scope("easydel-transformer-cacheview-init"): device = NamedSharding(mesh=mesh, spec=key_values_partition_specs) out = cls( key=quantizer( jnp.zeros( shape=( metadata.batch_size, metadata.sequence_length, metadata.key_heads, metadata.key_dim, ), dtype=dtype, device=device, ), ), value=quantizer( jnp.zeros( shape=( metadata.batch_size, metadata.sequence_length, metadata.value_heads, metadata.value_dim, ), dtype=dtype, device=device, ) ), index=jnp.zeros((metadata.batch_size,), dtype=jnp.int32), metadata=metadata, layer_index=layer_index, ) return out
def __repr__(self): try: return ( self.__class__.__name__ + f"(key={self.key.shape}, value={self.value.shape}, layer_index={self.layer_index})" ) except AttributeError: return ( self.__class__.__name__ + f"(key={self.key}, value={self.value}, layer_index={self.layer_index})" ) __str__ = __repr__
[docs]@cx.dataclass class TransformerCache: views: tp.List[tp.Optional[TransformerCacheView]]
[docs] @classmethod def init_layers_cache( cls, num_hidden_layers: int, metadata: TransformerCacheMetaData, mesh: Mesh, quantizer: tp.Optional[EasyQuantizer] = None, dtype: tp.Optional[jnp.dtype] = None, key_values_partition_specs: tp.Optional[PartitionSpec] = None, ): from easydel.layers.quantization.quantizers import EasyQuantizer paxis = PartitionAxis() quantizer = quantizer or EasyQuantizer(EasyDeLQuantizationMethods.NONE) key_values_partition_specs = key_values_partition_specs or PartitionSpec( paxis.batch_axis, paxis.key_sequence_axis, paxis.head_axis, paxis.attention_dim_axis, ) if dtype is None: dtype = jnp.bfloat16 return cls( views=[ TransformerCacheView.init( metadata=metadata, quantizer=quantizer, key_values_partition_specs=key_values_partition_specs, dtype=dtype, mesh=mesh, layer_index=layer_index, ) for layer_index in range(num_hidden_layers) ] )
[docs] @classmethod def init_empty(cls, num_hidden_layers): return cls(views=[None for _ in range(num_hidden_layers)])
def __repr__(self): return ( f"{self.__class__.__name__}(\n " + "\n ".join(str(view) for view in self.views) + "\n)" ) __str__ = __repr__