easydel.layers.caching.lightning.cache#
Lightning attention cache implementation for EasyDeL.
This module provides a specialized caching system for Lightning attention, which uses a unified key-value representation for improved efficiency. Lightning attention combines keys and values into a single tensor, reducing memory bandwidth requirements.
- Key Components:
LightningCacheMetaData: Configuration for Lightning cache
LightningCacheView: Per-layer Lightning cache storage
LightningCache: Multi-layer Lightning cache container
LightningMetadata: Runtime metadata (placeholder)
- Features:
Unified KV tensor representation
Reduced memory bandwidth usage
Compatible with Lightning attention kernels
Supports standard transformer operations
Example
>>> metadata = LightningCacheMetaData.create(
... partition_axis=partition_axis,
... batch_size=2,
... num_heads=16,
... head_dim=64
... )
>>> cache = LightningCache.init_cache(
... num_hidden_layers=12,
... metadata=metadata
... )
- class easydel.layers.caching.lightning.cache.LightningCache(views: list[easydel.layers.caching.lightning.cache.LightningCacheView | None])[source]#
Bases:
BaseCacheMulti-layer Lightning attention cache container.
Orchestrates Lightning cache views across all model layers, providing unified management of the specialized Lightning attention cache format.
- views#
Per-layer cache views. None for uninitialized layers.
- Type
list[LightningCacheView | None]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(num_hidden_layers: int, metadata: LightningCacheMetaData) LightningCache[source]#
Initialize Lightning cache for all model layers.
Creates cache views for each layer with consistent configuration. Views are initialized with placeholders; actual tensors are allocated on first use.
- Parameters
num_hidden_layers (int) – Number of layers in the model.
metadata (LightningCacheMetaData) – Cache configuration.
- Returns
Initialized cache with views for all layers.
- Return type
- classmethod init_empty(num_hidden_layers: int) LightningCache[source]#
Initialize empty Lightning cache structure.
Creates cache container with None placeholders for all layers. Useful for gradual cache building or testing.
- Parameters
num_hidden_layers (int) – Number of layer slots to create.
- Returns
Empty cache structure.
- Return type
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- views: list[easydel.layers.caching.lightning.cache.LightningCacheView | None]#
- class easydel.layers.caching.lightning.cache.LightningCacheMetaData(partition_axis: PartitionAxis, batch_size: int | None, num_heads: int | None, head_dim: int | None, key_heads: int | None, value_heads: int | None, key_dim: int | None, value_dim: int | None)[source]#
Bases:
BaseCacheMetadataMetadata configuration for Lightning attention cache.
Stores configuration parameters specific to Lightning attention, which uses a unified key-value representation. Similar to standard transformer cache but optimized for Lightning’s memory access patterns.
- partition_axis#
Axis configuration for tensor partitioning. Defines how tensors are sharded across devices.
- Type
es.PartitionAxis
- batch_size#
Number of sequences in batch. None allows dynamic batch sizes.
- Type
int | None
- num_heads#
Number of attention heads. Used for standard multi-head attention.
- Type
int | None
- head_dim#
Dimension of each attention head. Defines the feature size per head.
- Type
int | None
- key_heads#
Number of key heads. For multi-query or grouped-query attention.
- Type
int | None
- value_heads#
Number of value heads. For multi-query or grouped-query attention.
- Type
int | None
- key_dim#
Dimension of key projections. Can differ from head_dim for asymmetric attention.
- Type
int | None
- value_dim#
Dimension of value projections. Can differ from head_dim for asymmetric attention.
- Type
int | None
- classmethod create(partition_axis: PartitionAxis, batch_size: int | None = None, num_heads: int | None = None, head_dim: int | None = None, key_heads: int | None = None, value_heads: int | None = None, key_dim: int | None = None, value_dim: int | None = None) LightningCacheMetaData[source]#
Create and validate Lightning cache metadata.
Factory method for creating Lightning cache configuration. Unlike standard transformer cache, Lightning allows more flexibility in parameters as it handles unified KV tensors.
- Parameters
partition_axis (es.PartitionAxis) – Tensor partitioning configuration.
batch_size (int | None) – Batch size for cache allocation. None for dynamic batching.
num_heads (int | None) – Number of attention heads. Defaults to None for flexible head configuration.
head_dim (int | None) – Dimension per attention head. Defaults to None for flexible dimensions.
key_heads (int | None) – Number of key heads for MQA/GQA. Defaults to None (same as num_heads).
value_heads (int | None) – Number of value heads for MQA/GQA. Defaults to None (same as num_heads).
key_dim (int | None) – Key projection dimension. Defaults to None (same as head_dim).
value_dim (int | None) – Value projection dimension. Defaults to None (same as head_dim).
- Returns
Configured metadata instance.
- Return type
Note
Lightning attention’s unified representation means some parameters may be handled differently than standard cache.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- partition_axis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.lightning.cache.LightningCacheView(key_value: jaxtyping.Float[Array, 'batch seq_len num_heads head_dim'] | eformer.jaximus._imus.ImplicitArray | None, metadata: LightningCacheMetaData, layer_index: int | None = None)[source]#
Bases:
BaseCacheViewSingle-layer cache view for Lightning attention.
Manages the unified key-value cache for one layer using Lightning’s optimized representation. Unlike standard caches that store keys and values separately, Lightning combines them for better memory efficiency.
- key_value#
Unified key-value tensor. Lightning’s special representation combining K and V.
- Type
cx.Array | ImplicitArray
- metadata#
Static cache configuration.
- layer_index#
Index of this layer in the model.
- Type
int | None
Note
The unified representation requires special handling during concatenation and may not be compatible with standard attention.
- concatenate_to_cache(query: Float[Array, 'batch query_len num_heads head_dim'], key: Float[Array, 'batch query_len num_key_heads key_dim'], value: Float[Array, 'batch query_len num_value_heads value_dim'], attention_mask: jaxtyping.Bool[Array, 'batch 1 query_len seq_len'] | jaxtyping.Float[Array, 'batch 1 query_len seq_len'], kv_sharding: PartitionSpec, quantizer: object, causal_mask: jaxtyping.Bool[Array, 'batch 1 query_len seq_len'] | bool | None = None, token_type_ids: jaxtyping.Int[Array, 'batch query_len'] | None = None) tuple[jaxtyping.Float[Array, 'batch seq_len num_key_heads key_dim'], jaxtyping.Float[Array, 'batch seq_len num_value_heads value_dim'], jaxtyping.Bool[Array, 'batch 1 query_len seq_len']][source]#
Update cache with new key/value states for Lightning attention.
Concatenates new key and value states to the cache using Lightning’s unified KV representation. This method is called during each forward pass to update the cache with newly computed states.
Note: This implementation appears to need refactoring as it references attributes (self.index, self.key, self.value) that don’t exist in the current class definition. The actual Lightning implementation would use the unified self.key_value tensor.
- Parameters
query – Query tensor with shape [batch, query_len, num_heads, head_dim]. Used to determine update dimensions.
key – Key tensor with shape [batch, query_len, num_key_heads, key_dim]. New keys to add to the cache.
value – Value tensor with shape [batch, query_len, num_value_heads, value_dim]. New values to add to the cache.
attention_mask – Boolean or float mask with shape [batch, 1, query_len, seq_len]. Defines which positions can attend to which.
kv_sharding – JAX PartitionSpec for sharding the KV cache.
quantizer – Quantization function for cache compression.
causal_mask – Optional causal mask for autoregressive attention. Can be boolean array or boolean value.
token_type_ids – Optional token type IDs for segment-level masking. Shape [batch, query_len].
- Returns
Updated key cache: Float[Array, “batch seq_len num_key_heads key_dim”]
Updated value cache: Float[Array, “batch seq_len num_value_heads value_dim”]
Updated attention mask: Bool[Array, “batch 1 query_len seq_len”]
- Return type
Tuple containing
- Raises
NotImplementedError – Current implementation needs refactoring to properly use Lightning’s unified KV representation.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(metadata: LightningCacheMetaData, layer_index: int | None = None) LightningCacheView[source]#
Initialize a Lightning cache view for a single layer.
Creates a cache view with placeholder for unified KV tensor. Actual tensor allocation happens during first use to allow for dynamic shapes.
- Parameters
metadata (LightningCacheMetaData) – Cache configuration.
layer_index (int | None) – Layer index in the model.
- Returns
Initialized view with None placeholder.
- Return type
- key_value: jaxtyping.Float[Array, 'batch seq_len num_heads head_dim'] | eformer.jaximus._imus.ImplicitArray | None#
- metadata: LightningCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.lightning.cache.LightningMetadata[source]#
Bases:
BaseRunTimeMetadataRuntime metadata for Lightning attention cache operations.
Placeholder class for future Lightning-specific runtime metadata. Currently empty but reserved for Lightning-specific runtime state.