easydel.layers.caching.lightning.cache#

Lightning attention cache implementation for EasyDeL.

This module provides a specialized caching system for Lightning attention, which uses a unified key-value representation for improved efficiency. Lightning attention combines keys and values into a single tensor, reducing memory bandwidth requirements.

Key Components:
  • LightningCacheMetaData: Configuration for Lightning cache

  • LightningCacheView: Per-layer Lightning cache storage

  • LightningCache: Multi-layer Lightning cache container

  • LightningMetadata: Runtime metadata (placeholder)

Features:
  • Unified KV tensor representation

  • Reduced memory bandwidth usage

  • Compatible with Lightning attention kernels

  • Supports standard transformer operations

Example

>>> metadata = LightningCacheMetaData.create(
...     partition_axis=partition_axis,
...     batch_size=2,
...     num_heads=16,
...     head_dim=64
... )
>>> cache = LightningCache.init_cache(
...     num_hidden_layers=12,
...     metadata=metadata
... )
class easydel.layers.caching.lightning.cache.LightningCache(views: list[easydel.layers.caching.lightning.cache.LightningCacheView | None])[source]#

Bases: BaseCache

Multi-layer Lightning attention cache container.

Orchestrates Lightning cache views across all model layers, providing unified management of the specialized Lightning attention cache format.

views#

Per-layer cache views. None for uninitialized layers.

Type

list[LightningCacheView | None]

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(num_hidden_layers: int, metadata: LightningCacheMetaData) LightningCache[source]#

Initialize Lightning cache for all model layers.

Creates cache views for each layer with consistent configuration. Views are initialized with placeholders; actual tensors are allocated on first use.

Parameters
  • num_hidden_layers (int) – Number of layers in the model.

  • metadata (LightningCacheMetaData) – Cache configuration.

Returns

Initialized cache with views for all layers.

Return type

LightningCache

classmethod init_empty(num_hidden_layers: int) LightningCache[source]#

Initialize empty Lightning cache structure.

Creates cache container with None placeholders for all layers. Useful for gradual cache building or testing.

Parameters

num_hidden_layers (int) – Number of layer slots to create.

Returns

Empty cache structure.

Return type

LightningCache

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

views: list[easydel.layers.caching.lightning.cache.LightningCacheView | None]#
class easydel.layers.caching.lightning.cache.LightningCacheMetaData(partition_axis: PartitionAxis, batch_size: int | None, num_heads: int | None, head_dim: int | None, key_heads: int | None, value_heads: int | None, key_dim: int | None, value_dim: int | None)[source]#

Bases: BaseCacheMetadata

Metadata configuration for Lightning attention cache.

Stores configuration parameters specific to Lightning attention, which uses a unified key-value representation. Similar to standard transformer cache but optimized for Lightning’s memory access patterns.

partition_axis#

Axis configuration for tensor partitioning. Defines how tensors are sharded across devices.

Type

es.PartitionAxis

batch_size#

Number of sequences in batch. None allows dynamic batch sizes.

Type

int | None

num_heads#

Number of attention heads. Used for standard multi-head attention.

Type

int | None

head_dim#

Dimension of each attention head. Defines the feature size per head.

Type

int | None

key_heads#

Number of key heads. For multi-query or grouped-query attention.

Type

int | None

value_heads#

Number of value heads. For multi-query or grouped-query attention.

Type

int | None

key_dim#

Dimension of key projections. Can differ from head_dim for asymmetric attention.

Type

int | None

value_dim#

Dimension of value projections. Can differ from head_dim for asymmetric attention.

Type

int | None

batch_size: int | None#
classmethod create(partition_axis: PartitionAxis, batch_size: int | None = None, num_heads: int | None = None, head_dim: int | None = None, key_heads: int | None = None, value_heads: int | None = None, key_dim: int | None = None, value_dim: int | None = None) LightningCacheMetaData[source]#

Create and validate Lightning cache metadata.

Factory method for creating Lightning cache configuration. Unlike standard transformer cache, Lightning allows more flexibility in parameters as it handles unified KV tensors.

Parameters
  • partition_axis (es.PartitionAxis) – Tensor partitioning configuration.

  • batch_size (int | None) – Batch size for cache allocation. None for dynamic batching.

  • num_heads (int | None) – Number of attention heads. Defaults to None for flexible head configuration.

  • head_dim (int | None) – Dimension per attention head. Defaults to None for flexible dimensions.

  • key_heads (int | None) – Number of key heads for MQA/GQA. Defaults to None (same as num_heads).

  • value_heads (int | None) – Number of value heads for MQA/GQA. Defaults to None (same as num_heads).

  • key_dim (int | None) – Key projection dimension. Defaults to None (same as head_dim).

  • value_dim (int | None) – Value projection dimension. Defaults to None (same as head_dim).

