easydel.layers.caching.__init__

Contents

easydel.layers.caching.__init__#

class easydel.layers.caching.__init__.LightningCache(views: 'tp.List[tp.Optional[LightningCacheView]]')[source]#

Bases: Mapping

from_tuple()#
classmethod init_empty(num_hidden_layers)[source]#
classmethod init_layers_cache(num_hidden_layers: int, metadata: LightningCacheMetaData)[source]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
replace(**kwargs)#
to_tuple()#
values() an object providing a view on D's values#
views: List[Optional[LightningCacheView]]#
class easydel.layers.caching.__init__.LightningCacheMetaData(batch_size: Optional[int], num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int])[source]#

Bases: Mapping

Metadata for transformer cache configuration.

batch_size: Optional[int]#
classmethod create(batch_size: Optional[int] = None, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None) LightningCacheMetaData[source]#

Create a LightningCacheMetaData instance with validation. :returns: LightningCacheMetaData instance

Raises

ValueError – If required parameters are missing or invalid.

from_tuple()#
head_dim: Optional[int]#
items() a set-like object providing a view on D's items#
key_dim: Optional[int]#
key_heads: Optional[int]#
keys() a set-like object providing a view on D's keys#
num_heads: Optional[int]#
replace(**kwargs)#
to_tuple()#
value_dim: Optional[int]#
value_heads: Optional[int]#
values() an object providing a view on D's values#
class easydel.layers.caching.__init__.LightningCacheView(key_value: 'tp.Union[cx.Array, ImplicitArray]', metadata: 'LightningCacheMetaData', layer_index: 'tp.Optional[int]' = None)[source]#

Bases: Mapping

from_tuple()#
classmethod init(metadata: LightningCacheMetaData, layer_index: Optional[int] = None)[source]#
items() a set-like object providing a view on D's items#
key_value: Union[Array, ndarray, bool, number, ImplicitArray]#
keys() a set-like object providing a view on D's keys#
layer_index: Optional[int] = None#
metadata: LightningCacheMetaData#
replace(**kwargs)#
to_tuple()#
values() an object providing a view on D's values#
class easydel.layers.caching.__init__.Mamba2Cache(views: List[Optional[easydel.layers.caching.mamba2_cache.Mamba2CacheView]])[source]#

Bases: Mapping

from_tuple()#
classmethod init_empty(num_hidden_layers)[source]#
classmethod init_layers_cache(num_hidden_layers: int, metadata: Mamba2CacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
replace(**kwargs)#
reset() Mamba2Cache[source]#

Reset all cache views to their initial state.

Returns

Reset MambaCache

to_tuple()#
update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) Mamba2Cache[source]#

Update the convolutional state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCache

update_seq(num)[source]#
update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) Mamba2Cache[source]#

Update the SSM state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCache

values() an object providing a view on D's values#
views: List[Optional[Mamba2CacheView]]#
class easydel.layers.caching.__init__.Mamba2CacheMetaData(batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int)[source]#

Bases: Mapping

Metadata for Mamba2 cache configuration.

batch_size: int#
conv_kernel_size: int#
classmethod create(batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int) Mamba2CacheMetaData[source]#

Create a Mamba2CacheMetaData instance with validation.

from_tuple()#
head_dim: int#
intermediate_size: int#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
n_groups: int#
num_heads: int#
replace(**kwargs)#
state_size: int#
to_tuple()#
values() an object providing a view on D's values#
class easydel.layers.caching.__init__.Mamba2CacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], seqlen_offset: int, metadata: easydel.layers.caching.mamba2_cache.Mamba2CacheMetaData, layer_index: Optional[int] = None)[source]#

Bases: Mapping

conv_states: Union[Array, ndarray, bool, number, ImplicitArray]#
from_tuple()#
classmethod init(metadata: Mamba2CacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
layer_index: Optional[int] = None#
metadata: Mamba2CacheMetaData#
positions: Union[Array, ndarray, bool, number]#
replace(**kwargs)#
reset() Mamba2CacheView[source]#

Reset both conv and ssm states to zeros.

seqlen_offset: int#
ssm_states: Union[Array, ndarray, bool, number, ImplicitArray]#
to_tuple()#
update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) Mamba2CacheView[source]#

Update the convolutional state of the cache.

update_ssm_state(new_ssm_state: Union[Array, ndarray, bool, number]) Mamba2CacheView[source]#

Update the SSM state of the cache.

values() an object providing a view on D's values#
class easydel.layers.caching.__init__.MambaCache(views: List[Optional[easydel.layers.caching.mamba_cache.MambaCacheView]])[source]#

Bases: Mapping

from_tuple()#
classmethod init_empty(num_hidden_layers)[source]#
classmethod init_layers_cache(num_hidden_layers: int, metadata: MambaCacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
replace(**kwargs)#
reset() MambaCache[source]#

Reset all cache views to their initial state.

Returns

Reset MambaCache

to_tuple()#
update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCache[source]#

Update the convolutional state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCache

update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) MambaCache[source]#

Update the SSM state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCache

values() an object providing a view on D's values#
views: List[Optional[MambaCacheView]]#
class easydel.layers.caching.__init__.MambaCacheMetaData(batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int)[source]#

