easydel.layers.caching.mamba.cache

easydel.layers.caching.mamba.cache#

Mamba state-space model cache implementation for EasyDeL.

This module provides caching for Mamba models, which use state-space formulations instead of attention mechanisms. Mamba caches maintain convolutional and SSM (State Space Model) states rather than key-value pairs.

Mamba models process sequences using: - Convolutional states for local context - SSM states for long-range dependencies - Efficient linear-time complexity

Key Components:
  • MambaCacheMetaData: Configuration for Mamba cache dimensions

  • MambaCacheView: Per-layer state storage for conv and SSM

  • MambaCache: Multi-layer Mamba cache orchestration

  • MambaMetadata: Runtime metadata (placeholder)

Features:
  • Separate convolutional and SSM state management

  • Rolling buffer for convolutional states

  • Direct SSM state updates

  • Memory-efficient state representation

Example

>>> metadata = MambaCacheMetaData.create(
...     num_hidden_layers=24,
...     partition_axis=partition_axis,
...     batch_size=2,
...     intermediate_size=2048,
...     ssm_state_size=16,
...     conv_kernel_size=4
... )
>>> cache = MambaCache.init_cache(
...     metadata=metadata,
...     dtype=jnp.float32
... )
>>> # Update conv state for layer 0
>>> cache = cache.update_conv_state(
...     layer_idx=0,
...     new_conv_state=conv_state,
...     cache_position=position
... )
class easydel.layers.caching.mamba.cache.MambaCache(views: list[easydel.layers.caching.mamba.cache.MambaCacheView | None])[source]#

Bases: BaseCache

Multi-layer cache container for Mamba models.

Orchestrates cache views across all Mamba layers, providing a unified interface for state management during inference. Each layer maintains independent conv and SSM states.

views#

Ordered list of cache views, one per model layer. None values indicate uninitialized layers.

Type

list[MambaCacheView | None]

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(metadata: MambaCacheMetaData, dtype: numpy.dtype | None = None, partition_specs: jax.sharding.PartitionSpec | None = None) MambaCache[source]#

Initialize a complete Mamba cache with views for all layers.

Creates a fully initialized cache with allocated storage for all layers specified in the metadata. Each layer gets its own view with independent state tensors.

Parameters
  • metadata (MambaCacheMetaData) – Configuration defining cache dimensions and number of layers.

  • dtype (jnp.dtype | None) – Data type for cache tensors. Defaults to bfloat16 if not specified.

  • partition_specs (PartitionSpec | None) – Sharding specification for distributed execution. If None, creates default spec with batch, head, and sequence axes.

Returns

Fully initialized cache ready for inference.

Return type

MambaCache

Example

>>> cache = MambaCache.init_cache(
...     metadata=metadata,
...     dtype=jnp.float32,
...     partition_specs=PartitionSpec('dp', None, None)
... )
classmethod init_empty(num_hidden_layers: int) MambaCache[source]#

Initialize an empty Mamba cache without allocated storage.

Creates a cache structure with None placeholders for each layer. Useful for gradual initialization or when cache allocation is deferred.

Parameters

num_hidden_layers (int) – Number of layers to create placeholders for.

Returns

Cache instance with uninitialized (None) views.

Return type

MambaCache

Example

>>> cache = MambaCache.init_empty(num_hidden_layers=24)
>>> # Populate individual layers later
>>> cache.views[0] = MambaCacheView.init(...)
replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() MambaCache[source]#

Reset all cache layers to zero states.

Clears the entire cache by resetting each layer’s conv and SSM states to zeros. Useful for sequence boundaries or reinitialization.

Returns

New cache instance with all states zeroed.

Return type

MambaCache

Note

Preserves cache structure; only clears state values.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(layer_idx: int, new_conv_state: Float[Array, 'batch intermediate_size'], cache_position: Int[Array, '...']) MambaCache[source]#

Update convolutional state for a specific layer.