Returns

Configured metadata instance.

Return type

LightningCacheMetaData

Note

Lightning attention’s unified representation means some parameters may be handled differently than standard cache.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

head_dim: int | None#
key_dim: int | None#
key_heads: int | None#
num_heads: int | None#
partition_axis: PartitionAxis#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

value_dim: int | None#
value_heads: int | None#
class easydel.layers.caching.lightning.cache.LightningCacheView(key_value: jaxtyping.Float[Array, 'batch seq_len num_heads head_dim'] | eformer.jaximus._imus.ImplicitArray | None, metadata: LightningCacheMetaData, layer_index: int | None = None)[source]#

Bases: BaseCacheView

Single-layer cache view for Lightning attention.

Manages the unified key-value cache for one layer using Lightning’s optimized representation. Unlike standard caches that store keys and values separately, Lightning combines them for better memory efficiency.

key_value#

Unified key-value tensor. Lightning’s special representation combining K and V.

Type

cx.Array | ImplicitArray

metadata#

Static cache configuration.

Type

LightningCacheMetaData

layer_index#

Index of this layer in the model.

Type

int | None

Note

The unified representation requires special handling during concatenation and may not be compatible with standard attention.

concatenate_to_cache(query: Float[Array, 'batch query_len num_heads head_dim'], key: Float[Array, 'batch query_len num_key_heads key_dim'], value: Float[Array, 'batch query_len num_value_heads value_dim'], attention_mask: jaxtyping.Bool[Array, 'batch 1 query_len seq_len'] | jaxtyping.Float[Array, 'batch 1 query_len seq_len'], kv_sharding: PartitionSpec, quantizer: object, causal_mask: jaxtyping.Bool[Array, 'batch 1 query_len seq_len'] | bool | None = None, token_type_ids: jaxtyping.Int[Array, 'batch query_len'] | None = None) tuple[jaxtyping.Float[Array, 'batch seq_len num_key_heads key_dim'], jaxtyping.Float[Array, 'batch seq_len num_value_heads value_dim'], jaxtyping.Bool[Array, 'batch 1 query_len seq_len']][source]#

Update cache with new key/value states for Lightning attention.

Concatenates new key and value states to the cache using Lightning’s unified KV representation. This method is called during each forward pass to update the cache with newly computed states.

Note: This implementation appears to need refactoring as it references attributes (self.index, self.key, self.value) that don’t exist in the current class definition. The actual Lightning implementation would use the unified self.key_value tensor.

Parameters
  • query – Query tensor with shape [batch, query_len, num_heads, head_dim]. Used to determine update dimensions.

  • key – Key tensor with shape [batch, query_len, num_key_heads, key_dim]. New keys to add to the cache.

  • value – Value tensor with shape [batch, query_len, num_value_heads, value_dim]. New values to add to the cache.

  • attention_mask – Boolean or float mask with shape [batch, 1, query_len, seq_len]. Defines which positions can attend to which.

  • kv_sharding – JAX PartitionSpec for sharding the KV cache.

  • quantizer – Quantization function for cache compression.

  • causal_mask – Optional causal mask for autoregressive attention. Can be boolean array or boolean value.

  • token_type_ids – Optional token type IDs for segment-level masking. Shape [batch, query_len].

Returns

  • Updated key cache: Float[Array, “batch seq_len num_key_heads key_dim”]

  • Updated value cache: Float[Array, “batch seq_len num_value_heads value_dim”]

  • Updated attention mask: Bool[Array, “batch 1 query_len seq_len”]

Return type

Tuple containing

Raises

NotImplementedError – Current implementation needs refactoring to properly use Lightning’s unified KV representation.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(metadata: LightningCacheMetaData, layer_index: int | None = None) LightningCacheView[source]#

Initialize a Lightning cache view for a single layer.

Creates a cache view with placeholder for unified KV tensor. Actual tensor allocation happens during first use to allow for dynamic shapes.

Parameters
  • metadata (LightningCacheMetaData) – Cache configuration.

  • layer_index (int | None) – Layer index in the model.

Returns

Initialized view with None placeholder.

Return type

LightningCacheView

key_value: jaxtyping.Float[Array, 'batch seq_len num_heads head_dim'] | eformer.jaximus._imus.ImplicitArray | None#
layer_index: int | None = None#
metadata: LightningCacheMetaData#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.lightning.cache.LightningMetadata[source]#

Bases: BaseRunTimeMetadata

Runtime metadata for Lightning attention cache operations.

Placeholder class for future Lightning-specific runtime metadata. Currently empty but reserved for Lightning-specific runtime state.