easydel.layers.caching.mamba.cache#
Mamba state-space model cache implementation for EasyDeL.
This module provides caching for Mamba models, which use state-space formulations instead of attention mechanisms. Mamba caches maintain convolutional and SSM (State Space Model) states rather than key-value pairs.
Mamba models process sequences using: - Convolutional states for local context - SSM states for long-range dependencies - Efficient linear-time complexity
- Key Components:
MambaCacheMetaData: Configuration for Mamba cache dimensions
MambaCacheView: Per-layer state storage for conv and SSM
MambaCache: Multi-layer Mamba cache orchestration
MambaMetadata: Runtime metadata (placeholder)
- Features:
Separate convolutional and SSM state management
Rolling buffer for convolutional states
Direct SSM state updates
Memory-efficient state representation
Example
>>> metadata = MambaCacheMetaData.create(
... num_hidden_layers=24,
... partition_axis=partition_axis,
... batch_size=2,
... intermediate_size=2048,
... ssm_state_size=16,
... conv_kernel_size=4
... )
>>> cache = MambaCache.init_cache(
... metadata=metadata,
... dtype=jnp.float32
... )
>>> # Update conv state for layer 0
>>> cache = cache.update_conv_state(
... layer_idx=0,
... new_conv_state=conv_state,
... cache_position=position
... )
- class easydel.layers.caching.mamba.cache.MambaCache(views: list[easydel.layers.caching.mamba.cache.MambaCacheView | None])[source]#
Bases:
BaseCacheMulti-layer cache container for Mamba models.
Orchestrates cache views across all Mamba layers, providing a unified interface for state management during inference. Each layer maintains independent conv and SSM states.
- views#
Ordered list of cache views, one per model layer. None values indicate uninitialized layers.
- Type
list[MambaCacheView | None]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(metadata: MambaCacheMetaData, dtype: numpy.dtype | None = None, partition_specs: jax.sharding.PartitionSpec | None = None) MambaCache[source]#
Initialize a complete Mamba cache with views for all layers.
Creates a fully initialized cache with allocated storage for all layers specified in the metadata. Each layer gets its own view with independent state tensors.
- Parameters
metadata (MambaCacheMetaData) – Configuration defining cache dimensions and number of layers.
dtype (jnp.dtype | None) – Data type for cache tensors. Defaults to bfloat16 if not specified.
partition_specs (PartitionSpec | None) – Sharding specification for distributed execution. If None, creates default spec with batch, head, and sequence axes.
- Returns
Fully initialized cache ready for inference.
- Return type
Example
>>> cache = MambaCache.init_cache( ... metadata=metadata, ... dtype=jnp.float32, ... partition_specs=PartitionSpec('dp', None, None) ... )
- classmethod init_empty(num_hidden_layers: int) MambaCache[source]#
Initialize an empty Mamba cache without allocated storage.
Creates a cache structure with None placeholders for each layer. Useful for gradual initialization or when cache allocation is deferred.
- Parameters
num_hidden_layers (int) – Number of layers to create placeholders for.
- Returns
Cache instance with uninitialized (None) views.
- Return type
Example
>>> cache = MambaCache.init_empty(num_hidden_layers=24) >>> # Populate individual layers later >>> cache.views[0] = MambaCacheView.init(...)
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() MambaCache[source]#
Reset all cache layers to zero states.
Clears the entire cache by resetting each layer’s conv and SSM states to zeros. Useful for sequence boundaries or reinitialization.
- Returns
New cache instance with all states zeroed.
- Return type
Note
Preserves cache structure; only clears state values.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(layer_idx: int, new_conv_state: Float[Array, 'batch intermediate_size'], cache_position: Int[Array, '...']) MambaCache[source]#
Update convolutional state for a specific layer.
Delegates to the specified layer’s view to update its conv state, then returns a new cache instance with the updated view.
- Parameters
layer_idx (int) – Index of the layer to update.
new_conv_state (cx.Array) – New convolutional state. Shape: [batch_size, intermediate_size]
cache_position (cx.Array) – Position for insertion.
- Returns
New cache instance with updated layer.
- Return type
- Raises
ValueError – If specified layer view is None.
Example
>>> cache = cache.update_conv_state( ... layer_idx=5, ... new_conv_state=hidden_states, ... cache_position=jnp.array([2]) ... )
- update_ssm_state(layer_idx: int, new_ssm_state: Float[Array, 'batch intermediate_size ssm_state_size']) MambaCache[source]#
Update SSM state for a specific layer.
Replaces the SSM state for the specified layer with new values, returning a new cache instance with the update.
- Parameters
layer_idx (int) – Index of the layer to update.
new_ssm_state (cx.Array) – New SSM state tensor. Shape: [batch_size, intermediate_size, ssm_state_size]
- Returns
New cache instance with updated layer.
- Return type
- Raises
ValueError – If specified layer view is None.
Example
>>> cache = cache.update_ssm_state( ... layer_idx=5, ... new_ssm_state=ssm_output ... )
- views: list[easydel.layers.caching.mamba.cache.MambaCacheView | None]#
- class easydel.layers.caching.mamba.cache.MambaCacheMetaData(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int)[source]#
Bases:
BaseCacheMetadataMetadata for Mamba cache configuration.
Stores static configuration for Mamba model caching, including dimensions for state-space model states and convolutional buffers. Mamba models use a combination of SSM states for long-range modeling and convolutional states for local context.
Number of Mamba layers in the model.
- Type
int
- partition_axis#
Configuration for tensor partitioning in distributed settings.
- Type
PartitionAxis
- batch_size#
Number of sequences in batch.
- Type
int
- intermediate_size#
Dimension of intermediate representations in Mamba blocks (typically expansion of model dimension).
- Type
int
- ssm_state_size#
Dimension of the SSM state vectors. Controls model’s memory capacity.
- Type
int
- conv_kernel_size#
Size of convolutional kernel for local mixing. Typically 3-7 for short-range dependencies.
- Type
int
- batch_size: int#
- conv_kernel_size: int#
- classmethod create(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int) MambaCacheMetaData[source]#
Create a MambaCacheMetaData instance with validation.
- Parameters
partition_axis –
Partition Axis.
batch_size: Size of the batch intermediate_size: Model’s intermediate size ssm_state_size: Model’s state size conv_kernel_size: Model’s convolution kernel size
- Returns
MambaCacheMetaData instance
- Raises
ValueError – If required parameters are invalid
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- intermediate_size: int#
- num_hidden_layers: int#
- partition_axis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- ssm_state_size: int#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.mamba.cache.MambaCacheView(conv_states: jaxtyping.Float[Array, 'batch intermediate_size conv_kernel_size'] | eformer.jaximus._imus.ImplicitArray, ssm_states: jaxtyping.Float[Array, 'batch intermediate_size ssm_state_size'] | eformer.jaximus._imus.ImplicitArray, positions: Int[Array, 'batch'], metadata: MambaCacheMetaData, layer_index: int | None = None)[source]#
Bases:
BaseCacheViewSingle-layer cache view for Mamba state management.
Manages both convolutional and SSM states for one Mamba layer. The convolutional states provide local context through a sliding window, while SSM states encode the full sequence history.
- conv_states#
Rolling buffer of convolutional states. Shape: [batch_size, intermediate_size, conv_kernel_size] Stores the last conv_kernel_size timesteps for convolution.
- Type
Array | ImplicitArray
- ssm_states#
State-space model hidden states. Shape: [batch_size, intermediate_size, ssm_state_size] Encodes full sequence history in compressed form.
- Type
Array | ImplicitArray
- positions#
Current position index per batch element. Shape: [batch_size] Tracks where each sequence is in generation.
- Type
Array
- metadata#
Static configuration metadata.
- Type
- layer_index#
Index of this layer in the model.
- Type
int | None
- concatenate_to_cache(*args, **kwargs) tuple[source]#
Not implemented for Mamba cache.
Mamba uses separate update methods for conv and SSM states rather than a unified concatenation interface.
- Raises
NotImplementedError – Always raised as this method is not applicable to Mamba caching strategy.
Note
Use update_conv_state() and update_ssm_state() instead.
- conv_states: jaxtyping.Float[Array, 'batch intermediate_size conv_kernel_size'] | eformer.jaximus._imus.ImplicitArray#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(metadata: MambaCacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: int | None = None) MambaCacheView[source]#
Initialize a Mamba cache view with zero states.
Creates and allocates cache tensors for both convolutional and SSM states. All states are initialized to zeros, representing a fresh start with no prior context.
- Parameters
metadata (MambaCacheMetaData) – Configuration for cache dimensions.
partition_specs (PartitionSpec) – Sharding specification for distributed execution. Applied to both conv and SSM states.
dtype (jnp.dtype) – Data type for state tensors (e.g., float32, bfloat16).
layer_index (int | None) – Optional index of this layer in the model.
- Returns
Initialized cache view with allocated zero tensors.
- Return type
Example
>>> view = MambaCacheView.init( ... metadata=metadata, ... partition_specs=PartitionSpec('dp', None, None), ... dtype=jnp.float32, ... layer_index=0 ... )
- metadata: MambaCacheMetaData#
- positions: Int[Array, 'batch']#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() MambaCacheView[source]#
Reset all cache states to zeros.
Clears both convolutional and SSM states, effectively resetting the model’s memory. Useful for: - Starting new sequences - Clearing context between batches - Debugging and testing
- Returns
Reset view with zeroed states.
- Return type
- ssm_states: jaxtyping.Float[Array, 'batch intermediate_size ssm_state_size'] | eformer.jaximus._imus.ImplicitArray#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(new_conv_state: Float[Array, 'batch intermediate_size'], cache_position: Int[Array, '...']) MambaCacheView[source]#
Update the convolutional state with new values.
Implements a rolling buffer for convolutional states, where new states replace old ones in a circular fashion. This maintains a fixed-size window of recent states.
The update process: 1. Roll existing states to make room 2. Insert new state at specified position 3. Return updated view (functional update)
- Parameters
new_conv_state (cx.Array) – New convolutional state to insert. Shape: [batch_size, intermediate_size]
cache_position (cx.Array) – Position index for insertion. Clamped to valid range [0, conv_kernel_size-1].
- Returns
Updated view with new conv state.
- Return type
Note
Position is automatically clamped to prevent out-of-bounds access.
- update_ssm_state(new_ssm_state: Float[Array, 'batch intermediate_size ssm_state_size']) MambaCacheView[source]#
Update the SSM (State Space Model) state.
Replaces the entire SSM state with new values. Unlike conv states which use a rolling buffer, SSM states are completely replaced as they represent the full model state at each timestep.
- Parameters
new_ssm_state (cx.Array) – New SSM state tensor. Shape: [batch_size, intermediate_size, ssm_state_size]
- Returns
Updated view with new SSM state.
- Return type
Note
SSM states encode the full history up to current position, so replacement (not accumulation) is the correct operation.
- class easydel.layers.caching.mamba.cache.MambaMetadata[source]#
Bases:
BaseRunTimeMetadataRuntime metadata for Mamba cache operations.
Placeholder for future Mamba-specific runtime metadata. May include sequence positions, segment boundaries, etc.