Source code for easydel.layers.caching.transformer.cache
# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer key-value caching implementation for EasyDeL.
This module provides the standard key-value caching system for transformer
models, supporting various attention patterns including full attention,
sliding window attention, and local attention.
The transformer cache is designed for efficient autoregressive generation
by storing previously computed key and value states, avoiding redundant
computation during inference.
Key Components:
- TransformerCacheMetaData: Configuration for cache dimensions and behavior
- TransformerCacheView: Per-layer cache storage and update logic
- TransformerCache: Multi-layer cache orchestration
- TransformerMetadata: Runtime metadata for cache operations
- AttnMaskDetail: Attention masking configuration
Features:
- Support for multiple attention patterns (full, sliding, local)
- Quantization support for memory efficiency
- Distributed caching with JAX sharding
- Functional cache updates for JAX compatibility
- Dynamic mask generation and caching
Example:
>>> # Initialize cache metadata
>>> metadata = TransformerCacheMetaData.create(
... batch_size=2,
... sequence_length=1024,
... num_hidden_layers=12,
... pad_token_id=0,
... num_heads=16,
... head_dim=64
... )
>>>
>>> # Create cache
>>> cache = TransformerCache.init_cache(
... mesh=mesh,
... metadata=metadata,
... partition_manager=pm,
... dtype=jnp.bfloat16
... )
>>>
>>> # Update cache during inference
>>> for layer_idx in range(12):
... key_cache, value_cache, mask, new_view = cache[layer_idx].concatenate_to_cache(
... query=query_states,
... key=key_states,
... value=value_states,
... attention_mask=attention_mask,
... quantizer=quantizer,
... partition_manager=pm
... )
... cache[layer_idx] = new_view
"""
from __future__ import annotations
import typing as tp
from enum import Enum
from functools import partial
import jax
from eformer import common_types
from eformer.escale import PartitionManager, apply_logical_sharding
from eformer.jaximus import ImplicitArray, register
from eformer.pytree import auto_pytree, field
from ejkernel.types import MaskInfo
from jax import lax
from jax import numpy as jnp
from jax.extend.core import Primitive
from jax.sharding import Mesh
from jax.sharding import NamedSharding as Ns
from jaxtyping import Array as JAXArray
from jaxtyping import Float, Int
from .._abstracts import BaseCache, BaseCacheMetadata, BaseCacheView, BaseRunTimeMetadata
if tp.TYPE_CHECKING:
from easydel.layers.quantization.quantizers import EasyQuantizer
else:
EasyQuantizer = object
[docs]@auto_pytree
class AttnMaskDetail:
"""Configuration for attention masking patterns.
Defines the type and parameters of attention masking to apply
during cache operations. Supports various masking strategies
including sliding windows, chunks, and custom patterns.
Attributes:
mask_type (Enum): Type of attention mask (e.g., FULL, SLIDING, CHUNKED).
size (int): Primary size parameter for the mask (window size, chunk size, etc.).
offset (int | None): Optional offset for mask positioning.
chunks (int | None): Number of chunks for chunked attention.
bricks (int | None): Number of bricks for blocked attention patterns.
"""
mask_type: Enum = field(pytree_node=False)
size: int = field(pytree_node=False)
offset: int | None = field(pytree_node=False, default=None)
chunks: int | None = field(pytree_node=False, default=None)
bricks: int | None = field(pytree_node=False, default=None)
NOT_GIVEN = common_types.NOT_GIVEN
RUNTIME_MODE_TYPES = common_types.RUNTIME_MODE_TYPES
BATCH = common_types.BATCH
QUERY_LENGTH = common_types.QUERY_LENGTH
KV_LENGTH = common_types.KV_LENGTH
HEAD = common_types.HEAD
KV_HEAD = common_types.KV_HEAD
HEAD_DIM = common_types.HEAD_DIM
KV_HEAD_DIM = common_types.KV_HEAD_DIM
BIAS_HEAD_SEQ = common_types.BIAS_HEAD_SEQ
BIAS_KV_SEQ = common_types.BIAS_KV_SEQ
MODE_PREFILL = common_types.MODE_PREFILL
@register("dynamic_update_slice")
def _(
primitive: Primitive,
operand: ImplicitArray,
update: tp.Any,
*args,
**kwargs,
) -> JAXArray:
"""Register handler for dynamic_update_slice with ImplicitArray operand.
Materializes the implicit array before performing the update operation.
This ensures compatibility with quantized or lazy-evaluated arrays.
Args:
primitive: JAX primitive for dynamic_update_slice.
operand: ImplicitArray to update (will be materialized).
update: Update values.
*args: Additional arguments for the primitive.
**kwargs: Additional keyword arguments.
Returns:
Result of the dynamic_update_slice operation.
"""
operand = operand.materialize()
return primitive.bind(operand, update, *args)
@register("dynamic_update_slice")
def _(
primitive: Primitive,
operand: tp.Any,
update: ImplicitArray,
*args,
**kwargs,
) -> JAXArray:
update = update.materialize()
return primitive.bind(operand, update, *args)
@register("dynamic_update_slice")
def _(
primitive: Primitive,
operand: ImplicitArray,
update: ImplicitArray,
*args,
**kwargs,
) -> JAXArray:
operand = operand.materialize()
update = update.materialize()
return primitive.bind(operand, update, *args)
[docs]@auto_pytree
class TransformerCacheMetaData(BaseCacheMetadata):
"""Metadata configuration for transformer key-value caching.
Stores all static configuration needed to initialize and operate
a transformer cache. Supports various attention head configurations
including multi-head, multi-query, and grouped-query attention.
The metadata defines:
- Cache dimensions (batch, sequence, layers)
- Attention head configuration
- Masking and bias settings
- Special attention patterns (sliding window)
Attributes:
batch_size (int): Number of sequences in batch.
sequence_length (int): Maximum sequence length to cache.
num_hidden_layers (int): Number of transformer layers.
pad_token_id (int): Token ID used for padding.
num_heads (int | None): Number of attention heads (for regular MHA).
head_dim (int | None): Dimension of each attention head.
key_heads (int | None): Number of key heads (for MQA/GQA).
value_heads (int | None): Number of value heads (for MQA/GQA).
key_dim (int | None): Dimension of key projections.
value_dim (int | None): Dimension of value projections.
sliding_window (int | None): Size of sliding attention window.
update_causal_mask (bool): Whether to update causal masks dynamically.
create_attention_bias (bool): Whether to create attention bias terms.
"""
batch_size: int = field(pytree_node=False)
sequence_length: int = field(pytree_node=False)
num_hidden_layers: int = field(pytree_node=False)
pad_token_id: int = field(pytree_node=False)
# Optional attention-related fields
num_heads: int | None = field(pytree_node=False)
head_dim: int | None = field(pytree_node=False)
key_heads: int | None = field(pytree_node=False)
value_heads: int | None = field(pytree_node=False)
key_dim: int | None = field(pytree_node=False)
value_dim: int | None = field(pytree_node=False)
sliding_window: int | None = field(pytree_node=False)
# Configuration flags
update_causal_mask: bool = field(pytree_node=False)
create_attention_bias: bool = field(pytree_node=False)
[docs] @classmethod
def create(
cls,
batch_size: int,
sequence_length: int,
num_hidden_layers: int,
pad_token_id: int,
num_heads: int | None = None,
head_dim: int | None = None,
key_heads: int | None = None,
value_heads: int | None = None,
key_dim: int | None = None,
value_dim: int | None = None,
update_causal_mask: bool = True,
create_attention_bias: bool = True,
sliding_window: int | None = None,
) -> TransformerCacheMetaData:
"""
Create a TransformerCacheMetaData instance with validation.
Arguments:
batch_size: Size of the batch.
sequence_length: Length of the sequence.
num_hidden_layers: number of hidden layers.
num_heads: Number of attention heads.
head_dim: Dimension of each head.
key_heads: Number of key heads.
value_heads: Number of value heads.
key_dim: Dimension of keys.
value_dim: Dimension of values.
update_causal_mask: Whether to update causal mask.
create_attention_bias: Whether to create attention bias.
Returns:
TransformerCacheMetaData instance
Raises:
ValueError: If required parameters are missing or invalid.
"""
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if sequence_length <= 0:
raise ValueError("sequence_length must be positive")
if head_dim is not None:
key_dim = key_dim or head_dim
value_dim = value_dim or head_dim
else:
if key_dim is None or value_dim is None:
raise ValueError("Either head_dim or both key_dim and value_dim must be specified")
# Derive heads from num_heads if not specified
if num_heads is not None:
key_heads = key_heads or num_heads
value_heads = value_heads or num_heads
else:
if key_heads is None or value_heads is None:
raise ValueError("Either num_heads or both key_heads and value_heads must be specified")
return cls(
batch_size=batch_size,
sequence_length=sequence_length,
num_hidden_layers=num_hidden_layers,
pad_token_id=pad_token_id,
num_heads=num_heads,
head_dim=head_dim,
key_heads=key_heads,
value_heads=value_heads,
key_dim=key_dim,
value_dim=value_dim,
update_causal_mask=update_causal_mask,
create_attention_bias=create_attention_bias,
sliding_window=sliding_window,
)
[docs]@auto_pytree(frozen=False)
class TransformerCacheView(BaseCacheView):
"""Single-layer cache view for transformer key-value states.
Manages the cached key and value tensors for one transformer layer,
along with position tracking and masking information. Supports
various attention patterns and quantization strategies.
The view maintains:
- Key and value state tensors
- Current position indices for each sequence
- Starting positions for relative indexing
- Masking configuration for attention patterns
Attributes:
key (cx.Array | ImplicitArray): Cached key states.
Shape: [batch_size, seq_length, num_key_heads, key_dim]
value (cx.Array | ImplicitArray): Cached value states.
Shape: [batch_size, seq_length, num_value_heads, value_dim]
indexs (cx.Array | ImplicitArray): Current position index per sequence.
Shape: [batch_size]
starts (cx.Array | ImplicitArray): Starting position per sequence.
Shape: [batch_size]
metadata (TransformerCacheMetaData): Static cache configuration.
maximum_sequence_length (int): Maximum cacheable sequence length.
layer_index (int | None): Index of this layer in the model.
masking_details (AttnMaskDetail | None): Attention mask configuration.
"""
key: Float[JAXArray, "batch seq_len num_key_heads key_dim"] | ImplicitArray
value: Float[JAXArray, "batch seq_len num_value_heads value_dim"] | ImplicitArray
indexs: Int[JAXArray, "batch"] | ImplicitArray # noqa: F821
starts: Int[JAXArray, "batch"] | ImplicitArray # noqa: F821
metadata: TransformerCacheMetaData
maximum_sequence_length: int = field(pytree_node=False)
layer_index: int | None = field(pytree_node=False, default=None)
masking_details: AttnMaskDetail | None = field(pytree_node=False, default=None)
[docs] @classmethod
def init(
cls,
mesh: Mesh,
dtype: jnp.dtype,
metadata: TransformerCacheMetaData,
quantizer: EasyQuantizer,
partition_manager: PartitionManager,
starts: Int[JAXArray, "batch"] | None = None, # noqa: F821
layer_index: int | None = None,
masking_details: AttnMaskDetail | None = None,
):
"""Initialize a transformer cache view for a single layer.
Creates and allocates cache tensors with appropriate shapes,
dtypes, and sharding for distributed execution. Applies
quantization if configured.
Args:
mesh (Mesh): JAX device mesh for distributed execution.
dtype (jnp.dtype): Data type for cache tensors.
metadata (TransformerCacheMetaData): Cache configuration.
quantizer (EasyQuantizer): Quantization configuration.
partition_manager (PartitionManager): Sharding strategy manager.
starts (jax.Array | None): Initial starting positions per sequence.
Defaults to zeros if not provided.
layer_index (int | None): Index of this layer in the model.
masking_details (AttnMaskDetail | None): Attention mask configuration.
Returns:
TransformerCacheView: Initialized cache view with allocated tensors.
Note:
For sliding window attention, cache dimensions are adjusted
based on the window size specified in masking_details.
"""
from easydel.infra.utils import AttnMaskType
with jax.named_scope("easydel-transformer-cacheview-init"):
mt = metadata
kshape = (mt.batch_size, mt.sequence_length, mt.key_heads, mt.key_dim)
vshape = (mt.batch_size, mt.sequence_length, mt.value_heads, mt.value_dim)
kv_sharding_axes = [BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM]
if masking_details is not None:
if masking_details.mask_type == AttnMaskType.SLIDING:
kshape = (mt.batch_size, min(masking_details.size, mt.sequence_length), mt.key_heads, mt.key_dim)
vshape = (mt.batch_size, min(masking_details.size, mt.sequence_length), mt.value_heads, mt.value_dim)
kshardings = Ns(mesh, partition_manager.resolve(axes=kv_sharding_axes, mode=MODE_PREFILL, shape=kshape))
vshardings = Ns(mesh, partition_manager.resolve(axes=kv_sharding_axes, mode=MODE_PREFILL, shape=vshape))
ishardings = Ns(mesh, partition_manager.resolve(axes=[BATCH], mode=MODE_PREFILL, shape=(mt.batch_size,)))
if starts is None:
starts = jnp.zeros((mt.batch_size,), dtype=jnp.int32)
starts = apply_logical_sharding(starts, axes=[BATCH], mode=MODE_PREFILL, partition_manager=partition_manager)
out = cls(
key=quantizer(jnp.zeros(shape=kshape, dtype=dtype, device=kshardings)),
value=quantizer(jnp.zeros(shape=vshape, dtype=dtype, device=vshardings)),
indexs=jnp.zeros((metadata.batch_size,), dtype=jnp.int32, device=ishardings),
starts=starts,
metadata=metadata,
layer_index=layer_index,
masking_details=masking_details,
maximum_sequence_length=mt.sequence_length,
)
return out
[docs] @jax.named_scope("easydel-transformer-cacheview-concatenate-to-cache")
def concatenate_to_cache(
self,
query: Float[JAXArray, "batch query_len num_heads head_dim"],
key: Float[JAXArray, "batch query_len num_key_heads key_dim"],
value: Float[JAXArray, "batch query_len num_value_heads value_dim"],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
quantizer: EasyQuantizer,
cache_metadata: TransformerMetadata | None,
mask_info: MaskInfo,
partition_manager: PartitionManager,
) -> tuple[
Float[JAXArray, "batch seq_len num_key_heads key_dim"],
Float[JAXArray, "batch seq_len num_value_heads value_dim"],
MaskInfo,
TransformerCacheView,
AttnMaskDetail | None,
]:
"""
Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.
Args:
query: Current query states.
key: Current key states to add to the cache.
value: Current value states to add to the cache.
cache_metadata: Optional metadata. If provided and contains slot/length info, enables pooled caching.
attention_mask: Base attention mask.
quantizer: Quantizer for the cache.
causal_mask: Optional causal mask.
token_type_ids: Optional token type IDs for segment masking.
Returns:
Tuple[Array, Array, Array]:
- Updated key cache tensor (functional update).
- Updated value cache tensor (functional update).
- Final attention mask to be used (either original or calculated).
"""
from easydel.infra.utils import AttnMaskType
runtime_dtype = query.dtype
num_updated_cache_vectors = query.shape[1]
masking_details = self.masking_details
indexs = self.indexs
sharding_statics = dict(mode=MODE_PREFILL, partition_manager=partition_manager)
def _kv_struct_shard(
x: Float[JAXArray, "batch seq_len num_heads head_dim"],
) -> Float[JAXArray, "batch seq_len num_heads head_dim"]:
return apply_logical_sharding(x, axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM], **sharding_statics)
def _maybe_materialize(x: JAXArray | ImplicitArray) -> JAXArray:
if hasattr(x, "materialize"):
x = x.materialize()
return x
@partial(jax.vmap, in_axes=(0, 0, 0), out_axes=(0))
def _update_kv(
old: Float[JAXArray, "seq_len num_heads head_dim"],
new: Float[JAXArray, "query_len num_heads head_dim"],
slot: Int[JAXArray, ""],
) -> Float[JAXArray, "seq_len num_heads head_dim"]:
return lax.dynamic_update_slice(old, new.astype(old.dtype), (slot, 0, 0))
@partial(jax.vmap, in_axes=(0, 0, 0), out_axes=(0))
def _update_kv_sliding(
old_cache: Float[JAXArray, "window_size num_heads head_dim"],
new_values: Float[JAXArray, "query_len num_heads head_dim"],
current_index: Int[JAXArray, ""],
) -> Float[JAXArray, "window_size num_heads head_dim"]:
"""Update sliding window KV cache."""
new_len = new_values.shape[0]
window_size = old_cache.shape[0]
if new_len >= window_size:
return new_values[-window_size:, :, :].astype(old_cache.dtype)
total_tokens = current_index + new_len
def _fits_in_window():
return lax.dynamic_update_slice(old_cache, new_values.astype(old_cache.dtype), (current_index, 0, 0))
def _overflow_window():
return jnp.concatenate([old_cache[new_len:, :, :], new_values.astype(old_cache.dtype)], axis=0)
return lax.cond(total_tokens <= window_size, _fits_in_window, _overflow_window)
sliding_window = None
if masking_details is not None and masking_details.mask_type == AttnMaskType.SLIDING:
value_cache_updated = _update_kv_sliding(_maybe_materialize(self.value), value, indexs)
key_cache_updated = _update_kv_sliding(_maybe_materialize(self.key), key, indexs)
sliding_window = masking_details.size
else:
value_cache_updated = _update_kv(_maybe_materialize(self.value), value, indexs)
key_cache_updated = _update_kv(_maybe_materialize(self.key), key, indexs)
indexs = indexs + num_updated_cache_vectors
mask_info = mask_info.apply_kv_lengths(
kv_lengths=indexs,
q_len=num_updated_cache_vectors,
sliding_window=sliding_window,
)
value_cache_updated = _kv_struct_shard(value_cache_updated).astype(runtime_dtype)
key_cache_updated = _kv_struct_shard(key_cache_updated).astype(runtime_dtype)
indexs_updated = apply_logical_sharding(indexs, axes=[BATCH], **sharding_statics)
return (
key_cache_updated,
value_cache_updated,
mask_info,
self.replace(key=quantizer(key_cache_updated), value=quantizer(value_cache_updated), indexs=indexs_updated),
masking_details,
)
def __repr__(self):
try:
return (
self.__class__.__name__
+ f"(key={self.key.shape}, value={self.value.shape}, layer_index={self.layer_index})"
)
except AttributeError:
return self.__class__.__name__ + f"(key={self.key}, value={self.value}, layer_index={self.layer_index})"
@property
def is_empty(self) -> bool:
return self.key is None
__str__ = __repr__
[docs]@auto_pytree
class TransformerCache(BaseCache):
"""Multi-layer transformer cache container.
Orchestrates cache views across all transformer layers, providing
methods for initialization, access, and batch operations. Supports
serialization for checkpointing and cache transfer.
The cache maintains:
- Ordered list of per-layer cache views
- Consistent configuration across layers
- Batch update operations
- Serialization/deserialization support
Attributes:
views (list[TransformerCacheView | None]): Cache views for each layer.
None indicates uninitialized layer.
"""
views: list[TransformerCacheView | None]
[docs] @classmethod
def init_cache(
cls,
mesh: Mesh,
metadata: TransformerCacheMetaData,
partition_manager: PartitionManager,
dtype: jnp.dtype | None = None,
starts: Int[JAXArray, "batch"] | None = None, # noqa: F821
quantizer: EasyQuantizer | None = None,
mask_type_details: dict[int, AttnMaskDetail] | None = None,
):
from easydel.layers.quantization.quantizers import EasyQuantizer
quantizer = quantizer or EasyQuantizer(quantization_config=None)
if dtype is None:
dtype = jnp.bfloat16
with mesh:
return cls(
views=[
# i have to somehow fix my OCD
TransformerCacheView.init(
mesh=mesh,
dtype=dtype,
starts=starts,
metadata=metadata,
quantizer=quantizer,
layer_index=layer_index,
masking_details=mask_type_details.get(layer_index) if mask_type_details is not None else None,
partition_manager=partition_manager,
)
for layer_index in range(metadata.num_hidden_layers)
]
)
[docs] def to_pure(self) -> tuple[list[list[JAXArray | ImplicitArray]], TransformerCacheMetaData]:
"""Convert cache to pure Python data structure for serialization.
Extracts raw tensors and metadata for checkpointing or transfer.
The pure representation can be pickled or saved to disk.
Returns:
tuple: Pair of (cache_data, metadata) where:
- cache_data: List of [key, value, indexs, starts] per layer
- metadata: Cache configuration metadata
"""
return (
[[layer.key, layer.value, layer.indexs, layer.starts] for i, layer in enumerate(self.views)],
self.views[-1].metadata,
)
[docs] @classmethod
def from_pure(
cls, pure: list[list[JAXArray | ImplicitArray]], metadata: TransformerCacheMetaData
) -> TransformerCache:
"""Reconstruct cache from pure Python data structure.
Restores a cache from serialized tensors and metadata,
typically after loading from disk or receiving from transfer.
Args:
pure: List of [key, value, indexs, starts] per layer.
metadata: Cache configuration metadata.
Returns:
TransformerCache: Reconstructed cache instance.
"""
return cls(
views=[
TransformerCacheView(
key=layer[0],
value=layer[1],
indexs=layer[2],
starts=layer[3],
metadata=metadata,
)
for layer in pure
]
)
[docs] def insert_starts(
self, starts: Int[JAXArray, "..."], slot: int, partition_manager: PartitionManager
) -> TransformerCache:
"""Insert starting positions at specified batch slot.
Updates the starting position indices for a specific batch slot
across all layers. Used for dynamic batching and cache management.
Args:
starts: New starting positions to insert.
slot (int): Batch slot index to update.
partition_manager (PartitionManager): Sharding configuration.
Returns:
TransformerCache: Updated cache instance.
"""
for idx in range(len(self.views)):
view = self.views[idx]
starts = jnp.array(starts).reshape(-1)
self.views[idx] = self.views[idx].replace(
starts=apply_logical_sharding(
lax.dynamic_update_slice_in_dim(view.starts, starts, slot, 0),
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
)
return self
[docs] def insert_index(
self, index: Int[JAXArray, "..."], slot: int, partition_manager: PartitionManager
) -> TransformerCache:
"""Insert position indices at specified batch slot.
Updates the current position indices for a specific batch slot
across all layers. Used for tracking generation progress.
Args:
index: New position index to insert.
slot (int): Batch slot index to update.
partition_manager (PartitionManager): Sharding configuration.
Returns:
TransformerCache: Updated cache instance.
"""
for idx in range(len(self.views)):
view = self.views[idx]
index = jnp.array(index).reshape(-1)
self.views[idx] = self.views[idx].replace(
indexs=apply_logical_sharding(
lax.dynamic_update_slice_in_dim(view.indexs, index, slot, 0),
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
)
return self
[docs] def insert(
self,
other: TransformerCache,
slot: int,
quantizer: EasyQuantizer,
partition_manager: PartitionManager,
):
"""Insert another cache's contents at specified batch slot.
Copies key-value states and indices from another cache into
this cache at the specified batch position. Useful for
batched generation with different sequences.
Args:
other (TransformerCache): Source cache to copy from.
slot (int): Batch slot index to insert into.
quantizer (EasyQuantizer): Quantization configuration.
partition_manager (PartitionManager): Sharding configuration.
Returns:
TransformerCache: Updated cache instance.
"""
def _maybe_materialize(x: ImplicitArray | JAXArray) -> JAXArray:
if hasattr(x, "materialize"):
x = x.materialize()
return x
for idx in range(len(self.views)):
view = self.views[idx]
oview = other.views[idx]
new_val = lax.dynamic_update_slice(
_maybe_materialize(view.value),
_maybe_materialize(oview.value.astype(view.value.dtype)),
(slot, 0, 0, 0),
)
new_key = lax.dynamic_update_slice(
_maybe_materialize(view.key),
_maybe_materialize(oview.key.astype(view.key.dtype)),
(slot, 0, 0, 0),
)
self.views[idx] = self.views[idx].replace(
key=quantizer(
apply_logical_sharding(
new_key,
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
),
value=quantizer(
apply_logical_sharding(
new_val,
axes=[BATCH, KV_LENGTH, KV_HEAD, KV_HEAD_DIM],
mode=MODE_PREFILL,
partition_manager=partition_manager,
)
),
indexs=apply_logical_sharding(
lax.dynamic_update_slice_in_dim(view.indexs, oview.indexs, slot, 0),
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
),
starts=apply_logical_sharding(
lax.dynamic_update_slice_in_dim(view.starts, oview.starts, slot, 0),
axes=[BATCH],
mode=MODE_PREFILL,
partition_manager=partition_manager,
),
metadata=view.metadata,
)
return self
[docs] @classmethod
def init_empty(cls, num_hidden_layers: int) -> TransformerCache:
return cls(views=[None for _ in range(num_hidden_layers)])
def __repr__(self):
return f"{self.__class__.__name__}(\n " + "\n ".join(str(view) for view in self.views) + "\n)"
__str__ = __repr__
[docs]@auto_pytree
class TransformerMetadata(BaseRunTimeMetadata):
"""Runtime metadata for transformer cache operations.
Holds dynamic information needed during cache updates that isn't
part of the permanent cache state. This includes temporary indices
and flags for specific computation modes.
Attributes:
postpadded (bool): Whether sequences are post-padded.
Affects mask generation and position calculations.
starts (jax.Array | None): Starting positions for sequences.
Used for relative position calculations.
indexs (jax.Array | None): Current position indices.
Tracks generation progress per sequence.
"""
postpadded: bool = field(pytree_node=False, default=False)
starts: Int[JAXArray, "batch"] | None = None # noqa: F821
indexs: Int[JAXArray, "batch"] | None = None # noqa: F821