Bases: Mapping

Metadata for Mamba cache configuration.

batch_size: int#
conv_kernel_size: int#
classmethod create(batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int) MambaCacheMetaData[source]#

Create a MambaCacheMetaData instance with validation.

Parameters
  • batch_size – Size of the batch

  • intermediate_size – Model’s intermediate size

  • ssm_state_size – Model’s state size

  • conv_kernel_size – Model’s convolution kernel size

Returns

MambaCacheMetaData instance

Raises

ValueError – If required parameters are invalid

from_tuple()#
intermediate_size: int#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
replace(**kwargs)#
ssm_state_size: int#
to_tuple()#
values() an object providing a view on D's values#
class easydel.layers.caching.__init__.MambaCacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], metadata: easydel.layers.caching.mamba_cache.MambaCacheMetaData, layer_index: Optional[int] = None)[source]#

Bases: Mapping

conv_states: Union[Array, ndarray, bool, number, ImplicitArray]#
from_tuple()#
classmethod init(metadata: MambaCacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
layer_index: Optional[int] = None#
metadata: MambaCacheMetaData#
positions: Union[Array, ndarray, bool, number]#
replace(**kwargs)#
reset() MambaCacheView[source]#

Reset both conv and ssm states to zeros.

Returns

Reset MambaCacheView

ssm_states: Union[Array, ndarray, bool, number, ImplicitArray]#
to_tuple()#
update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCacheView[source]#

Update the convolutional state of the cache.

Parameters
  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCacheView

update_ssm_state(new_ssm_state: Union[Array, ndarray, bool, number]) MambaCacheView[source]#

Update the SSM state of the cache.

Parameters

new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCacheView

values() an object providing a view on D's values#
class easydel.layers.caching.__init__.TransformerCache(views: List[Optional[easydel.layers.caching.transformer_cache.TransformerCacheView]])[source]#

Bases: Mapping

from_tuple()#
classmethod init_empty(num_hidden_layers)[source]#
classmethod init_layers_cache(num_hidden_layers: int, metadata: TransformerCacheMetaData, mesh: Mesh, quantizer: Optional[object] = None, dtype: Optional[dtype] = None, key_values_partition_specs: Optional[PartitionSpec] = None)[source]#
items() a set-like object providing a view on D's items#
keys() a set-like object providing a view on D's keys#
replace(**kwargs)#
to_tuple()#
values() an object providing a view on D's values#
views: List[Optional[TransformerCacheView]]#
class easydel.layers.caching.__init__.TransformerCacheMetaData(batch_size: int, sequence_length: int, num_heads: Optional[int], head_dim: Optional[int], key_heads: Optional[int], value_heads: Optional[int], key_dim: Optional[int], value_dim: Optional[int], update_causal_mask: bool, create_attention_bias: bool)[source]#

Bases: Mapping

Metadata for transformer cache configuration.

batch_size: int#
classmethod create(batch_size: int, sequence_length: int, num_heads: Optional[int] = None, head_dim: Optional[int] = None, key_heads: Optional[int] = None, value_heads: Optional[int] = None, key_dim: Optional[int] = None, value_dim: Optional[int] = None, update_causal_mask: bool = True, create_attention_bias: bool = True) TransformerCacheMetaData[source]#

Create a TransformerCacheMetaData instance with validation.

Parameters
  • batch_size – Size of the batch.

  • sequence_length – Length of the sequence.

  • num_heads – Number of attention heads.

  • head_dim – Dimension of each head.

  • key_heads – Number of key heads.

  • value_heads – Number of value heads.

  • key_dim – Dimension of keys.

  • value_dim – Dimension of values.

  • update_causal_mask – Whether to update causal mask.

  • create_attention_bias – Whether to create attention bias.

Returns

TransformerCacheMetaData instance

Raises

ValueError – If required parameters are missing or invalid.

create_attention_bias: bool#
from_tuple()#
head_dim: Optional[int]#
items() a set-like object providing a view on D's items#
key_dim: Optional[int]#
key_heads: Optional[int]#
keys() a set-like object providing a view on D's keys#
num_heads: Optional[int]#
replace(**kwargs)#
sequence_length: int#
to_tuple()#
update_causal_mask: bool#
value_dim: Optional[int]#
value_heads: Optional[int]#
values() an object providing a view on D's values#
class easydel.layers.caching.__init__.TransformerCacheView(key: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], value: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], index: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], metadata: easydel.layers.caching.transformer_cache.TransformerCacheMetaData, layer_index: Optional[int] = None)[source]#

Bases: Mapping

from_tuple()#
index: Union[Array, ndarray, bool, number, ImplicitArray]#
classmethod init(metadata: TransformerCacheMetaData, quantizer: object, key_values_partition_specs: PartitionSpec, dtype: dtype, mesh: Mesh, layer_index: Optional[int] = None)[source]#
items() a set-like object providing a view on D's items#
key: Union[Array, ndarray, bool, number, ImplicitArray]#
keys() a set-like object providing a view on D's keys#
layer_index: Optional[int] = None#
metadata: TransformerCacheMetaData#
replace(**kwargs)#
to_tuple()#
value: Union[Array, ndarray, bool, number, ImplicitArray]#
values() an object providing a view on D's values#