easydel.layers.caching.mamba.__init__#

class easydel.layers.caching.mamba.__init__.MambaCache(views: List[Optional[easydel.layers.caching.mamba.mamba_cache.MambaCacheView]])[source]#

Bases: BaseCache

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(metadata: MambaCacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#

Initialize a complete cache with views for all layers.

Parameters
  • metadata – Configuration metadata

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Fully initialized cache instance

classmethod init_empty(num_hidden_layers)[source]#

Initialize an empty cache container.

Parameters
  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Cache instance with uninitialized views

replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() MambaCache[source]#

Reset all cache views to their initial state.

Returns

Reset MambaCache

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCache[source]#

Update the convolutional state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCache

update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) MambaCache[source]#

Update the SSM state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCache

views: List[Optional[MambaCacheView]]#
class easydel.layers.caching.mamba.__init__.MambaCacheMetaData(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int)[source]#

Bases: BaseCacheMetadata

Metadata for Mamba cache configuration.

batch_size: int#
conv_kernel_size: int#
classmethod create(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int) MambaCacheMetaData[source]#

Create a MambaCacheMetaData instance with validation.

Parameters

partition_axis –

Partition Axis.

batch_size: Size of the batch intermediate_size: Model’s intermediate size ssm_state_size: Model’s state size conv_kernel_size: Model’s convolution kernel size

Returns

MambaCacheMetaData instance

Raises

ValueError – If required parameters are invalid

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

intermediate_size: int#
num_hidden_layers: int#
partition_axis: PartitionAxis#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

ssm_state_size: int#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.mamba.__init__.MambaCacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], metadata: easydel.layers.caching.mamba.mamba_cache.MambaCacheMetaData, layer_index: Optional[int] = None)[source]#

Bases: BaseCacheView

concatenate_to_cache(*args, **kwargs)[source]#

Update cache with new states.

Parameters
  • *args – Typically includes new tensors

  • **kwargs – Additional parameters for cache update

Returns

  • anything

Return type

Tuple containing

conv_states: Union[Array, ndarray, bool, number, ImplicitArray]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(metadata: MambaCacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#

Initialize a new cache view instance.

Parameters
  • metadata – Configuration metadata for the cache

  • *args – Additional positional arguments

  • **kwargs – Additional keyword arguments

Returns

Initialized cache view instance

layer_index: Optional[int] = None#
metadata: MambaCacheMetaData#
positions: Union[Array, ndarray, bool, number]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() MambaCacheView[source]#

Reset both conv and ssm states to zeros.

Returns

Reset MambaCacheView

ssm_states: Union[Array, ndarray, bool, number, ImplicitArray]#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCacheView[source]#

Update the convolutional state of the cache.

Parameters
  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCacheView

update_ssm_state(new_ssm_state: Union[Array, ndarray, bool, number]) MambaCacheView[source]#

Update the SSM state of the cache.

Parameters

new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCacheView

class easydel.layers.caching.mamba.__init__.MambaMetadata[source]#

Bases: BaseRunTimeMetadata