easydel.layers.caching.transformer.transformer_cache#
- class easydel.layers.caching.transformer.transformer_cache.TransformerCache(views: 'tp.List[tp.Optional[TransformerCacheView]]')[source]#
Bases:
BaseCache- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(mesh: Mesh, metadata: TransformerCacheMetaData, partition_manager: PartitionManager, dtype: Optional[dtype] = None, starts: Optional[Array] = None, quantizer: Optional[object] = None)[source]#
Initialize a complete cache with views for all layers.
- Parameters
metadata – Configuration metadata
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Fully initialized cache instance
- classmethod init_empty(num_hidden_layers)[source]#
Initialize an empty cache container.
- Parameters
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Cache instance with uninitialized views
- insert(other: TransformerCache, slot: int, quantizer: object, partition_manager: PartitionManager)[source]#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- views: List[Optional[TransformerCacheView]]#
- class easydel.layers.caching.transformer.transformer_cache.TransformerCacheMetaData(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int], update_causal_mask: bool, create_attention_bias: bool)[source]#
Bases:
BaseCacheMetadataMetadata for transformer cache configuration.
- batch_size: int#
- classmethod create(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None, update_causal_mask: bool = True, create_attention_bias: bool = True) TransformerCacheMetaData[source]#
Create a TransformerCacheMetaData instance with validation.
- Parameters
batch_size – Size of the batch.
sequence_length – Length of the sequence.
num_hidden_layers – number of hidden layers.
num_heads – Number of attention heads.
head_dim – Dimension of each head.
key_heads – Number of key heads.
value_heads – Number of value heads.
key_dim – Dimension of keys.
value_dim – Dimension of values.
update_causal_mask – Whether to update causal mask.
create_attention_bias – Whether to create attention bias.
- Returns
TransformerCacheMetaData instance
- Raises
ValueError – If required parameters are missing or invalid.
- create_attention_bias: bool#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- head_dim: Optional[int]#
- key_dim: Optional[int]#
- key_heads: Optional[int]#
- num_heads: Optional[int]#
- pad_token_id: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sequence_length: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_causal_mask: bool#
- value_dim: Optional[int]#
- value_heads: Optional[int]#
- class easydel.layers.caching.transformer.transformer_cache.TransformerCacheView(key: 'tp.Union[cx.Array, ImplicitArray]', value: 'tp.Union[cx.Array, ImplicitArray]', index: 'tp.Union[cx.Array, ImplicitArray]', starts: 'tp.Union[cx.Array, ImplicitArray]', metadata: 'TransformerCacheMetaData', layer_index: 'tp.Optional[int]' = None)[source]#
Bases:
BaseCacheView- concatenate_to_cache(query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], quantizer: object, cache_metadata: Optional[TransformerMetadata], attention_mask: Union[Array, ndarray, bool, number], partition_manager: PartitionManager, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.
- Parameters
query – Current query states.
key – Current key states to add to the cache.
value – Current value states to add to the cache.
cache_metadata – Optional metadata. If provided and contains slot/length info, enables pooled caching.
attention_mask – Base attention mask.
quantizer – Quantizer for the cache.
causal_mask – Optional causal mask.
token_type_ids – Optional token type IDs for segment masking.
- Returns
Updated key cache tensor (functional update).
Updated value cache tensor (functional update).
Final attention mask to be used (either original or calculated).
- Return type
Tuple[Array, Array, Array]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(mesh: Mesh, dtype: dtype, metadata: TransformerCacheMetaData, quantizer: object, partition_manager: PartitionManager, starts: Optional[Array] = None, layer_index: Optional[int] = None)[source]#
Initialize a new cache view instance.
- Parameters
metadata – Configuration metadata for the cache
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Initialized cache view instance
- property is_empty#
- layer_index: Optional[int] = None#
- metadata: TransformerCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.transformer.transformer_cache.TransformerMetadata(postpadded: bool = False, index: int | None = None)[source]#
Bases:
BaseRunTimeMetadataholds optional metadata for attention runtime
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- postpadded: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.