easydel.layers.caching.mamba2_cache#
- class easydel.layers.caching.mamba2_cache.Mamba2Cache(views: List[Optional[easydel.layers.caching.mamba2_cache.Mamba2CacheView]])[source]#
Bases:
Mapping- from_tuple()#
- classmethod init_layers_cache(num_hidden_layers: int, metadata: Mamba2CacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- replace(**kwargs)#
- reset() Mamba2Cache[source]#
Reset all cache views to their initial state.
- Returns
Reset MambaCache
- to_tuple()#
- update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) Mamba2Cache[source]#
Update the convolutional state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_conv_state – New state to be inserted
cache_position – Position in the cache to update
- Returns
Updated MambaCache
- update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) Mamba2Cache[source]#
Update the SSM state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_ssm_state – New SSM state to replace the current one
- Returns
Updated MambaCache
- values() an object providing a view on D's values#
- views: List[Optional[Mamba2CacheView]]#
- class easydel.layers.caching.mamba2_cache.Mamba2CacheMetaData(batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int)[source]#
Bases:
MappingMetadata for Mamba2 cache configuration.
- batch_size: int#
- conv_kernel_size: int#
- classmethod create(batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int) Mamba2CacheMetaData[source]#
Create a Mamba2CacheMetaData instance with validation.
- from_tuple()#
- head_dim: int#
- intermediate_size: int#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- n_groups: int#
- num_heads: int#
- replace(**kwargs)#
- state_size: int#
- to_tuple()#
- values() an object providing a view on D's values#
- class easydel.layers.caching.mamba2_cache.Mamba2CacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], seqlen_offset: int, metadata: easydel.layers.caching.mamba2_cache.Mamba2CacheMetaData, layer_index: Optional[int] = None)[source]#
Bases:
Mapping- from_tuple()#
- classmethod init(metadata: Mamba2CacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#
- items() a set-like object providing a view on D's items#
- keys() a set-like object providing a view on D's keys#
- layer_index: Optional[int] = None#
- metadata: Mamba2CacheMetaData#
- replace(**kwargs)#
- reset() Mamba2CacheView[source]#
Reset both conv and ssm states to zeros.
- seqlen_offset: int#
- to_tuple()#
- update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) Mamba2CacheView[source]#
Update the convolutional state of the cache.
- update_ssm_state(new_ssm_state: Union[Array, ndarray, bool, number]) Mamba2CacheView[source]#
Update the SSM state of the cache.
- values() an object providing a view on D's values#