easydel.layers.caching.transformer.cache

Contents

easydel.layers.caching.transformer.cache#

Transformer key-value caching implementation for EasyDeL.

This module provides the standard key-value caching system for transformer models, supporting various attention patterns including full attention, sliding window attention, and local attention.

The transformer cache is designed for efficient autoregressive generation by storing previously computed key and value states, avoiding redundant computation during inference.

Key Components:
  • TransformerCacheMetaData: Configuration for cache dimensions and behavior

  • TransformerCacheView: Per-layer cache storage and update logic

  • TransformerCache: Multi-layer cache orchestration

  • TransformerMetadata: Runtime metadata for cache operations

  • AttnMaskDetail: Attention masking configuration

Features:
  • Support for multiple attention patterns (full, sliding, local)

  • Quantization support for memory efficiency

  • Distributed caching with JAX sharding

  • Functional cache updates for JAX compatibility

  • Dynamic mask generation and caching

Example

>>> # Initialize cache metadata
>>> metadata = TransformerCacheMetaData.create(
...     batch_size=2,
...     sequence_length=1024,
...     num_hidden_layers=12,
...     pad_token_id=0,
...     num_heads=16,
...     head_dim=64
... )
>>>
>>> # Create cache
>>> cache = TransformerCache.init_cache(
...     mesh=mesh,
...     metadata=metadata,
...     partition_manager=pm,
...     dtype=jnp.bfloat16
... )
>>>
>>> # Update cache during inference
>>> for layer_idx in range(12):
...     key_cache, value_cache, mask, new_view = cache[layer_idx].concatenate_to_cache(
...         query=query_states,
...         key=key_states,
...         value=value_states,
...         attention_mask=attention_mask,
...         quantizer=quantizer,
...         partition_manager=pm
...     )
...     cache[layer_idx] = new_view
class easydel.layers.caching.transformer.cache.AttnMaskDetail(mask_type: Enum, size: int, offset: int | None = None, chunks: int | None = None, bricks: int | None = None)[source]#

Bases: object

Configuration for attention masking patterns.

Defines the type and parameters of attention masking to apply during cache operations. Supports various masking strategies including sliding windows, chunks, and custom patterns.

mask_type#

Type of attention mask (e.g., FULL, SLIDING, CHUNKED).

Type

Enum

size#

Primary size parameter for the mask (window size, chunk size, etc.).

Type

int

offset#

Optional offset for mask positioning.

Type

int | None

chunks#

Number of chunks for chunked attention.

Type

int | None

bricks#

Number of bricks for blocked attention patterns.

Type

int | None

bricks: int | None = None#
chunks: int | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

mask_type: Enum#
offset: int | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

size: int#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.transformer.cache.TransformerCache(views: list[easydel.layers.caching.transformer.cache.TransformerCacheView | None])[source]#

Bases: BaseCache

Multi-layer transformer cache container.

Orchestrates cache views across all transformer layers, providing methods for initialization, access, and batch operations. Supports serialization for checkpointing and cache transfer.

The cache maintains: - Ordered list of per-layer cache views - Consistent configuration across layers - Batch update operations - Serialization/deserialization support

views#

Cache views for each layer. None indicates uninitialized layer.

Type

list[TransformerCacheView | None]

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod from_pure(pure: list[list[jax.Array | eformer.jaximus._imus.ImplicitArray]], metadata: TransformerCacheMetaData) TransformerCache[source]#

Reconstruct cache from pure Python data structure.

Restores a cache from serialized tensors and metadata, typically after loading from disk or receiving from transfer.

Parameters
  • pure – List of [key, value, indexs, starts] per layer.

  • metadata – Cache configuration metadata.

Returns

Reconstructed cache instance.

Return type

TransformerCache

classmethod init_cache(mesh: Mesh, metadata: TransformerCacheMetaData, partition_manager: PartitionManager, dtype: numpy.dtype | None = None, starts: jaxtyping.Int[Array, 'batch'] | None = None, quantizer: object | None = None, mask_type_details: dict[int, easydel.layers.caching.transformer.cache.AttnMaskDetail] | None = None)[source]#

Initialize a complete cache with views for all layers.

This factory method creates a fully initialized cache with allocated storage for all layers. It’s the primary way to create a cache for inference, setting up all necessary views with consistent configuration.

