easydel.layers.caching.lightning.lightning_cache#
- class easydel.layers.caching.lightning.lightning_cache.LightningCache(views: 'tp.List[tp.Optional[LightningCacheView]]')[source]#
Bases:
BaseCache- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(num_hidden_layers: int, metadata: LightningCacheMetaData)[source]#
Initialize a complete cache with views for all layers.
- Parameters
metadata – Configuration metadata
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Fully initialized cache instance
- classmethod init_empty(num_hidden_layers)[source]#
Initialize an empty cache container.
- Parameters
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Cache instance with uninitialized views
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- views: List[Optional[LightningCacheView]]#
- class easydel.layers.caching.lightning.lightning_cache.LightningCacheMetaData(partition_axis: PartitionAxis, batch_size: Optional[int], num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int])[source]#
Bases:
BaseCacheMetadataMetadata for transformer cache configuration.
- batch_size: Optional[int]#
- classmethod create(partition_axis: PartitionAxis, batch_size: Optional[int] = None, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None) LightningCacheMetaData[source]#
Create a LightningCacheMetaData instance with validation. :returns: LightningCacheMetaData instance
- Raises
ValueError – If required parameters are missing or invalid.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- head_dim: Optional[int]#
- key_dim: Optional[int]#
- key_heads: Optional[int]#
- num_heads: Optional[int]#
- partition_axis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- value_dim: Optional[int]#
- value_heads: Optional[int]#
- class easydel.layers.caching.lightning.lightning_cache.LightningCacheView(key_value: 'tp.Union[cx.Array, ImplicitArray]', metadata: 'LightningCacheMetaData', layer_index: 'tp.Optional[int]' = None)[source]#
Bases:
BaseCacheView- concatenate_to_cache(query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], attention_mask: Union[Array, ndarray, bool, number], kv_sharding: PartitionSpec, quantizer: object, causal_mask: Optional[Union[Array, ndarray, bool, number, bool]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Updates the KV cache with new key/value states and adjusts the attention mask.
Internal helper function used when KV caching is enabled.
- Parameters
query (Array) – Current query states.
key (Array) – Current key states.
value (Array) – Current value states.
attention_mask (Array) – Base attention mask.
causal_mask (tp.Optional[Array], optional) – Causal mask. Defaults to None.
token_type_ids (tp.Optional[Array], optional) – Token type IDs for segment-based masking. Defaults to None.
- Returns
Updated key cache tensor.
Updated value cache tensor.
Updated attention mask reflecting the cached sequence length.
- Return type
tp.Tuple[Array, Array, Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(metadata: LightningCacheMetaData, layer_index: Optional[int] = None)[source]#
Initialize a new cache view instance.
- Parameters
metadata – Configuration metadata for the cache
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Initialized cache view instance
- layer_index: Optional[int] = None#
- metadata: LightningCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.lightning.lightning_cache.LightningMetadata[source]#
Bases:
BaseRunTimeMetadata