easydel.layers.caching.transformer_cache#
- class easydel.layers.caching.transformer_cache.TransformerCache(views: List[Optional[easydel.layers.caching.transformer_cache.TransformerCacheView]])[source]#
Bases:
Mapping- from_tuple()#
- classmethod init_layers_cache(num_hidden_layers: int, metadata: TransformerCacheMetaData, mesh: Mesh, quantizer: Optional[object] = None, dtype: Optional[dtype] = None, key_values_partition_specs: Optional[PartitionSpec] = None)[source]#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- replace(**kwargs)#
- to_tuple()#
- values() an object providing a view on D's values#
- views: List[Optional[TransformerCacheView]]#
- class easydel.layers.caching.transformer_cache.TransformerCacheMetaData(batch_size: int, sequence_length: int, num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int], update_causal_mask: bool, create_attention_bias: bool)[source]#
Bases:
MappingMetadata for transformer cache configuration.
- batch_size: int#
- classmethod create(batch_size: int, sequence_length: int, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None, update_causal_mask: bool = True, create_attention_bias: bool = True) TransformerCacheMetaData[source]#
Create a TransformerCacheMetaData instance with validation.
- Parameters
batch_size – Size of the batch.
sequence_length – Length of the sequence.
num_heads – Number of attention heads.
head_dim – Dimension of each head.
key_heads – Number of key heads.
value_heads – Number of value heads.
key_dim – Dimension of keys.
value_dim – Dimension of values.
update_causal_mask – Whether to update causal mask.
create_attention_bias – Whether to create attention bias.
- Returns
TransformerCacheMetaData instance
- Raises
ValueError – If required parameters are missing or invalid.
- create_attention_bias: bool#
- from_tuple()#
- head_dim: Optional[int]#
- items() a set-like object providing a view on D's items#
- key_dim: Optional[int]#
- key_heads: Optional[int]#
- keys() a set-like object providing a view on D's keys#
- num_heads: Optional[int]#
- replace(**kwargs)#
- sequence_length: int#
- to_tuple()#
- update_causal_mask: bool#
- value_dim: Optional[int]#
- value_heads: Optional[int]#
- values() an object providing a view on D's values#
- class easydel.layers.caching.transformer_cache.TransformerCacheView(key: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], value: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], index: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], metadata: easydel.layers.caching.transformer_cache.TransformerCacheMetaData, layer_index: Optional[int] = None)[source]#
Bases:
Mapping- from_tuple()#
- classmethod init(metadata: TransformerCacheMetaData, quantizer: object, key_values_partition_specs: PartitionSpec, dtype: dtype, mesh: Mesh, layer_index: Optional[int] = None)[source]#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- layer_index: Optional[int] = None#
- metadata: TransformerCacheMetaData#
- replace(**kwargs)#
- to_tuple()#
- values() an object providing a view on D's values#