easydel.layers.caching.mamba2.cache

Contents

easydel.layers.caching.mamba2.cache#

Mamba2 enhanced state-space model cache implementation.

This module provides caching for Mamba2 models, which extend the original Mamba architecture with additional features including: - Multi-head state space models - Group normalization capabilities - Enhanced convolutional processing - Improved sequence modeling

Mamba2 introduces structured state spaces with head-based organization, similar to multi-head attention but for state-space models.

Key Components:
  • Mamba2CacheMetaData: Enhanced configuration with head/group support

  • Mamba2CacheView: Per-layer state storage with sequence tracking

  • Mamba2Cache: Multi-layer orchestration with sequence updates

  • Mamba2Metadata: Runtime metadata (placeholder)

Features:
  • Head-based SSM organization

  • Group normalization support

  • Sequence position tracking

  • Extended convolutional states

Example

>>> metadata = Mamba2CacheMetaData.create(
...     partition_axis=partition_axis,
...     num_hidden_layers=32,
...     batch_size=2,
...     intermediate_size=2816,
...     num_heads=16,
...     head_dim=64,
...     state_size=128,
...     conv_kernel_size=4,
...     n_groups=8
... )
>>> cache = Mamba2Cache.init_cache(
...     num_hidden_layers=32,
...     metadata=metadata,
...     dtype=jnp.float32
... )
class easydel.layers.caching.mamba2.cache.Mamba2Cache(views: list[easydel.layers.caching.mamba2.cache.Mamba2CacheView | None])[source]#

Bases: BaseCache

Multi-layer Mamba2 cache container.

Orchestrates Mamba2 cache views across all model layers, with additional support for sequence position tracking and batch sequence updates.

views#

Per-layer cache views.

Type

list[Mamba2CacheView | None]

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(num_hidden_layers: int, metadata: Mamba2CacheMetaData, dtype: numpy.dtype | None = None, partition_specs: jax.sharding.PartitionSpec | None = None) Mamba2Cache[source]#

Initialize a complete Mamba2 cache for all layers.

Creates a fully initialized cache with allocated storage for the specified number of layers. Each layer gets independent multi-head SSM states and convolutional buffers.

Parameters
  • num_hidden_layers (int) – Number of Mamba2 layers to initialize.

  • metadata (Mamba2CacheMetaData) – Configuration for cache dimensions.

  • dtype (jnp.dtype | None) – Data type for tensors. Defaults to bfloat16.

  • partition_specs (PartitionSpec | None) – Sharding specification. If None, creates default spec with batch, head, and sequence axes.

Returns

Fully initialized multi-layer cache.

Return type

Mamba2Cache

Example

>>> cache = Mamba2Cache.init_cache(
...     num_hidden_layers=32,
...     metadata=metadata,
...     dtype=jnp.float32
... )
classmethod init_empty(num_hidden_layers: int) Mamba2Cache[source]#

Initialize an empty Mamba2 cache structure.

Creates a cache with None placeholders for gradual initialization.

Parameters

num_hidden_layers (int) – Number of layer placeholders.

Returns

Cache with uninitialized views.

Return type

Mamba2Cache

Example

>>> cache = Mamba2Cache.init_empty(32)
>>> # Initialize layers individually later
replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() Mamba2Cache[source]#

Reset all cache views to their initial state.

Returns

Reset MambaCache

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(layer_idx: int, new_conv_state: Float[Array, 'batch extended_size'], cache_position: Int[Array, '...']) Mamba2Cache[source]#

Update the convolutional state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_conv_state – New state to be inserted

  • cache_position – Position in the cache to update

Returns

Updated MambaCache

update_seq(num: int) None[source]#

Update sequence positions across all layers.

Increments position tracking for sequence continuation, useful when processing long sequences in chunks.

Parameters

num (int) – Number of positions to advance.

Note

Updates both positions and seqlen_offset for each layer.

update_ssm_state(layer_idx: int, new_ssm_state: Float[Array, 'batch num_heads head_dim state_size']) Mamba2Cache[source]#

Update the SSM state for a specific layer.

Parameters
  • layer_idx – Index of the layer to update

  • new_ssm_state – New SSM state to replace the current one

Returns

Updated MambaCache

views: list[easydel.layers.caching.mamba2.cache.Mamba2CacheView | None]#
class easydel.layers.caching.mamba2.cache.Mamba2CacheMetaData(partition_axis: PartitionAxis, num_hidden_layers: int, batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int)[source]#

Bases: BaseCacheMetadata

Metadata configuration for Mamba2 state-space cache.

Extends the original Mamba cache with support for multi-head state-space models and group normalization. The head-based organization allows for more expressive state representations.

partition_axis#

Tensor partitioning configuration.

Type

PartitionAxis

num_hidden_layers#

Number of Mamba2 layers in the model.

Type

int

batch_size#

Number of sequences in batch.

Type

int

intermediate_size#

Hidden dimension of MLP layers.

Type

int

num_heads#

Number of SSM heads (similar to attention heads).

Type

int

head_dim#

Dimension per SSM head.

Type

int

state_size#

Size of the state-space representation.

Type

int

conv_kernel_size#

Size of convolutional kernel.

Type

int

n_groups#

Number of groups for group operations.

Type

int

