easydel.layers.caching.mamba.mamba_cache#
- class easydel.layers.caching.mamba.mamba_cache.MambaCache(views: List[Optional[easydel.layers.caching.mamba.mamba_cache.MambaCacheView]])[source]#
Bases:
BaseCache- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(metadata: MambaCacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#
Initialize a complete cache with views for all layers.
- Parameters
metadata – Configuration metadata
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Fully initialized cache instance
- classmethod init_empty(num_hidden_layers)[source]#
Initialize an empty cache container.
- Parameters
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Cache instance with uninitialized views
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() MambaCache[source]#
Reset all cache views to their initial state.
- Returns
Reset MambaCache
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCache[source]#
Update the convolutional state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_conv_state – New state to be inserted
cache_position – Position in the cache to update
- Returns
Updated MambaCache
- update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) MambaCache[source]#
Update the SSM state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_ssm_state – New SSM state to replace the current one
- Returns
Updated MambaCache
- views: List[Optional[MambaCacheView]]#
- class easydel.layers.caching.mamba.mamba_cache.MambaCacheMetaData(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int)[source]#
Bases:
BaseCacheMetadataMetadata for Mamba cache configuration.
- batch_size: int#
- conv_kernel_size: int#
- classmethod create(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int) MambaCacheMetaData[source]#
Create a MambaCacheMetaData instance with validation.
- Parameters
partition_axis –
Partition Axis.
batch_size: Size of the batch intermediate_size: Model’s intermediate size ssm_state_size: Model’s state size conv_kernel_size: Model’s convolution kernel size
- Returns
MambaCacheMetaData instance
- Raises
ValueError – If required parameters are invalid
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- intermediate_size: int#
- partition_axis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- ssm_state_size: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.mamba.mamba_cache.MambaCacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], metadata: easydel.layers.caching.mamba.mamba_cache.MambaCacheMetaData, layer_index: Optional[int] = None)[source]#
Bases:
BaseCacheView- concatenate_to_cache(*args, **kwargs)[source]#
Update cache with new states.
- Parameters
*args – Typically includes new tensors
**kwargs – Additional parameters for cache update
- Returns
anything
- Return type
Tuple containing
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(metadata: MambaCacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#
Initialize a new cache view instance.
- Parameters
metadata – Configuration metadata for the cache
*args – Additional positional arguments
**kwargs – Additional keyword arguments
- Returns
Initialized cache view instance
- layer_index: Optional[int] = None#
- metadata: MambaCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() MambaCacheView[source]#
Reset both conv and ssm states to zeros.
- Returns
Reset MambaCacheView
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCacheView[source]#
Update the convolutional state of the cache.
- Parameters
new_conv_state – New state to be inserted
cache_position – Position in the cache to update
- Returns
Updated MambaCacheView
- class easydel.layers.caching.mamba.mamba_cache.MambaMetadata[source]#
Bases:
BaseRunTimeMetadata