The initialization process: 1. Validates metadata configuration 2. Determines resource allocation strategy 3. Creates views for each layer 4. Applies sharding and quantization 5. Returns ready-to-use cache

Parameters
  • metadata (BaseCacheMetadata) – Configuration metadata defining cache dimensions, number of layers, and behavior.

  • *args – Additional positional arguments. Common args include: - mesh: JAX device mesh for distributed execution - dtype: Default dtype for cache tensors - num_layers: Override for number of layers

  • **kwargs – Additional keyword arguments. Common kwargs include: - partition_manager: Sharding configuration - quantizer: Quantization settings - device: Device placement preferences - initial_positions: Starting positions per layer

Returns

A fully initialized cache with views for all layers.

Ready for use in model inference.

Return type

BaseCache

Raises
  • ValueError – If metadata is incompatible with cache type.

  • MemoryError – If insufficient memory for allocation.

  • RuntimeError – If device/sharding configuration fails.

Example

>>> cache = TransformerCache.init_cache(
...     metadata=metadata,
...     mesh=mesh,
...     dtype=jnp.bfloat16,
...     partition_manager=pm
... )
>>> print(f"Initialized cache with {len(cache)} layers")
classmethod init_empty(num_hidden_layers: int) TransformerCache[source]#

Initialize an empty cache container without allocated storage.

Creates a cache structure with placeholder views that can be populated later. This is useful for: - Gradual cache building during training - Memory-efficient initialization - Dynamic cache allocation - Testing and debugging

The empty cache has the correct structure but no allocated tensors, allowing the shape and configuration to be determined dynamically.

Parameters
  • *args – Positional arguments. Common args include: - num_layers: Number of layers to create placeholders for

  • **kwargs – Keyword arguments for future compatibility.

Returns

A cache instance with uninitialized (None) views.

Views must be populated before use.

Return type

BaseCache

Example

>>> cache = TransformerCache.init_empty(num_hidden_layers=12)
>>> # Populate views gradually
>>> for i in range(12):
...     cache[i] = TransformerCacheView.init(...)
insert(other: TransformerCache, slot: int, quantizer: object, partition_manager: PartitionManager)[source]#

Insert another cache’s contents at specified batch slot.

Copies key-value states and indices from another cache into this cache at the specified batch position. Useful for batched generation with different sequences.

Parameters
  • other (TransformerCache) – Source cache to copy from.

  • slot (int) – Batch slot index to insert into.

  • quantizer (EasyQuantizer) – Quantization configuration.

  • partition_manager (PartitionManager) – Sharding configuration.

Returns

Updated cache instance.

Return type

TransformerCache

insert_index(index: Int[Array, '...'], slot: int, partition_manager: PartitionManager) TransformerCache[source]#

Insert position indices at specified batch slot.

Updates the current position indices for a specific batch slot across all layers. Used for tracking generation progress.

Parameters
  • index – New position index to insert.

  • slot (int) – Batch slot index to update.

  • partition_manager (PartitionManager) – Sharding configuration.

Returns

Updated cache instance.

Return type

TransformerCache

insert_starts(starts: Int[Array, '...'], slot: int, partition_manager: PartitionManager) TransformerCache[source]#

Insert starting positions at specified batch slot.

Updates the starting position indices for a specific batch slot across all layers. Used for dynamic batching and cache management.

Parameters
  • starts – New starting positions to insert.

  • slot (int) – Batch slot index to update.

  • partition_manager (PartitionManager) – Sharding configuration.

Returns

Updated cache instance.

Return type

TransformerCache

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

to_pure() tuple[list[list[jax.Array | eformer.jaximus._imus.ImplicitArray]], easydel.layers.caching.transformer.cache.TransformerCacheMetaData][source]#

Convert cache to pure Python data structure for serialization.

Extracts raw tensors and metadata for checkpointing or transfer. The pure representation can be pickled or saved to disk.

Returns

Pair of (cache_data, metadata) where:
  • cache_data: List of [key, value, indexs, starts] per layer

  • metadata: Cache configuration metadata

Return type

tuple

views: list[easydel.layers.caching.transformer.cache.TransformerCacheView | None]#
class easydel.layers.caching.transformer.cache.TransformerCacheMetaData(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: int | None, head_dim: int | None, key_heads: int | None, value_heads: int | None, key_dim: int | None, value_dim: int | None, sliding_window: int | None, update_causal_mask: bool, create_attention_bias: bool)[source]#

