# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mamba2 enhanced state-space model cache implementation.
This module provides caching for Mamba2 models, which extend the original
Mamba architecture with additional features including:
- Multi-head state space models
- Group normalization capabilities
- Enhanced convolutional processing
- Improved sequence modeling
Mamba2 introduces structured state spaces with head-based organization,
similar to multi-head attention but for state-space models.
Key Components:
- Mamba2CacheMetaData: Enhanced configuration with head/group support
- Mamba2CacheView: Per-layer state storage with sequence tracking
- Mamba2Cache: Multi-layer orchestration with sequence updates
- Mamba2Metadata: Runtime metadata (placeholder)
Features:
- Head-based SSM organization
- Group normalization support
- Sequence position tracking
- Extended convolutional states
Example:
>>> metadata = Mamba2CacheMetaData.create(
... partition_axis=partition_axis,
... num_hidden_layers=32,
... batch_size=2,
... intermediate_size=2816,
... num_heads=16,
... head_dim=64,
... state_size=128,
... conv_kernel_size=4,
... n_groups=8
... )
>>> cache = Mamba2Cache.init_cache(
... num_hidden_layers=32,
... metadata=metadata,
... dtype=jnp.float32
... )
"""
from __future__ import annotations
from eformer.escale import PartitionAxis, with_sharding_constraint
from eformer.jaximus import ImplicitArray
from eformer.pytree import auto_pytree, field
from jax import numpy as jnp
from jax.sharding import PartitionSpec
from jaxtyping import Array, Float, Int
from .._abstracts import BaseCache, BaseCacheMetadata, BaseCacheView, BaseRunTimeMetadata
[docs]@auto_pytree
class Mamba2CacheView(BaseCacheView):
"""Single-layer cache view for Mamba2 state-space model.
Manages both convolutional and SSM states for one Mamba2 layer,
with additional tracking for sequence positions and offsets.
The multi-head organization allows for richer state representations.
Attributes:
conv_states (cx.Array | ImplicitArray): Convolutional states buffer.
Shape: [batch, intermediate_size + 2*n_groups*state_size, kernel_size]
ssm_states (cx.Array | ImplicitArray): State-space model states.
Shape: [batch, num_heads, head_dim, state_size]
positions (cx.Array): Current position per sequence.
Shape: [batch_size]
seqlen_offset (int): Global sequence offset for continuation.
metadata (Mamba2CacheMetaData): Static configuration.
layer_index (int | None): Layer index in model.
"""
conv_states: Float[Array, "batch extended_size conv_kernel_size"] | ImplicitArray
ssm_states: Float[Array, "batch num_heads head_dim state_size"] | ImplicitArray
positions: Int[Array, "batch"] # noqa: F821
seqlen_offset: int = field(pytree_node=False)
metadata: Mamba2CacheMetaData
layer_index: int | None = field(pytree_node=False, default=None)
[docs] @classmethod
def init(
cls,
metadata: Mamba2CacheMetaData,
partition_specs: PartitionSpec,
dtype: jnp.dtype,
layer_index: int | None = None,
) -> Mamba2CacheView:
"""Initialize a Mamba2 cache view with zero states.
Creates cache tensors for the extended Mamba2 architecture,
including multi-head SSM states and expanded convolutional
buffers that incorporate group normalization dimensions.
Args:
metadata (Mamba2CacheMetaData): Configuration for cache dimensions.
partition_specs (PartitionSpec): Sharding specification for
distributed execution.
dtype (jnp.dtype): Data type for state tensors.
layer_index (int | None): Optional layer index in model.
Returns:
Mamba2CacheView: Initialized cache view with allocated tensors.
Note:
Conv state size is extended by 2*n_groups*state_size to
accommodate group normalization features.
"""
return cls(
conv_states=with_sharding_constraint(
arr=jnp.zeros(
shape=(
metadata.batch_size,
metadata.intermediate_size + 2 * metadata.n_groups * metadata.state_size,
metadata.conv_kernel_size,
),
dtype=dtype,
),
sharding=partition_specs,
),
ssm_states=with_sharding_constraint(
arr=jnp.zeros(
shape=(
metadata.batch_size,
metadata.num_heads,
metadata.head_dim,
metadata.state_size,
),
dtype=dtype,
),
sharding=partition_specs,
),
positions=jnp.zeros((metadata.batch_size,), "i4"),
metadata=metadata,
layer_index=layer_index,
seqlen_offset=0,
)
[docs] def concatenate_to_cache(self, *args, **kwargs) -> tuple:
"""Not implemented for Mamba2 cache.
Mamba2 uses separate update methods for conv and SSM states.
Raises:
NotImplementedError: Always raised.
Note:
Use `update_conv_state()` and `update_ssm_state()` instead.
"""
raise NotImplementedError()
[docs] def update_conv_state(
self,
new_conv_state: Float[Array, "batch extended_size"],
cache_position: Int[Array, "..."],
) -> Mamba2CacheView:
"""Update convolutional state with new values.
Maintains a rolling buffer of convolutional states, with
support for the extended state size used in Mamba2.
Args:
new_conv_state (cx.Array): New conv state to insert.
Shape: [batch, intermediate_size + 2*n_groups*state_size]
cache_position (cx.Array): Position index for insertion.
Returns:
Mamba2CacheView: Updated view with new conv state.
"""
cache_position = jnp.clip(cache_position, 0, self.metadata.conv_kernel_size - 1)
conv_state = jnp.roll(self.conv_states, shift=-1, axis=-1)
updated_conv_states = conv_state.at[:, :, cache_position].set(new_conv_state)
self.conv_states = updated_conv_states
return self
[docs] def update_ssm_state(
self,
new_ssm_state: Float[Array, "batch num_heads head_dim state_size"],
) -> Mamba2CacheView:
"""Update SSM state with head-structured representation.
Replaces the multi-head SSM state with new values,
maintaining the head-based organization.
Args:
new_ssm_state (cx.Array): New SSM state.
Shape: [batch, num_heads, head_dim, state_size]
Returns:
Mamba2CacheView: Updated view with new SSM state.
"""
self.ssm_states = new_ssm_state
return self
[docs] def reset(self) -> Mamba2CacheView:
"""Reset all cache states to zeros.
Clears both convolutional and SSM states while preserving
structure and shape. Position tracking is maintained.
Returns:
Mamba2CacheView: Reset view with zeroed states.
"""
self.conv_states = jnp.zeros_like(self.conv_states)
self.ssm_states = jnp.zeros_like(self.ssm_states)
return self
[docs]@auto_pytree
class Mamba2Cache(BaseCache):
"""Multi-layer Mamba2 cache container.
Orchestrates Mamba2 cache views across all model layers,
with additional support for sequence position tracking
and batch sequence updates.
Attributes:
views (list[Mamba2CacheView | None]): Per-layer cache views.
"""
views: list[Mamba2CacheView | None]
[docs] @classmethod
def init_cache(
cls,
num_hidden_layers: int,
metadata: Mamba2CacheMetaData,
dtype: jnp.dtype | None = None,
partition_specs: PartitionSpec | None = None,
) -> Mamba2Cache:
"""Initialize a complete Mamba2 cache for all layers.
Creates a fully initialized cache with allocated storage for
the specified number of layers. Each layer gets independent
multi-head SSM states and convolutional buffers.
Args:
num_hidden_layers (int): Number of Mamba2 layers to initialize.
metadata (Mamba2CacheMetaData): Configuration for cache dimensions.
dtype (jnp.dtype | None): Data type for tensors. Defaults to bfloat16.
partition_specs (PartitionSpec | None): Sharding specification.
If None, creates default spec with batch, head, and sequence axes.
Returns:
Mamba2Cache: Fully initialized multi-layer cache.
Example:
>>> cache = Mamba2Cache.init_cache(
... num_hidden_layers=32,
... metadata=metadata,
... dtype=jnp.float32
... )
"""
paxis = PartitionAxis()
partition_specs = partition_specs or PartitionSpec(
paxis.batch_axis,
paxis.head_axis,
paxis.sequence_axis,
)
if dtype is None:
dtype = jnp.bfloat16
return cls(
views=[
Mamba2CacheView.init(
metadata=metadata,
partition_specs=partition_specs,
dtype=dtype,
layer_index=layer_index,
)
for layer_index in range(num_hidden_layers)
]
)
[docs] def update_conv_state(
self,
layer_idx: int,
new_conv_state: Float[Array, "batch extended_size"],
cache_position: Int[Array, "..."],
) -> Mamba2Cache:
"""
Update the convolutional state for a specific layer.
Arguments:
layer_idx: Index of the layer to update
new_conv_state: New state to be inserted
cache_position: Position in the cache to update
Returns:
Updated MambaCache
"""
if self.views[layer_idx] is None:
raise ValueError(f"Cache view for layer {layer_idx} is None")
updated_view = self.views[layer_idx].update_conv_state(
new_conv_state=new_conv_state,
cache_position=cache_position,
)
new_views = list(self.views)
new_views[layer_idx] = updated_view
return self.replace(views=new_views)
[docs] def update_ssm_state(
self,
layer_idx: int,
new_ssm_state: Float[Array, "batch num_heads head_dim state_size"],
) -> Mamba2Cache:
"""
Update the SSM state for a specific layer.
Arguments:
layer_idx: Index of the layer to update
new_ssm_state: New SSM state to replace the current one
Returns:
Updated MambaCache
"""
if self.views[layer_idx] is None:
raise ValueError(f"Cache view for layer {layer_idx} is None")
updated_view = self.views[layer_idx].update_ssm_state(
new_ssm_state=new_ssm_state,
)
new_views = list(self.views)
new_views[layer_idx] = updated_view
return self.replace(views=new_views)
[docs] def reset(self) -> Mamba2Cache:
"""
Reset all cache views to their initial state.
Returns:
Reset MambaCache
"""
new_views = [view.reset() if view is not None else None for view in self.views]
return self.replace(views=new_views)
[docs] @classmethod
def init_empty(cls, num_hidden_layers: int) -> Mamba2Cache:
"""Initialize an empty Mamba2 cache structure.
Creates a cache with None placeholders for gradual initialization.
Args:
num_hidden_layers (int): Number of layer placeholders.
Returns:
Mamba2Cache: Cache with uninitialized views.
Example:
>>> cache = Mamba2Cache.init_empty(32)
>>> # Initialize layers individually later
"""
return cls(views=[None for _ in range(num_hidden_layers)])
[docs] def update_seq(self, num: int) -> None:
"""Update sequence positions across all layers.
Increments position tracking for sequence continuation,
useful when processing long sequences in chunks.
Args:
num (int): Number of positions to advance.
Note:
Updates both positions and seqlen_offset for each layer.
"""
for view in self.views:
if view is not None:
view.positions += num
view.seqlen_offset += num
def __repr__(self) -> str:
return f"{self.__class__.__name__}(\n " + "\n ".join(str(view) for view in self.views) + "\n)"
__str__ = __repr__