Delegates to the specified layer’s view to update its conv state, then returns a new cache instance with the updated view.

Parameters
  • layer_idx (int) – Index of the layer to update.

  • new_conv_state (cx.Array) – New convolutional state. Shape: [batch_size, intermediate_size]

  • cache_position (cx.Array) – Position for insertion.

Returns

New cache instance with updated layer.

Return type

MambaCache

Raises

ValueError – If specified layer view is None.

Example

>>> cache = cache.update_conv_state(
...     layer_idx=5,
...     new_conv_state=hidden_states,
...     cache_position=jnp.array([2])
... )
update_ssm_state(layer_idx: int, new_ssm_state: Float[Array, 'batch intermediate_size ssm_state_size']) MambaCache[source]#

Update SSM state for a specific layer.

Replaces the SSM state for the specified layer with new values, returning a new cache instance with the update.

Parameters
  • layer_idx (int) – Index of the layer to update.

  • new_ssm_state (cx.Array) – New SSM state tensor. Shape: [batch_size, intermediate_size, ssm_state_size]

Returns

New cache instance with updated layer.

Return type

MambaCache

Raises

ValueError – If specified layer view is None.

Example

>>> cache = cache.update_ssm_state(
...     layer_idx=5,
...     new_ssm_state=ssm_output
... )
views: list[easydel.layers.caching.mamba.cache.MambaCacheView | None]#
class easydel.layers.caching.mamba.cache.MambaCacheMetaData(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int)[source]#

Bases: BaseCacheMetadata

Metadata for Mamba cache configuration.

Stores static configuration for Mamba model caching, including dimensions for state-space model states and convolutional buffers. Mamba models use a combination of SSM states for long-range modeling and convolutional states for local context.

num_hidden_layers#

Number of Mamba layers in the model.

Type

int

partition_axis#

Configuration for tensor partitioning in distributed settings.

Type

PartitionAxis

batch_size#

Number of sequences in batch.

Type

int

intermediate_size#

Dimension of intermediate representations in Mamba blocks (typically expansion of model dimension).

Type

int

ssm_state_size#

Dimension of the SSM state vectors. Controls model’s memory capacity.

Type

int

conv_kernel_size#

Size of convolutional kernel for local mixing. Typically 3-7 for short-range dependencies.

Type

int

batch_size: int#
conv_kernel_size: int#
classmethod create(num_hidden_layers: int, partition_axis: PartitionAxis, batch_size: int, intermediate_size: int, ssm_state_size: int, conv_kernel_size: int) MambaCacheMetaData[source]#

Create a MambaCacheMetaData instance with validation.

Parameters

partition_axis

Partition Axis.

batch_size: Size of the batch intermediate_size: Model’s intermediate size ssm_state_size: Model’s state size conv_kernel_size: Model’s convolution kernel size

Returns

MambaCacheMetaData instance

Raises

ValueError – If required parameters are invalid

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

intermediate_size: int#
num_hidden_layers: int#
partition_axis: PartitionAxis#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

ssm_state_size: int#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.mamba.cache.MambaCacheView(conv_states: jaxtyping.Float[Array, 'batch intermediate_size conv_kernel_size'] | eformer.jaximus._imus.ImplicitArray, ssm_states: jaxtyping.Float[Array, 'batch intermediate_size ssm_state_size'] | eformer.jaximus._imus.ImplicitArray, positions: Int[Array, 'batch'], metadata: MambaCacheMetaData, layer_index: int | None = None)[source]#

Bases: BaseCacheView

Single-layer cache view for Mamba state management.

Manages both convolutional and SSM states for one Mamba layer. The convolutional states provide local context through a sliding window, while SSM states encode the full sequence history.

conv_states#

Rolling buffer of convolutional states. Shape: [batch_size, intermediate_size, conv_kernel_size] Stores the last conv_kernel_size timesteps for convolution.

Type

Array | ImplicitArray

ssm_states#