Bases: BaseCacheMetadata

Metadata configuration for transformer key-value caching.

Stores all static configuration needed to initialize and operate a transformer cache. Supports various attention head configurations including multi-head, multi-query, and grouped-query attention.

The metadata defines: - Cache dimensions (batch, sequence, layers) - Attention head configuration - Masking and bias settings - Special attention patterns (sliding window)

batch_size#

Number of sequences in batch.

Type

int

sequence_length#

Maximum sequence length to cache.

Type

int

num_hidden_layers#

Number of transformer layers.

Type

int

pad_token_id#

Token ID used for padding.

Type

int

num_heads#

Number of attention heads (for regular MHA).

Type

int | None

head_dim#

Dimension of each attention head.

Type

int | None

key_heads#

Number of key heads (for MQA/GQA).

Type

int | None

value_heads#

Number of value heads (for MQA/GQA).

Type

int | None

key_dim#

Dimension of key projections.

Type

int | None

value_dim#

Dimension of value projections.

Type

int | None

sliding_window#

Size of sliding attention window.

Type

int | None

update_causal_mask#

Whether to update causal masks dynamically.

Type

bool

create_attention_bias#

Whether to create attention bias terms.

Type

bool

batch_size: int#
classmethod create(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: int | None = None, head_dim: int | None = None, key_heads: int | None = None, value_heads: int | None = None, key_dim: int | None = None, value_dim: int | None = None, update_causal_mask: bool = True, create_attention_bias: bool = True, sliding_window: int | None = None) TransformerCacheMetaData[source]#

Create a TransformerCacheMetaData instance with validation.

Parameters
  • batch_size – Size of the batch.

  • sequence_length – Length of the sequence.

  • num_hidden_layers – number of hidden layers.

  • num_heads – Number of attention heads.

  • head_dim – Dimension of each head.

  • key_heads – Number of key heads.

  • value_heads – Number of value heads.

  • key_dim – Dimension of keys.

  • value_dim – Dimension of values.

  • update_causal_mask – Whether to update causal mask.

  • create_attention_bias – Whether to create attention bias.

Returns

TransformerCacheMetaData instance

Raises

ValueError – If required parameters are missing or invalid.

create_attention_bias: bool#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

head_dim: int | None#
key_dim: int | None#
key_heads: int | None#
num_heads: int | None#
num_hidden_layers: int#
pad_token_id: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sequence_length: int#
sliding_window: int | None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_causal_mask: bool#
value_dim: int | None#
value_heads: int | None#
class easydel.layers.caching.transformer.cache.TransformerCacheView(key: jaxtyping.Float[Array, 'batch seq_len num_key_heads key_dim'] | eformer.jaximus._imus.ImplicitArray, value: jaxtyping.Float[Array, 'batch seq_len num_value_heads value_dim'] | eformer.jaximus._imus.ImplicitArray, indexs: jaxtyping.Int[Array, 'batch'] | eformer.jaximus._imus.ImplicitArray, starts: jaxtyping.Int[Array, 'batch'] | eformer.jaximus._imus.ImplicitArray, metadata: TransformerCacheMetaData, maximum_sequence_length: int, layer_index: int | None = None, masking_details: easydel.layers.caching.transformer.cache.AttnMaskDetail | None = None)[source]#

Bases: BaseCacheView

Single-layer cache view for transformer key-value states.

Manages the cached key and value tensors for one transformer layer, along with position tracking and masking information. Supports various attention patterns and quantization strategies.

The view maintains: - Key and value state tensors - Current position indices for each sequence - Starting positions for relative indexing - Masking configuration for attention patterns

key#

Cached key states. Shape: [batch_size, seq_length, num_key_heads, key_dim]

Type

cx.Array | ImplicitArray

value#

Cached value states. Shape: [batch_size, seq_length, num_value_heads, value_dim]

Type

cx.Array | ImplicitArray

indexs#

Current position index per sequence. Shape: [batch_size]

Type

cx.Array | ImplicitArray

starts#

Starting position per sequence. Shape: [batch_size]

Type

cx.Array | ImplicitArray

metadata#

Static cache configuration.

Type

TransformerCacheMetaData

maximum_sequence_length#

Maximum cacheable sequence length.

Type

int

layer_index#

Index of this layer in the model.

Type

int | None

masking_details#

Attention mask configuration.

Type

AttnMaskDetail | None

