easydel.layers.caching.mamba.__init__#

class easydel.layers.caching.mamba.__init__.MambaCache(views: List[Optional[easydel.layers.caching.mamba.mamba_cache.MambaCacheView]])[source]#

Bases: BaseCache

classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

classmethod init_cache(metadata: MambaCacheMetaData, dtype: Optional[dtype] = None, partition_specs: Optional[PartitionSpec] = None)[source]#

Initialize a complete cache with views for all layers.

Parameters
  • metadata โ€“ Configuration metadata

  • *args โ€“ Additional positional arguments

  • **kwargs โ€“ Additional keyword arguments

Returns

Fully initialized cache instance

classmethod init_empty(num_hidden_layers)[source]#

Initialize an empty cache container.

Parameters
  • *args โ€“ Additional positional arguments

  • **kwargs โ€“ Additional keyword arguments

Returns

Cache instance with uninitialized views

replace(**kwargs)#
reset() MambaCache[source]#

Reset all cache views to their initial state.

Returns

Reset MambaCache

to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

update_conv_state(layer_idx: int, new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCache[source]#

Update the convolutional state for a specific layer.

Parameters
  • layer_idx โ€“ Index of the layer to update

  • new_conv_state โ€“ New state to be inserted

  • cache_position โ€“ Position in the cache to update

Returns

Updated MambaCache

update_ssm_state(layer_idx: int, new_ssm_state: Union[Array, ndarray, bool, number]) MambaCache[source]#

Update the SSM state for a specific layer.

Parameters
  • layer_idx โ€“ Index of the layer to update

  • new_ssm_state โ€“ New SSM state to replace the current one

Returns

Updated MambaCache

views: List[Optional[MambaCacheView]]#
class easydel.layers.caching.mamba.__init__.MambaCacheMetaData(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int)[source]#

Bases: BaseCacheMetadata

Metadata for Mamba cache configuration.

batch_size: int#
conv_kernel_size: int#
classmethod create(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int) MambaCacheMetaData[source]#

Create a MambaCacheMetaData instance with validation.

Parameters

partition_axis โ€“

Partition Axis.

batch_size: Size of the batch intermediate_size: Modelโ€™s intermediate size ssm_state_size: Modelโ€™s state size conv_kernel_size: Modelโ€™s convolution kernel size

Returns

MambaCacheMetaData instance

Raises

ValueError โ€“ If required parameters are invalid

classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

intermediate_size: int#
num_hidden_layers: int#
partition_axis: PartitionAxis#
replace(**kwargs)#
ssm_state_size: int#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

class easydel.layers.caching.mamba.__init__.MambaCacheView(conv_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], ssm_states: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, eformer.jaximus._imus.ImplicitArray], positions: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number], metadata: easydel.layers.caching.mamba.mamba_cache.MambaCacheMetaData, layer_index: Optional[int] = None)[source]#

Bases: BaseCacheView

concatenate_to_cache(*args, **kwargs)[source]#

Update cache with new states.

Parameters
  • *args โ€“ Typically includes new tensors

  • **kwargs โ€“ Additional parameters for cache update

Returns

  • anything

Return type

Tuple containing

conv_states: Union[Array, ndarray, bool, number, ImplicitArray]#
classmethod from_dict(data)#

Create an instance from a dictionary (deserialization).

classmethod from_json(json_str)#

Create an instance from a JSON string.

classmethod init(metadata: MambaCacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: Optional[int] = None)[source]#

Initialize a new cache view instance.

Parameters
  • metadata โ€“ Configuration metadata for the cache

  • *args โ€“ Additional positional arguments

  • **kwargs โ€“ Additional keyword arguments

Returns

Initialized cache view instance

layer_index: Optional[int] = None#
metadata: MambaCacheMetaData#
positions: Union[Array, ndarray, bool, number]#
replace(**kwargs)#
reset() MambaCacheView[source]#

Reset both conv and ssm states to zeros.

Returns

Reset MambaCacheView

ssm_states: Union[Array, ndarray, bool, number, ImplicitArray]#
to_dict()#

Convert the instance to a dictionary for JSON serialization.

to_json(**kwargs)#

Convert the instance to a JSON string.

update_conv_state(new_conv_state: Union[Array, ndarray, bool, number], cache_position: Union[Array, ndarray, bool, number]) MambaCacheView[source]#

Update the convolutional state of the cache.

Parameters
  • new_conv_state โ€“ New state to be inserted

  • cache_position โ€“ Position in the cache to update

Returns

Updated MambaCacheView

update_ssm_state(new_ssm_state: Union[Array, ndarray, bool, number]) MambaCacheView[source]#

Update the SSM state of the cache.

Parameters

new_ssm_state โ€“ New SSM state to replace the current one

Returns

Updated MambaCacheView

class easydel.layers.caching.mamba.__init__.MambaMetadata[source]#

Bases: BaseRunTimeMetadata