State-space model hidden states. Shape: [batch_size, intermediate_size, ssm_state_size] Encodes full sequence history in compressed form.

Type

Array | ImplicitArray

positions#

Current position index per batch element. Shape: [batch_size] Tracks where each sequence is in generation.

Type

Array

metadata#

Static configuration metadata.

Type

MambaCacheMetaData

layer_index#

Index of this layer in the model.

Type

int | None

concatenate_to_cache(*args, **kwargs) tuple[source]#

Not implemented for Mamba cache.

Mamba uses separate update methods for conv and SSM states rather than a unified concatenation interface.

Raises

NotImplementedError – Always raised as this method is not applicable to Mamba caching strategy.

Note

Use update_conv_state() and update_ssm_state() instead.

conv_states: jaxtyping.Float[Array, 'batch intermediate_size conv_kernel_size'] | eformer.jaximus._imus.ImplicitArray#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(metadata: MambaCacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: int | None = None) MambaCacheView[source]#

Initialize a Mamba cache view with zero states.

Creates and allocates cache tensors for both convolutional and SSM states. All states are initialized to zeros, representing a fresh start with no prior context.

Parameters
  • metadata (MambaCacheMetaData) – Configuration for cache dimensions.

  • partition_specs (PartitionSpec) – Sharding specification for distributed execution. Applied to both conv and SSM states.

  • dtype (jnp.dtype) – Data type for state tensors (e.g., float32, bfloat16).

  • layer_index (int | None) – Optional index of this layer in the model.

Returns

Initialized cache view with allocated zero tensors.

Return type

MambaCacheView

Example

>>> view = MambaCacheView.init(
...     metadata=metadata,
...     partition_specs=PartitionSpec('dp', None, None),
...     dtype=jnp.float32,
...     layer_index=0
... )
layer_index: int | None = None#
metadata: MambaCacheMetaData#
positions: Int[Array, 'batch']#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() MambaCacheView[source]#

Reset all cache states to zeros.

Clears both convolutional and SSM states, effectively resetting the model’s memory. Useful for: - Starting new sequences - Clearing context between batches - Debugging and testing

Returns

Reset view with zeroed states.

Return type

MambaCacheView

ssm_states: jaxtyping.Float[Array, 'batch intermediate_size ssm_state_size'] | eformer.jaximus._imus.ImplicitArray#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(new_conv_state: Float[Array, 'batch intermediate_size'], cache_position: Int[Array, '...']) MambaCacheView[source]#

Update the convolutional state with new values.

Implements a rolling buffer for convolutional states, where new states replace old ones in a circular fashion. This maintains a fixed-size window of recent states.

The update process: 1. Roll existing states to make room 2. Insert new state at specified position 3. Return updated view (functional update)

Parameters
  • new_conv_state (cx.Array) – New convolutional state to insert. Shape: [batch_size, intermediate_size]

  • cache_position (cx.Array) – Position index for insertion. Clamped to valid range [0, conv_kernel_size-1].

Returns

Updated view with new conv state.

Return type

MambaCacheView

Note

Position is automatically clamped to prevent out-of-bounds access.

update_ssm_state(new_ssm_state: Float[Array, 'batch intermediate_size ssm_state_size']) MambaCacheView[source]#

Update the SSM (State Space Model) state.

Replaces the entire SSM state with new values. Unlike conv states which use a rolling buffer, SSM states are completely replaced as they represent the full model state at each timestep.

Parameters

new_ssm_state (cx.Array) – New SSM state tensor. Shape: [batch_size, intermediate_size, ssm_state_size]

Returns

Updated view with new SSM state.

Return type

MambaCacheView

Note

SSM states encode the full history up to current position, so replacement (not accumulation) is the correct operation.

class easydel.layers.caching.mamba.cache.MambaMetadata[source]#

Bases: BaseRunTimeMetadata

Runtime metadata for Mamba cache operations.

Placeholder for future Mamba-specific runtime metadata. May include sequence positions, segment boundaries, etc.