concatenate_to_cache(query: Float[Array, 'batch query_len num_heads head_dim'], key: Float[Array, 'batch query_len num_key_heads key_dim'], value: Float[Array, 'batch query_len num_value_heads value_dim'], mode: Literal['__autoregressive__', '__prefill__', '__train__', '__insert__'], quantizer: object, cache_metadata: easydel.layers.caching.transformer.cache.TransformerMetadata | None, mask_info: MaskInfo, partition_manager: PartitionManager) tuple[jaxtyping.Float[Array, 'batch seq_len num_key_heads key_dim'], jaxtyping.Float[Array, 'batch seq_len num_value_heads value_dim'], ejkernel.types.mask.MaskInfo, easydel.layers.caching.transformer.cache.TransformerCacheView, easydel.layers.caching.transformer.cache.AttnMaskDetail | None][source]#

Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.

Parameters
  • query – Current query states.

  • key – Current key states to add to the cache.

  • value – Current value states to add to the cache.

  • cache_metadata – Optional metadata. If provided and contains slot/length info, enables pooled caching.

  • attention_mask – Base attention mask.

  • quantizer – Quantizer for the cache.

  • causal_mask – Optional causal mask.

  • token_type_ids – Optional token type IDs for segment masking.

Returns

  • Updated key cache tensor (functional update).

  • Updated value cache tensor (functional update).

  • Final attention mask to be used (either original or calculated).

Return type

Tuple[Array, Array, Array]

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

indexs: jaxtyping.Int[Array, 'batch'] | eformer.jaximus._imus.ImplicitArray#
classmethod init(mesh: Mesh, dtype: dtype, metadata: TransformerCacheMetaData, quantizer: object, partition_manager: PartitionManager, starts: jaxtyping.Int[Array, 'batch'] | None = None, layer_index: int | None = None, masking_details: easydel.layers.caching.transformer.cache.AttnMaskDetail | None = None)[source]#

Initialize a transformer cache view for a single layer.

Creates and allocates cache tensors with appropriate shapes, dtypes, and sharding for distributed execution. Applies quantization if configured.

Parameters
  • mesh (Mesh) – JAX device mesh for distributed execution.

  • dtype (jnp.dtype) – Data type for cache tensors.

  • metadata (TransformerCacheMetaData) – Cache configuration.

  • quantizer (EasyQuantizer) – Quantization configuration.

  • partition_manager (PartitionManager) – Sharding strategy manager.

  • starts (jax.Array | None) – Initial starting positions per sequence. Defaults to zeros if not provided.

  • layer_index (int | None) – Index of this layer in the model.

  • masking_details (AttnMaskDetail | None) – Attention mask configuration.

Returns

Initialized cache view with allocated tensors.

Return type

TransformerCacheView

Note

For sliding window attention, cache dimensions are adjusted based on the window size specified in masking_details.

property is_empty: bool#
key: jaxtyping.Float[Array, 'batch seq_len num_key_heads key_dim'] | eformer.jaximus._imus.ImplicitArray#
layer_index: int | None = None#
masking_details: easydel.layers.caching.transformer.cache.AttnMaskDetail | None = None#
maximum_sequence_length: int#
metadata: TransformerCacheMetaData#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

starts: jaxtyping.Int[Array, 'batch'] | eformer.jaximus._imus.ImplicitArray#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

value: jaxtyping.Float[Array, 'batch seq_len num_value_heads value_dim'] | eformer.jaximus._imus.ImplicitArray#
class easydel.layers.caching.transformer.cache.TransformerMetadata(postpadded: bool = False, starts: jaxtyping.Int[Array, 'batch'] | None = None, indexs: jaxtyping.Int[Array, 'batch'] | None = None)[source]#

Bases: BaseRunTimeMetadata

Runtime metadata for transformer cache operations.

Holds dynamic information needed during cache updates that isn’t part of the permanent cache state. This includes temporary indices and flags for specific computation modes.

postpadded#

Whether sequences are post-padded. Affects mask generation and position calculations.

Type

bool

starts#

Starting positions for sequences. Used for relative position calculations.

Type

jax.Array | None

indexs#

Current position indices. Tracks generation progress per sequence.

Type

jax.Array | None

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

indexs: jaxtyping.Int[Array, 'batch'] | None = None#
postpadded: bool = False#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

starts: jaxtyping.Int[Array, 'batch'] | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.