batch_size: int#
conv_kernel_size: int#
classmethod create(parition_axis: PartitionAxis, num_hidden_layers: int, batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int) Mamba2CacheMetaData[source]#

Create and validate Mamba2 cache metadata.

Factory method that validates all parameters before creating the metadata instance. Ensures all dimensions are positive and compatible.

Parameters
  • partition_axis (PartitionAxis) – Sharding configuration.

  • num_hidden_layers (int) – Number of model layers.

  • batch_size (int) – Batch size for cache allocation.

  • intermediate_size (int) – MLP hidden dimension.

  • num_heads (int) – Number of SSM heads.

  • head_dim (int) – Dimension per head.

  • state_size (int) – State-space size per head.

  • conv_kernel_size (int) – Convolution kernel size.

  • n_groups (int) – Number of normalization groups.

Returns

Validated metadata instance.

Return type

Mamba2CacheMetaData

Raises

ValueError – If any parameter is non-positive.

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

head_dim: int#
intermediate_size: int#
n_groups: int#
num_heads: int#
num_hidden_layers: int#
partition_axis: PartitionAxis#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

state_size: int#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.mamba2.cache.Mamba2CacheView(conv_states: jaxtyping.Float[Array, 'batch extended_size conv_kernel_size'] | eformer.jaximus._imus.ImplicitArray, ssm_states: jaxtyping.Float[Array, 'batch num_heads head_dim state_size'] | eformer.jaximus._imus.ImplicitArray, positions: Int[Array, 'batch'], seqlen_offset: int, metadata: Mamba2CacheMetaData, layer_index: int | None = None)[source]#

Bases: BaseCacheView

Single-layer cache view for Mamba2 state-space model.

Manages both convolutional and SSM states for one Mamba2 layer, with additional tracking for sequence positions and offsets. The multi-head organization allows for richer state representations.

conv_states#

Convolutional states buffer. Shape: [batch, intermediate_size + 2*n_groups*state_size, kernel_size]

Type

cx.Array | ImplicitArray

ssm_states#

State-space model states. Shape: [batch, num_heads, head_dim, state_size]

Type

cx.Array | ImplicitArray

positions#

Current position per sequence. Shape: [batch_size]

Type

cx.Array

seqlen_offset#

Global sequence offset for continuation.

Type

int

metadata#

Static configuration.

Type

Mamba2CacheMetaData

layer_index#

Layer index in model.

Type

int | None

concatenate_to_cache(*args, **kwargs) tuple[source]#

Not implemented for Mamba2 cache.

Mamba2 uses separate update methods for conv and SSM states.

Raises

NotImplementedError – Always raised.

Note

Use update_conv_state() and update_ssm_state() instead.

conv_states: jaxtyping.Float[Array, 'batch extended_size conv_kernel_size'] | eformer.jaximus._imus.ImplicitArray#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(metadata: Mamba2CacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: int | None = None) Mamba2CacheView[source]#

Initialize a Mamba2 cache view with zero states.

Creates cache tensors for the extended Mamba2 architecture, including multi-head SSM states and expanded convolutional buffers that incorporate group normalization dimensions.

Parameters
  • metadata (Mamba2CacheMetaData) – Configuration for cache dimensions.

  • partition_specs (PartitionSpec) – Sharding specification for distributed execution.

  • dtype (jnp.dtype) – Data type for state tensors.

  • layer_index (int | None) – Optional layer index in model.

Returns

Initialized cache view with allocated tensors.

Return type

Mamba2CacheView

Note

Conv state size is extended by 2*n_groups*state_size to accommodate group normalization features.

layer_index: int | None = None#
metadata: Mamba2CacheMetaData#
positions: Int[Array, 'batch']#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

reset() Mamba2CacheView[source]#

Reset all cache states to zeros.

Clears both convolutional and SSM states while preserving structure and shape. Position tracking is maintained.

Returns

Reset view with zeroed states.

Return type

Mamba2CacheView

seqlen_offset: int#
ssm_states: jaxtyping.Float[Array, 'batch num_heads head_dim state_size'] | eformer.jaximus._imus.ImplicitArray#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

update_conv_state(new_conv_state: Float[Array, 'batch extended_size'], cache_position: Int[Array, '...']) Mamba2CacheView[source]#

Update convolutional state with new values.

Maintains a rolling buffer of convolutional states, with support for the extended state size used in Mamba2.

Parameters
  • new_conv_state (cx.Array) – New conv state to insert. Shape: [batch, intermediate_size + 2*n_groups*state_size]

  • cache_position (cx.Array) – Position index for insertion.

Returns

Updated view with new conv state.

Return type

Mamba2CacheView

update_ssm_state(new_ssm_state: Float[Array, 'batch num_heads head_dim state_size']) Mamba2CacheView[source]#

Update SSM state with head-structured representation.

Replaces the multi-head SSM state with new values, maintaining the head-based organization.

Parameters

new_ssm_state (cx.Array) – New SSM state. Shape: [batch, num_heads, head_dim, state_size]

Returns

Updated view with new SSM state.

Return type

Mamba2CacheView

class easydel.layers.caching.mamba2.cache.Mamba2Metadata[source]#

Bases: BaseRunTimeMetadata

Runtime metadata for Mamba2 cache operations.

Placeholder for future Mamba2-specific runtime state. May include head masks, group indices, etc.