easydel.layers.caching.transformer.cache#
Transformer key-value caching implementation for EasyDeL.
This module provides the standard key-value caching system for transformer models, supporting various attention patterns including full attention, sliding window attention, and local attention.
The transformer cache is designed for efficient autoregressive generation by storing previously computed key and value states, avoiding redundant computation during inference.
- Key Components:
TransformerCacheMetaData: Configuration for cache dimensions and behavior
TransformerCacheView: Per-layer cache storage and update logic
TransformerCache: Multi-layer cache orchestration
TransformerMetadata: Runtime metadata for cache operations
AttnMaskDetail: Attention masking configuration
- Features:
Support for multiple attention patterns (full, sliding, local)
Quantization support for memory efficiency
Distributed caching with JAX sharding
Functional cache updates for JAX compatibility
Dynamic mask generation and caching
Example
>>> # Initialize cache metadata
>>> metadata = TransformerCacheMetaData.create(
... batch_size=2,
... sequence_length=1024,
... num_hidden_layers=12,
... pad_token_id=0,
... num_heads=16,
... head_dim=64
... )
>>>
>>> # Create cache
>>> cache = TransformerCache.init_cache(
... mesh=mesh,
... metadata=metadata,
... partition_manager=pm,
... dtype=jnp.bfloat16
... )
>>>
>>> # Update cache during inference
>>> for layer_idx in range(12):
... key_cache, value_cache, mask, new_view = cache[layer_idx].concatenate_to_cache(
... query=query_states,
... key=key_states,
... value=value_states,
... attention_mask=attention_mask,
... quantizer=quantizer,
... partition_manager=pm
... )
... cache[layer_idx] = new_view
- class easydel.layers.caching.transformer.cache.AttnMaskDetail(mask_type: Enum, size: int, offset: int | None = None, chunks: int | None = None, bricks: int | None = None)[source]#
Bases:
objectConfiguration for attention masking patterns.
Defines the type and parameters of attention masking to apply during cache operations. Supports various masking strategies including sliding windows, chunks, and custom patterns.
- mask_type#
Type of attention mask (e.g., FULL, SLIDING, CHUNKED).
- Type
Enum
- size#
Primary size parameter for the mask (window size, chunk size, etc.).
- Type
int
- offset#
Optional offset for mask positioning.
- Type
int | None
- chunks#
Number of chunks for chunked attention.
- Type
int | None
- bricks#
Number of bricks for blocked attention patterns.
- Type
int | None
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- mask_type: Enum#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- size: int#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.transformer.cache.TransformerCache(views: list[easydel.layers.caching.transformer.cache.TransformerCacheView | None])[source]#
Bases:
BaseCacheMulti-layer transformer cache container.
Orchestrates cache views across all transformer layers, providing methods for initialization, access, and batch operations. Supports serialization for checkpointing and cache transfer.
The cache maintains: - Ordered list of per-layer cache views - Consistent configuration across layers - Batch update operations - Serialization/deserialization support
- views#
Cache views for each layer. None indicates uninitialized layer.
- Type
list[TransformerCacheView | None]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod from_pure(pure: list[list[jax.Array | eformer.jaximus._imus.ImplicitArray]], metadata: TransformerCacheMetaData) TransformerCache[source]#
Reconstruct cache from pure Python data structure.
Restores a cache from serialized tensors and metadata, typically after loading from disk or receiving from transfer.
- Parameters
pure – List of [key, value, indexs, starts] per layer.
metadata – Cache configuration metadata.
- Returns
Reconstructed cache instance.
- Return type
- classmethod init_cache(mesh: Mesh, metadata: TransformerCacheMetaData, partition_manager: PartitionManager, dtype: numpy.dtype | None = None, starts: jaxtyping.Int[Array, 'batch'] | None = None, quantizer: object | None = None, mask_type_details: dict[int, easydel.layers.caching.transformer.cache.AttnMaskDetail] | None = None)[source]#
Initialize a complete cache with views for all layers.
This factory method creates a fully initialized cache with allocated storage for all layers. It’s the primary way to create a cache for inference, setting up all necessary views with consistent configuration.
The initialization process: 1. Validates metadata configuration 2. Determines resource allocation strategy 3. Creates views for each layer 4. Applies sharding and quantization 5. Returns ready-to-use cache
- Parameters
metadata (BaseCacheMetadata) – Configuration metadata defining cache dimensions, number of layers, and behavior.
*args – Additional positional arguments. Common args include: - mesh: JAX device mesh for distributed execution - dtype: Default dtype for cache tensors - num_layers: Override for number of layers
**kwargs – Additional keyword arguments. Common kwargs include: - partition_manager: Sharding configuration - quantizer: Quantization settings - device: Device placement preferences - initial_positions: Starting positions per layer
- Returns
- A fully initialized cache with views for all layers.
Ready for use in model inference.
- Return type
- Raises
ValueError – If metadata is incompatible with cache type.
MemoryError – If insufficient memory for allocation.
RuntimeError – If device/sharding configuration fails.
Example
>>> cache = TransformerCache.init_cache( ... metadata=metadata, ... mesh=mesh, ... dtype=jnp.bfloat16, ... partition_manager=pm ... ) >>> print(f"Initialized cache with {len(cache)} layers")
- classmethod init_empty(num_hidden_layers: int) TransformerCache[source]#
Initialize an empty cache container without allocated storage.
Creates a cache structure with placeholder views that can be populated later. This is useful for: - Gradual cache building during training - Memory-efficient initialization - Dynamic cache allocation - Testing and debugging
The empty cache has the correct structure but no allocated tensors, allowing the shape and configuration to be determined dynamically.
- Parameters
*args – Positional arguments. Common args include: - num_layers: Number of layers to create placeholders for
**kwargs – Keyword arguments for future compatibility.
- Returns
- A cache instance with uninitialized (None) views.
Views must be populated before use.
- Return type
Example
>>> cache = TransformerCache.init_empty(num_hidden_layers=12) >>> # Populate views gradually >>> for i in range(12): ... cache[i] = TransformerCacheView.init(...)
- insert(other: TransformerCache, slot: int, quantizer: object, partition_manager: PartitionManager)[source]#
Insert another cache’s contents at specified batch slot.
Copies key-value states and indices from another cache into this cache at the specified batch position. Useful for batched generation with different sequences.
- Parameters
other (TransformerCache) – Source cache to copy from.
slot (int) – Batch slot index to insert into.
quantizer (EasyQuantizer) – Quantization configuration.
partition_manager (PartitionManager) – Sharding configuration.
- Returns
Updated cache instance.
- Return type
- insert_index(index: Int[Array, '...'], slot: int, partition_manager: PartitionManager) TransformerCache[source]#
Insert position indices at specified batch slot.
Updates the current position indices for a specific batch slot across all layers. Used for tracking generation progress.
- Parameters
index – New position index to insert.
slot (int) – Batch slot index to update.
partition_manager (PartitionManager) – Sharding configuration.
- Returns
Updated cache instance.
- Return type
- insert_starts(starts: Int[Array, '...'], slot: int, partition_manager: PartitionManager) TransformerCache[source]#
Insert starting positions at specified batch slot.
Updates the starting position indices for a specific batch slot across all layers. Used for dynamic batching and cache management.
- Parameters
starts – New starting positions to insert.
slot (int) – Batch slot index to update.
partition_manager (PartitionManager) – Sharding configuration.
- Returns
Updated cache instance.
- Return type
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- to_pure() tuple[list[list[jax.Array | eformer.jaximus._imus.ImplicitArray]], easydel.layers.caching.transformer.cache.TransformerCacheMetaData][source]#
Convert cache to pure Python data structure for serialization.
Extracts raw tensors and metadata for checkpointing or transfer. The pure representation can be pickled or saved to disk.
- Returns
- Pair of (cache_data, metadata) where:
cache_data: List of [key, value, indexs, starts] per layer
metadata: Cache configuration metadata
- Return type
tuple
- class easydel.layers.caching.transformer.cache.TransformerCacheMetaData(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: int | None, head_dim: int | None, key_heads: int | None, value_heads: int | None, key_dim: int | None, value_dim: int | None, sliding_window: int | None, update_causal_mask: bool, create_attention_bias: bool)[source]#
Bases:
BaseCacheMetadataMetadata configuration for transformer key-value caching.
Stores all static configuration needed to initialize and operate a transformer cache. Supports various attention head configurations including multi-head, multi-query, and grouped-query attention.
The metadata defines: - Cache dimensions (batch, sequence, layers) - Attention head configuration - Masking and bias settings - Special attention patterns (sliding window)
- batch_size#
Number of sequences in batch.
- Type
int
- sequence_length#
Maximum sequence length to cache.
- Type
int
Number of transformer layers.
- Type
int
- pad_token_id#
Token ID used for padding.
- Type
int
- num_heads#
Number of attention heads (for regular MHA).
- Type
int | None
- head_dim#
Dimension of each attention head.
- Type
int | None
- key_heads#
Number of key heads (for MQA/GQA).
- Type
int | None
- value_heads#
Number of value heads (for MQA/GQA).
- Type
int | None
- key_dim#
Dimension of key projections.
- Type
int | None
- value_dim#
Dimension of value projections.
- Type
int | None
- sliding_window#
Size of sliding attention window.
- Type
int | None
- update_causal_mask#
Whether to update causal masks dynamically.
- Type
bool
- create_attention_bias#
Whether to create attention bias terms.
- Type
bool
- batch_size: int#
- classmethod create(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: int | None = None, head_dim: int | None = None, key_heads: int | None = None, value_heads: int | None = None, key_dim: int | None = None, value_dim: int | None = None, update_causal_mask: bool = True, create_attention_bias: bool = True, sliding_window: int | None = None) TransformerCacheMetaData[source]#
Create a TransformerCacheMetaData instance with validation.
- Parameters
batch_size – Size of the batch.
sequence_length – Length of the sequence.
num_hidden_layers – number of hidden layers.
num_heads – Number of attention heads.
head_dim – Dimension of each head.
key_heads – Number of key heads.
value_heads – Number of value heads.
key_dim – Dimension of keys.
value_dim – Dimension of values.
update_causal_mask – Whether to update causal mask.
create_attention_bias – Whether to create attention bias.
- Returns
TransformerCacheMetaData instance
- Raises
ValueError – If required parameters are missing or invalid.
- create_attention_bias: bool#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- num_hidden_layers: int#
- pad_token_id: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sequence_length: int#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_causal_mask: bool#
- class easydel.layers.caching.transformer.cache.TransformerCacheView(key: jaxtyping.Float[Array, 'batch seq_len num_key_heads key_dim'] | eformer.jaximus._imus.ImplicitArray, value: jaxtyping.Float[Array, 'batch seq_len num_value_heads value_dim'] | eformer.jaximus._imus.ImplicitArray, indexs: jaxtyping.Int[Array, 'batch'] | eformer.jaximus._imus.ImplicitArray, starts: jaxtyping.Int[Array, 'batch'] | eformer.jaximus._imus.ImplicitArray, metadata: TransformerCacheMetaData, maximum_sequence_length: int, layer_index: int | None = None, masking_details: easydel.layers.caching.transformer.cache.AttnMaskDetail | None = None)[source]#
Bases:
BaseCacheViewSingle-layer cache view for transformer key-value states.
Manages the cached key and value tensors for one transformer layer, along with position tracking and masking information. Supports various attention patterns and quantization strategies.
The view maintains: - Key and value state tensors - Current position indices for each sequence - Starting positions for relative indexing - Masking configuration for attention patterns
- key#
Cached key states. Shape: [batch_size, seq_length, num_key_heads, key_dim]
- Type
cx.Array | ImplicitArray
- value#
Cached value states. Shape: [batch_size, seq_length, num_value_heads, value_dim]
- Type
cx.Array | ImplicitArray
- indexs#
Current position index per sequence. Shape: [batch_size]
- Type
cx.Array | ImplicitArray
- starts#
Starting position per sequence. Shape: [batch_size]
- Type
cx.Array | ImplicitArray
- metadata#
Static cache configuration.
- maximum_sequence_length#
Maximum cacheable sequence length.
- Type
int
- layer_index#
Index of this layer in the model.
- Type
int | None
- masking_details#
Attention mask configuration.
- Type
AttnMaskDetail | None
- concatenate_to_cache(query: Float[Array, 'batch query_len num_heads head_dim'], key: Float[Array, 'batch query_len num_key_heads key_dim'], value: Float[Array, 'batch query_len num_value_heads value_dim'], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], quantizer: object, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | None, mask_info: MaskInfo, partition_manager: PartitionManager) tuple[jaxtyping.Float[Array, 'batch seq_len num_key_heads key_dim'], jaxtyping.Float[Array, 'batch seq_len num_value_heads value_dim'], ejkernel.types.mask.MaskInfo, easydel.layers.caching.transformer.cache.TransformerCacheView, easydel.layers.caching.transformer.cache.AttnMaskDetail | None][source]#
Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.
- Parameters
query – Current query states.
key – Current key states to add to the cache.
value – Current value states to add to the cache.
cache_metadata – Optional metadata. If provided and contains slot/length info, enables pooled caching.
attention_mask – Base attention mask.
quantizer – Quantizer for the cache.
causal_mask – Optional causal mask.
token_type_ids – Optional token type IDs for segment masking.
- Returns
Updated key cache tensor (functional update).
Updated value cache tensor (functional update).
Final attention mask to be used (either original or calculated).
- Return type
Tuple[Array, Array, Array]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- indexs: jaxtyping.Int[Array, 'batch'] | eformer.jaximus._imus.ImplicitArray#
- classmethod init(mesh: Mesh, dtype: dtype, metadata: TransformerCacheMetaData, quantizer: object, partition_manager: PartitionManager, starts: jaxtyping.Int[Array, 'batch'] | None = None, layer_index: int | None = None, masking_details: easydel.layers.caching.transformer.cache.AttnMaskDetail | None = None)[source]#
Initialize a transformer cache view for a single layer.
Creates and allocates cache tensors with appropriate shapes, dtypes, and sharding for distributed execution. Applies quantization if configured.
- Parameters
mesh (Mesh) – JAX device mesh for distributed execution.
dtype (jnp.dtype) – Data type for cache tensors.
metadata (TransformerCacheMetaData) – Cache configuration.
quantizer (EasyQuantizer) – Quantization configuration.
partition_manager (PartitionManager) – Sharding strategy manager.
starts (jax.Array | None) – Initial starting positions per sequence. Defaults to zeros if not provided.
layer_index (int | None) – Index of this layer in the model.
masking_details (AttnMaskDetail | None) – Attention mask configuration.
- Returns
Initialized cache view with allocated tensors.
- Return type
Note
For sliding window attention, cache dimensions are adjusted based on the window size specified in masking_details.
- property is_empty: bool#
- key: jaxtyping.Float[Array, 'batch seq_len num_key_heads key_dim'] | eformer.jaximus._imus.ImplicitArray#
- masking_details: easydel.layers.caching.transformer.cache.AttnMaskDetail | None = None#
- maximum_sequence_length: int#
- metadata: TransformerCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- starts: jaxtyping.Int[Array, 'batch'] | eformer.jaximus._imus.ImplicitArray#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- value: jaxtyping.Float[Array, 'batch seq_len num_value_heads value_dim'] | eformer.jaximus._imus.ImplicitArray#
- class easydel.layers.caching.transformer.cache.TransformerMetadata(postpadded: bool = False, starts: jaxtyping.Int[Array, 'batch'] | None = None, indexs: jaxtyping.Int[Array, 'batch'] | None = None)[source]#
Bases:
BaseRunTimeMetadataRuntime metadata for transformer cache operations.
Holds dynamic information needed during cache updates that isn’t part of the permanent cache state. This includes temporary indices and flags for specific computation modes.
- postpadded#
Whether sequences are post-padded. Affects mask generation and position calculations.
- Type
bool
- starts#
Starting positions for sequences. Used for relative position calculations.
- Type
jax.Array | None
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- postpadded: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.