Source code for easydel.layers.caching.transformer.transformer_cache
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
from functools import partial
import chex as cx
import jax
from eformer import common_types
from eformer.escale import PartitionManager, apply_logical_sharding
from eformer.jaximus import ImplicitArray
from eformer.pytree import auto_pytree
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.sharding import Mesh
from jax.sharding import NamedSharding as Ns
from .._abstracts import (
BaseCache,
BaseCacheMetadata,
BaseCacheView,
BaseRunTimeMetadata,
)
if tp.TYPE_CHECKING:
from easydel.layers.quantization.quantizers import EasyQuantizer
else:
EasyQuantizer = object
NOT_GIVEN = common_types.NOT_GIVEN
RUNTIME_MODE_TYPES = common_types.RUNTIME_MODE_TYPES
BATCH = common_types.BATCH
QUERY_LENGTH = common_types.QUERY_LENGTH
KV_LENGTH = common_types.KV_LENGTH
HEAD = common_types.HEAD
KV_HEAD = common_types.KV_HEAD
HEAD_DIM = common_types.HEAD_DIM
KV_HEAD_DIM = common_types.KV_HEAD_DIM
BIAS_HEAD_SEQ = common_types.BIAS_HEAD_SEQ
BIAS_KV_SEQ = common_types.BIAS_KV_SEQ
MODE_PREFILL = common_types.MODE_PREFILL
[docs]@auto_pytree
class TransformerCacheMetaData(BaseCacheMetadata):
"""Metadata for transformer cache configuration."""
batch_size: int
sequence_length: int
num_hidden_layers: int
pad_token_id: int
# Optional attention-related fields
num_heads: tp.Optional[int]
head_dim: tp.Optional[int]
key_heads: tp.Optional[int]
value_heads: tp.Optional[int]
key_dim: tp.Optional[int]
value_dim: tp.Optional[int]
# Configuration flags
update_causal_mask: bool
create_attention_bias: bool
[docs] @classmethod
def create(
cls,
batch_size: int,
sequence_length: int,
num_hidden_layers: int,
pad_token_id: int,
num_heads: tp.Optional[int] = None,
head_dim: tp.Optional[int] = None,
key_heads: tp.Optional[int] = None,
value_heads: tp.Optional[int] = None,
key_dim: tp.Optional[int] = None,
value_dim: tp.Optional[int] = None,
update_causal_mask: bool = True,
create_attention_bias: bool = True,
) -> "TransformerCacheMetaData":
"""
Create a TransformerCacheMetaData instance with validation.
Arguments:
batch_size: Size of the batch.
sequence_length: Length of the sequence.
num_hidden_layers: number of hidden layers.
num_heads: Number of attention heads.
head_dim: Dimension of each head.
key_heads: Number of key heads.
value_heads: Number of value heads.
key_dim: Dimension of keys.
value_dim: Dimension of values.
update_causal_mask: Whether to update causal mask.
create_attention_bias: Whether to create attention bias.
Returns:
TransformerCacheMetaData instance
Raises:
ValueError: If required parameters are missing or invalid.
"""
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if sequence_length <= 0:
raise ValueError("sequence_length must be positive")
if head_dim is not None:
key_dim = key_dim or head_dim
value_dim = value_dim or head_dim
else:
if key_dim is None or value_dim is None:
raise ValueError(
"Either head_dim or both key_dim and value_dim must be specified"
)
# Derive heads from num_heads if not specified
if num_heads is not None:
key_heads = key_heads or num_heads
value_heads = value_heads or num_heads
else:
if key_heads is None or value_heads is None:
raise ValueError(
"Either num_heads or both key_heads and value_heads must be specified"
)
return cls(
batch_size=batch_size,
sequence_length=sequence_length,
num_hidden_layers=num_hidden_layers,
pad_token_id=pad_token_id,
num_heads=num_heads,
head_dim=head_dim,
key_heads=key_heads,
value_heads=value_heads,
key_dim=key_dim,
value_dim=value_dim,
update_causal_mask=update_causal_mask,
create_attention_bias=create_attention_bias,
)
[docs]@auto_pytree(frozen=False)
class TransformerCacheView(BaseCacheView):
key: tp.Union[cx.Array, ImplicitArray]
value: tp.Union[cx.Array, ImplicitArray]
index: tp.Union[cx.Array, ImplicitArray]
starts: tp.Union[cx.Array, ImplicitArray]
metadata: TransformerCacheMetaData
layer_index: tp.Optional[int] = None
[docs] @classmethod
def init(
cls,
mesh: Mesh,
dtype: jnp.dtype,
metadata: TransformerCacheMetaData,
quantizer: EasyQuantizer,
partition_manager: PartitionManager,
starts: tp.Optional[jax.Array] = None,
layer_index: tp.Optional[int] = None,
):
with jax.named_scope("easydel-transformer-cacheview-init"):
mt = metadata
kshape = (mt.batch_size, mt.sequence_length, mt.key_heads, mt.key_dim)
vshape = (mt.batch_size, mt.sequence_length, mt.value_heads, mt.value_dim)
kshardings = Ns(
mesh,
partition_manager.resolve(
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
shape=kshape,
),
)
vshardings = Ns(
mesh,
partition_manager.resolve(
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
shape=vshape,
),
)
ishardings = Ns(
mesh,
partition_manager.resolve(
axes=[BATCH],
mode=MODE_PREFILL,
shape=(mt.batch_size,),
),
)
if starts is None:
starts = jnp.zeros((mt.batch_size,), dtype=jnp.int32)
starts = apply_logical_sharding(
starts,
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
out = cls(
key=quantizer(jnp.zeros(shape=kshape, dtype=dtype, device=kshardings)),
value=quantizer(jnp.zeros(shape=vshape, dtype=dtype, device=vshardings)),
index=jnp.zeros((metadata.batch_size,), dtype=jnp.int32, device=ishardings),
starts=starts,
metadata=metadata,
layer_index=layer_index,
)
return out
[docs] @jax.named_scope("easydel-transformer-cacheview-concatenate-to-cache")
def concatenate_to_cache(
self,
query: cx.Array,
key: cx.Array,
value: cx.Array,
quantizer: EasyQuantizer,
cache_metadata: tp.Optional[TransformerMetadata],
attention_mask: cx.Array,
partition_manager: PartitionManager,
causal_mask: tp.Optional[cx.Array] = None,
token_type_ids: tp.Optional[cx.Array] = None,
) -> tp.Tuple[cx.Array, cx.Array, cx.Array]:
"""
Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.
Args:
query: Current query states.
key: Current key states to add to the cache.
value: Current value states to add to the cache.
cache_metadata: Optional metadata. If provided and contains slot/length info, enables pooled caching.
attention_mask: Base attention mask.
quantizer: Quantizer for the cache.
causal_mask: Optional causal mask.
token_type_ids: Optional token type IDs for segment masking.
Returns:
Tuple[Array, Array, Array]:
- Updated key cache tensor (functional update).
- Updated value cache tensor (functional update).
- Final attention mask to be used (either original or calculated).
"""
num_updated_cache_vectors = query.shape[1]
index = self.index
batch_dims, max_length, num_heads, depth_per_head = self.value.shape
if attention_mask.ndim == 2:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
if causal_mask is not None:
if hasattr(causal_mask, "value"):
causal_mask = causal_mask.value
if causal_mask.shape[0] != query.shape[0]:
causal_mask = jnp.broadcast_to(
causal_mask,
(query.shape[0],) + causal_mask.shape[1:],
)
@partial(jax.vmap, in_axes=(0, 0), out_axes=0)
def _mask_slice(mask, slot):
return lax.dynamic_slice(
mask,
(0, slot, 0),
(1, num_updated_cache_vectors, max_length),
)
causal_mask = _mask_slice(causal_mask, self.index)
if token_type_ids is not None and num_updated_cache_vectors != 1:
token_type_mask = jnp.equal(
jnp.expand_dims(token_type_ids, 2),
jnp.expand_dims(token_type_ids, 1),
)
token_type_mask = jnp.where(token_type_ids == 0, False, token_type_mask)
token_type_mask = jnp.expand_dims(token_type_mask, 1)
sequence_length = token_type_ids.shape[1]
masked_portion = jnp.logical_or(
token_type_mask[:, :, :num_updated_cache_vectors, :],
causal_mask[:, :, :, :sequence_length],
)
causal_mask = causal_mask.at[:, :, :, :sequence_length].set(masked_portion)
attention_mask = nn.combine_masks(attention_mask, causal_mask)
else:
attention_mask = attention_mask
def _maybe_materialize(x):
if hasattr(x, "materialize"):
x = x.materialize()
return x
@partial(jax.vmap, in_axes=(0, 0, 0), out_axes=(0))
def _update_kv(old, new, slot):
return lax.dynamic_update_slice(old, new.astype(old.dtype), (slot, 0, 0))
value_cache_updated = _update_kv(_maybe_materialize(self.value), value, index)
key_cache_updated = _update_kv(_maybe_materialize(self.key), key, index)
value_cache_updated = apply_logical_sharding(
value_cache_updated,
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
key_cache_updated = apply_logical_sharding(
key_cache_updated,
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
index = index + num_updated_cache_vectors
pad_mask = jnp.broadcast_to(
(jnp.arange(max_length)[None, :] < index[:, None])[:, None, None, :],
(batch_dims, 1, num_updated_cache_vectors, max_length),
)
return (
key_cache_updated,
value_cache_updated,
apply_logical_sharding(
jnp.logical_and(pad_mask, attention_mask),
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
partition_manager=partition_manager,
),
self.replace(
key=quantizer(key_cache_updated),
value=quantizer(value_cache_updated),
index=apply_logical_sharding(
index,
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
),
),
)
def __repr__(self):
try:
return (
self.__class__.__name__
+ f"(key={self.key.shape}, value={self.value.shape}, layer_index={self.layer_index})"
)
except AttributeError:
return (
self.__class__.__name__
+ f"(key={self.key}, value={self.value}, layer_index={self.layer_index})"
)
@property
def is_empty(self):
return self.key is None
__str__ = __repr__
[docs]@auto_pytree
class TransformerCache(BaseCache):
views: tp.List[tp.Optional[TransformerCacheView]]
[docs] @classmethod
def init_cache(
cls,
mesh: Mesh,
metadata: TransformerCacheMetaData,
partition_manager: PartitionManager,
dtype: tp.Optional[jnp.dtype] = None,
starts: tp.Optional[jax.Array] = None,
quantizer: tp.Optional[EasyQuantizer] = None,
):
from easydel.infra.etils import EasyDeLQuantizationMethods
from easydel.layers.quantization.quantizers import EasyQuantizer
quantizer = quantizer or EasyQuantizer(EasyDeLQuantizationMethods.NONE)
if dtype is None:
dtype = jnp.bfloat16
with mesh:
return cls(
views=[
# i have to somehow fix my OCD
TransformerCacheView.init(
mesh=mesh,
dtype=dtype,
starts=starts,
metadata=metadata,
quantizer=quantizer,
layer_index=layer_index,
partition_manager=partition_manager,
)
for layer_index in range(metadata.num_hidden_layers)
]
)
[docs] def to_pure(self):
return (
[
[layer.key, layer.value, layer.index, layer.starts]
for i, layer in enumerate(self.views)
],
self.views[-1].metadata,
)
[docs] @classmethod
def from_pure(cls, pure, metadata):
return cls(
views=[
TransformerCacheView(
key=layer[0],
value=layer[1],
index=layer[2],
starts=layer[3],
metadata=metadata,
)
for layer in pure
]
)
[docs] def insert_starts(self, starts, slot: int, partition_manager: PartitionManager):
for idx in range(len(self.views)):
view = self.views[idx]
starts = jnp.array(starts).reshape(-1)
self.views[idx] = self.views[idx].replace(
starts=apply_logical_sharding(
lax.dynamic_update_slice_in_dim(view.starts, starts, slot, 0),
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
)
return self
[docs] def insert_index(self, index, slot: int, partition_manager: PartitionManager):
for idx in range(len(self.views)):
view = self.views[idx]
index = jnp.array(index).reshape(-1)
self.views[idx] = self.views[idx].replace(
index=apply_logical_sharding(
lax.dynamic_update_slice_in_dim(view.index, index, slot, 0),
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
)
return self
[docs] def insert(
self,
other: TransformerCache,
slot: int,
quantizer: EasyQuantizer,
partition_manager: PartitionManager,
):
for idx in range(len(self.views)):
view = self.views[idx]
oview = other.views[idx]
def _maybe_materialize(x):
if hasattr(x, "materialize"):
x = x.materialize()
return x
self.views[idx] = self.views[idx].replace(
key=quantizer(
apply_logical_sharding(
lax.dynamic_update_slice(
_maybe_materialize(view.key),
oview.key,
(slot, 0, 0, 0),
),
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
),
value=quantizer(
apply_logical_sharding(
lax.dynamic_update_slice(
_maybe_materialize(view.value),
oview.value,
(slot, 0, 0, 0),
),
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
),
index=apply_logical_sharding(
lax.dynamic_update_slice_in_dim(view.index, oview.index, slot, 0),
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
),
starts=apply_logical_sharding(
lax.dynamic_update_slice_in_dim(view.starts, oview.starts, slot, 0),
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
),
metadata=view.metadata,
)
return self
[docs] @classmethod
def init_empty(cls, num_hidden_layers):
return cls(views=[None for _ in range(num_hidden_layers)])
def __repr__(self):
return (
f"{self.__class__.__name__}(\n "
+ "\n ".join(str(view) for view in self.views)
+ "\n)"
)
__str__ = __repr__
[docs]@auto_pytree
class TransformerMetadata(BaseRunTimeMetadata):
"""
holds optional metadata for attention runtime
"""
postpadded: bool = False
index: int | None = None