easydel.layers.caching.transformer.__init__

Contents

easydel.layers.caching.transformer.__init__#

class easydel.layers.caching.transformer.__init__.TransformerCache(views: 'tp.List[tp.Optional[TransformerCacheView]]')[source]#

Bases: BaseCache

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod from_pure(pure, metadata)[source]#
classmethod init_cache(mesh: Mesh, metadata: TransformerCacheMetaData, partition_manager: PartitionManager, dtype: Optional[dtype] = None, starts: Optional[Array] = None, quantizer: Optional[object] = None)[source]#

Initialize a complete cache with views for all layers.

Parameters
  • metadata – Configuration metadata

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Fully initialized cache instance

classmethod init_empty(num_hidden_layers)[source]#

Initialize an empty cache container.

Parameters
  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Cache instance with uninitialized views

insert(other: TransformerCache, slot: int, quantizer: object, partition_manager: PartitionManager)[source]#
insert_index(index, slot: int, partition_manager: PartitionManager)[source]#
insert_starts(starts, slot: int, partition_manager: PartitionManager)[source]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

to_pure()[source]#
views: List[Optional[TransformerCacheView]]#
class easydel.layers.caching.transformer.__init__.TransformerCacheMetaData(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int], update_causal_mask: bool, create_attention_bias: bool)[source]#

Bases: BaseCacheMetadata

Metadata for transformer cache configuration.

batch_size: int#
classmethod create(batch_size: int, sequence_length: int, num_hidden_layers: int, pad_token_id: int, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None, update_causal_mask: bool = True, create_attention_bias: bool = True) TransformerCacheMetaData[source]#

Create a TransformerCacheMetaData instance with validation.

Parameters
  • batch_size – Size of the batch.

  • sequence_length – Length of the sequence.

  • num_hidden_layers – number of hidden layers.

  • num_heads – Number of attention heads.

  • head_dim – Dimension of each head.

  • key_heads – Number of key heads.

  • value_heads – Number of value heads.

  • key_dim – Dimension of keys.

  • value_dim – Dimension of values.

  • update_causal_mask – Whether to update causal mask.

  • create_attention_bias – Whether to create attention bias.

Returns

TransformerCacheMetaData instance

Raises

ValueError – If required parameters are missing or invalid.

create_attention_bias: bool#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

head_dim: Optional[int]#
key_dim: Optional[int]#
key_heads: Optional[int]#
num_heads: Optional[int]#
num_hidden_layers: int#
pad_token_id: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sequence_length: int#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_causal_mask: bool#
value_dim: Optional[int]#
value_heads: Optional[int]#
class easydel.layers.caching.transformer.__init__.TransformerCacheView(key: 'tp.Union[cx.Array, ImplicitArray]', value: 'tp.Union[cx.Array, ImplicitArray]', index: 'tp.Union[cx.Array, ImplicitArray]', starts: 'tp.Union[cx.Array, ImplicitArray]', metadata: 'TransformerCacheMetaData', layer_index: 'tp.Optional[int]' = None)[source]#

Bases: BaseCacheView

concatenate_to_cache(query: Union[Array, ndarray, bool, number], key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], quantizer: object, cache_metadata: Optional[TransformerMetadata], attention_mask: Union[Array, ndarray, bool, number], partition_manager: PartitionManager, causal_mask: Optional[Union[Array, ndarray, bool, number]] = None, token_type_ids: Optional[Union[Array, ndarray, bool, number]] = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#

Updates the KV cache functionally and returns the updated tensors along with the appropriate attention mask.

Parameters
  • query – Current query states.

  • key – Current key states to add to the cache.

  • value – Current value states to add to the cache.

  • cache_metadata – Optional metadata. If provided and contains slot/length info, enables pooled caching.

  • attention_mask – Base attention mask.

  • quantizer – Quantizer for the cache.

  • causal_mask – Optional causal mask.

  • token_type_ids – Optional token type IDs for segment masking.

Returns

  • Updated key cache tensor (functional update).

  • Updated value cache tensor (functional update).

  • Final attention mask to be used (either original or calculated).

Return type

Tuple[Array, Array, Array]

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

index: Union[Array, ndarray, bool, number, ImplicitArray]#
classmethod init(mesh: Mesh, dtype: dtype, metadata: TransformerCacheMetaData, quantizer: object, partition_manager: PartitionManager, starts: Optional[Array] = None, layer_index: Optional[int] = None)[source]#

Initialize a new cache view instance.

Parameters
  • metadata – Configuration metadata for the cache

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Initialized cache view instance

property is_empty#
key: Union[Array, ndarray, bool, number, ImplicitArray]#
layer_index: Optional[int] = None#
metadata: TransformerCacheMetaData#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

starts: Union[Array, ndarray, bool, number, ImplicitArray]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

value: Union[Array, ndarray, bool, number, ImplicitArray]#
class easydel.layers.caching.transformer.__init__.TransformerMetadata(postpadded: bool = False, index: int | None = None)[source]#

Bases: BaseRunTimeMetadata

holds optional metadata for attention runtime

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

index: int | None = None#
postpadded: bool = False#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.