Source code for easydel.layers.caching.transformer.transformer_cache
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
import chex as cx
import jax
from eformer.escale import PartitionAxis, with_sharding_constraint
from eformer.jaximus import ImplicitArray
from eformer.pytree import auto_pytree
from flax import nnx as nn
from jax import lax
from jax import numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from .._abstracts import (
BaseCache,
BaseCacheMetadata,
BaseCacheView,
BaseRunTimeMetadata,
)
if tp.TYPE_CHECKING:
from easydel.layers.quantization.quantizers import EasyQuantizer
else:
EasyQuantizer = object
[docs]@auto_pytree
class TransformerCacheMetaData(BaseCacheMetadata):
"""Metadata for transformer cache configuration."""
partition_axis: PartitionAxis
# Required fields
batch_size: int
sequence_length: int
num_hidden_layers: int
# Optional attention-related fields
num_heads: tp.Optional[int]
head_dim: tp.Optional[int]
key_heads: tp.Optional[int]
value_heads: tp.Optional[int]
key_dim: tp.Optional[int]
value_dim: tp.Optional[int]
# Configuration flags
update_causal_mask: bool
create_attention_bias: bool
[docs] @classmethod
def create(
cls,
partition_axis: PartitionAxis,
batch_size: int,
sequence_length: int,
num_hidden_layers: int,
num_heads: tp.Optional[int] = None,
head_dim: tp.Optional[int] = None,
key_heads: tp.Optional[int] = None,
value_heads: tp.Optional[int] = None,
key_dim: tp.Optional[int] = None,
value_dim: tp.Optional[int] = None,
update_causal_mask: bool = True,
create_attention_bias: bool = True,
) -> "TransformerCacheMetaData":
"""
Create a TransformerCacheMetaData instance with validation.
Arguments:
partition_axis: Partition axis.
batch_size: Size of the batch.
sequence_length: Length of the sequence.
num_hidden_layers: number of hidden layers.
num_heads: Number of attention heads.
head_dim: Dimension of each head.
key_heads: Number of key heads.
value_heads: Number of value heads.
key_dim: Dimension of keys.
value_dim: Dimension of values.
update_causal_mask: Whether to update causal mask.
create_attention_bias: Whether to create attention bias.
Returns:
TransformerCacheMetaData instance
Raises:
ValueError: If required parameters are missing or invalid.
"""
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if sequence_length <= 0:
raise ValueError("sequence_length must be positive")
if head_dim is not None:
key_dim = key_dim or head_dim
value_dim = value_dim or head_dim
else:
if key_dim is None or value_dim is None:
raise ValueError(
"Either head_dim or both key_dim and value_dim must be specified"
)
# Derive heads from num_heads if not specified
if num_heads is not None:
key_heads = key_heads or num_heads
value_heads = value_heads or num_heads
else:
if key_heads is None or value_heads is None:
raise ValueError(
"Either num_heads or both key_heads and value_heads must be specified"
)
return cls(
partition_axis=partition_axis,
batch_size=batch_size,
sequence_length=sequence_length,
num_hidden_layers=num_hidden_layers,
num_heads=num_heads,
head_dim=head_dim,
key_heads=key_heads,
value_heads=value_heads,
key_dim=key_dim,
value_dim=value_dim,
update_causal_mask=update_causal_mask,
create_attention_bias=create_attention_bias,
)
[docs]@auto_pytree
class TransformerCacheView(BaseCacheView):
key: tp.Union[cx.Array, ImplicitArray]
value: tp.Union[cx.Array, ImplicitArray]
index: tp.Union[cx.Array, ImplicitArray]
metadata: TransformerCacheMetaData
layer_index: tp.Optional[int] = None
[docs] @classmethod
def init(
cls,
metadata: TransformerCacheMetaData,
quantizer: EasyQuantizer,
key_values_partition_specs: PartitionSpec,
dtype: jnp.dtype,
mesh: Mesh,
layer_index: tp.Optional[int] = None,
):
with jax.named_scope("easydel-transformer-cacheview-init"):
device = NamedSharding(mesh=mesh, spec=key_values_partition_specs)
out = cls(
key=quantizer(
jnp.zeros(
shape=(
metadata.batch_size,
metadata.sequence_length,
metadata.key_heads,
metadata.key_dim,
),
dtype=dtype,
device=device,
),
),
value=quantizer(
jnp.zeros(
shape=(
metadata.batch_size,
metadata.sequence_length,
metadata.value_heads,
metadata.value_dim,
),
dtype=dtype,
device=device,
)
),
index=jnp.zeros((metadata.batch_size,), dtype=jnp.int32),
metadata=metadata,
layer_index=layer_index,
)
return out
[docs] @jax.named_scope("easydel-transformer-cacheview-concatenate-to-cache")
def concatenate_to_cache(
self,
query: cx.Array,
key: cx.Array,
value: cx.Array,
cache_metadata: tp.Optional[TransformerMetadata],
attention_mask: cx.Array,
kv_sharding: NamedSharding, # Ensure this is NamedSharding
quantizer: EasyQuantizer,
causal_mask: tp.Optional[cx.Array] = None,
token_type_ids: tp.Optional[cx.Array] = None,
) -> tp.Tuple[cx.Array, cx.Array, cx.Array]:
"""
Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.
Args:
query: Current query states.
key: Current key states to add to the cache.
value: Current value states to add to the cache.
cache_metadata: Optional metadata. If provided and contains slot/length info, enables pooled caching.
attention_mask: Base attention mask.
kv_sharding: NamedSharding spec for the cache tensors.
quantizer: Quantizer for the cache.
causal_mask: Optional causal mask.
token_type_ids: Optional token type IDs for segment masking.
Returns:
Tuple[Array, Array, Array]:
- Updated key cache tensor (functional update).
- Updated value cache tensor (functional update).
- Final attention mask to be used (either original or calculated).
"""
num_updated_cache_vectors = query.shape[1]
end_index = self.index[0]
*batch_dims, max_length, num_heads, depth_per_head = self.value.shape
if attention_mask.ndim == 2:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
if causal_mask is not None:
if hasattr(causal_mask, "value"):
causal_mask = causal_mask.value
causal_mask_slice = lax.dynamic_slice(
causal_mask,
(0, 0, end_index % max_length, 0),
(1, 1, num_updated_cache_vectors, max_length),
)
if token_type_ids is not None and num_updated_cache_vectors != 1:
token_type_mask = jnp.equal(
jnp.expand_dims(token_type_ids, 2),
jnp.expand_dims(token_type_ids, 1),
)
token_type_mask = jnp.where(token_type_ids == 0, False, token_type_mask)
token_type_mask = jnp.expand_dims(token_type_mask, 1)
sequence_length = token_type_ids.shape[1]
masked_portion = jnp.logical_or(
token_type_mask[:, :, :num_updated_cache_vectors, :],
causal_mask_slice[:, :, :, :sequence_length],
)
causal_mask_slice = causal_mask_slice.at[:, :, :, :sequence_length].set(
masked_portion
)
causal_mask_expanded = jnp.broadcast_to(
causal_mask_slice,
(query.shape[0],) + causal_mask_slice.shape[1:],
)
final_attention_mask = nn.combine_masks(attention_mask, causal_mask_expanded)
else:
final_attention_mask = attention_mask
slice_indices = (0, end_index % max_length, 0, 0)
value_cache_updated = lax.dynamic_update_slice(
self.value,
value.astype(self.value.dtype),
slice_indices,
)
key_cache_updated = lax.dynamic_update_slice(
self.key,
key.astype(self.key.dtype),
slice_indices,
)
value_cache_updated = with_sharding_constraint(value_cache_updated, kv_sharding)
key_cache_updated = with_sharding_constraint(key_cache_updated, kv_sharding)
self.key = quantizer(key_cache_updated)
self.value = quantizer(value_cache_updated)
self.index = self.index + num_updated_cache_vectors
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < self.index[0],
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
final_attention_mask = jnp.logical_and(pad_mask, final_attention_mask)
return key_cache_updated, value_cache_updated, final_attention_mask
@staticmethod
def _update_cache_logic(
key_tensor_layer: jax.Array,
value_tensor_layer: jax.Array,
new_keys: jax.Array,
new_values: jax.Array,
micro_batch_slot_indices: jax.Array,
micro_batch_write_pos: jax.Array,
) -> tp.Tuple[jax.Array, jax.Array]:
"""Internal logic to update cache based on new keys/values for pooled cache."""
if new_keys.ndim == 4 and new_keys.shape[1] > 1:
# Prefill case
seq_len_delta = new_keys.shape[1]
def prefill_scan_body(carry, seq_idx_delta):
k_tensor, v_tensor, current_write_pos = carry
step_keys = lax.dynamic_slice_in_dim(
new_keys,
seq_idx_delta,
1,
axis=1,
).squeeze(1)
step_values = lax.dynamic_slice_in_dim(
new_values,
seq_idx_delta,
1,
axis=1,
).squeeze(1)
updated_k, updated_v = TransformerCacheView._update_single_token_step(
k_tensor,
v_tensor,
step_keys,
step_values,
micro_batch_slot_indices,
current_write_pos,
)
return (updated_k, updated_v, current_write_pos + 1), None
initial_carry = (key_tensor_layer, value_tensor_layer, micro_batch_write_pos)
(final_k_tensor, final_v_tensor, _), _ = lax.scan(
prefill_scan_body, initial_carry, xs=jnp.arange(seq_len_delta)
)
elif new_keys.ndim == 3 or (new_keys.ndim == 4 and new_keys.shape[1] == 1):
# Decode case (or prefill of length 1)
# Squeeze the sequence dimension if it exists and is 1
if new_keys.ndim == 4:
new_keys = new_keys.squeeze(1)
new_values = new_values.squeeze(1)
final_k_tensor, final_v_tensor = TransformerCacheView._update_single_token_step(
key_tensor_layer,
value_tensor_layer,
new_keys,
new_values,
micro_batch_slot_indices,
micro_batch_write_pos,
)
else:
raise ValueError(
f"Unexpected shape for new_keys: {new_keys.shape}. Expected 3D [B, H, D] or 4D [B, S, H, D]."
)
return final_k_tensor, final_v_tensor
@staticmethod
def _update_single_token_step(
key_tensor_layer,
value_tensor_layer,
new_keys, # Shape [micro_batch_size, num_heads, dim]
new_values, # Shape [micro_batch_size, num_heads, dim]
micro_batch_slot_indices,
micro_batch_write_pos,
):
"""Internal logic for updating cache for a single token step across the batch."""
micro_batch_size = micro_batch_slot_indices.shape[0]
# Expand dims for dynamic_update_slice: needs [1, 1, num_heads, dim]
new_keys_update = jnp.expand_dims(new_keys, axis=1)
new_values_update = jnp.expand_dims(new_values, axis=1)
def body_fn_batch_item(batch_item_idx, current_k_v_tensors):
k_tensor, v_tensor = current_k_v_tensors
slot_idx = micro_batch_slot_indices[batch_item_idx]
seq_pos = micro_batch_write_pos[batch_item_idx]
# Slice the update for the specific batch item
# Shape [1, 1, num_heads, dim]
item_key_update = lax.dynamic_slice_in_dim(
new_keys_update,
batch_item_idx,
1,
axis=0,
)
item_value_update = lax.dynamic_slice_in_dim(
new_values_update,
batch_item_idx,
1,
axis=0,
)
start_indices = (slot_idx, seq_pos, 0, 0)
updated_k_tensor = lax.dynamic_update_slice(
k_tensor,
item_key_update.astype(k_tensor.dtype),
start_indices,
)
updated_v_tensor = lax.dynamic_update_slice(
v_tensor,
item_value_update.astype(v_tensor.dtype),
start_indices,
)
return (updated_k_tensor, updated_v_tensor)
final_k_tensor, final_v_tensor = lax.fori_loop(
0,
micro_batch_size,
body_fn_batch_item,
(key_tensor_layer, value_tensor_layer),
)
return final_k_tensor, final_v_tensor
def __repr__(self):
try:
return (
self.__class__.__name__
+ f"(key={self.key.shape}, value={self.value.shape}, layer_index={self.layer_index})"
)
except AttributeError:
return (
self.__class__.__name__
+ f"(key={self.key}, value={self.value}, layer_index={self.layer_index})"
)
__str__ = __repr__
[docs]@auto_pytree
class TransformerCache(BaseCache):
views: tp.List[tp.Optional[TransformerCacheView]]
[docs] @classmethod
def init_cache(
cls,
metadata: TransformerCacheMetaData,
mesh: Mesh,
quantizer: tp.Optional[EasyQuantizer] = None,
dtype: tp.Optional[jnp.dtype] = None,
key_values_partition_specs: tp.Optional[PartitionSpec] = None,
):
from easydel.infra.etils import EasyDeLQuantizationMethods
from easydel.layers.quantization.quantizers import EasyQuantizer
paxis = metadata.partition_axis
quantizer = quantizer or EasyQuantizer(EasyDeLQuantizationMethods.NONE)
key_values_partition_specs = key_values_partition_specs or PartitionSpec(
paxis.batch_axis,
paxis.key_sequence_axis,
paxis.head_axis,
paxis.attention_dim_axis,
)
if dtype is None:
dtype = jnp.bfloat16
return cls(
views=[
TransformerCacheView.init(
metadata=metadata,
quantizer=quantizer,
key_values_partition_specs=key_values_partition_specs,
dtype=dtype,
mesh=mesh,
layer_index=layer_index,
)
for layer_index in range(metadata.num_hidden_layers)
]
)
[docs] @classmethod
def init_empty(cls, num_hidden_layers):
return cls(views=[None for _ in range(num_hidden_layers)])
def __repr__(self):
return (
f"{self.__class__.__name__}(\n "
+ "\n ".join(str(view) for view in self.views)
+ "\n)"
)
__str__ = __repr__
[docs]@auto_pytree
class TransformerMetadata(BaseRunTimeMetadata):
"""
holds optional metadata for attention runtime
"""