easydel.layers.caching.mamba2.cache#
Mamba2 enhanced state-space model cache implementation.
This module provides caching for Mamba2 models, which extend the original Mamba architecture with additional features including: - Multi-head state space models - Group normalization capabilities - Enhanced convolutional processing - Improved sequence modeling
Mamba2 introduces structured state spaces with head-based organization, similar to multi-head attention but for state-space models.
- Key Components:
Mamba2CacheMetaData: Enhanced configuration with head/group support
Mamba2CacheView: Per-layer state storage with sequence tracking
Mamba2Cache: Multi-layer orchestration with sequence updates
Mamba2Metadata: Runtime metadata (placeholder)
- Features:
Head-based SSM organization
Group normalization support
Sequence position tracking
Extended convolutional states
Example
>>> metadata = Mamba2CacheMetaData.create(
... partition_axis=partition_axis,
... num_hidden_layers=32,
... batch_size=2,
... intermediate_size=2816,
... num_heads=16,
... head_dim=64,
... state_size=128,
... conv_kernel_size=4,
... n_groups=8
... )
>>> cache = Mamba2Cache.init_cache(
... num_hidden_layers=32,
... metadata=metadata,
... dtype=jnp.float32
... )
- class easydel.layers.caching.mamba2.cache.Mamba2Cache(views: list[easydel.layers.caching.mamba2.cache.Mamba2CacheView | None])[source]#
Bases:
BaseCacheMulti-layer Mamba2 cache container.
Orchestrates Mamba2 cache views across all model layers, with additional support for sequence position tracking and batch sequence updates.
- views#
Per-layer cache views.
- Type
list[Mamba2CacheView | None]
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(num_hidden_layers: int, metadata: Mamba2CacheMetaData, dtype: numpy.dtype | None = None, partition_specs: jax.sharding.PartitionSpec | None = None) Mamba2Cache[source]#
Initialize a complete Mamba2 cache for all layers.
Creates a fully initialized cache with allocated storage for the specified number of layers. Each layer gets independent multi-head SSM states and convolutional buffers.
- Parameters
num_hidden_layers (int) – Number of Mamba2 layers to initialize.
metadata (Mamba2CacheMetaData) – Configuration for cache dimensions.
dtype (jnp.dtype | None) – Data type for tensors. Defaults to bfloat16.
partition_specs (PartitionSpec | None) – Sharding specification. If None, creates default spec with batch, head, and sequence axes.
- Returns
Fully initialized multi-layer cache.
- Return type
Example
>>> cache = Mamba2Cache.init_cache( ... num_hidden_layers=32, ... metadata=metadata, ... dtype=jnp.float32 ... )
- classmethod init_empty(num_hidden_layers: int) Mamba2Cache[source]#
Initialize an empty Mamba2 cache structure.
Creates a cache with None placeholders for gradual initialization.
- Parameters
num_hidden_layers (int) – Number of layer placeholders.
- Returns
Cache with uninitialized views.
- Return type
Example
>>> cache = Mamba2Cache.init_empty(32) >>> # Initialize layers individually later
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() Mamba2Cache[source]#
Reset all cache views to their initial state.
- Returns
Reset MambaCache
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(layer_idx: int, new_conv_state: Float[Array, 'batch extended_size'], cache_position: Int[Array, '...']) Mamba2Cache[source]#
Update the convolutional state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_conv_state – New state to be inserted
cache_position – Position in the cache to update
- Returns
Updated MambaCache
- update_seq(num: int) None[source]#
Update sequence positions across all layers.
Increments position tracking for sequence continuation, useful when processing long sequences in chunks.
- Parameters
num (int) – Number of positions to advance.
Note
Updates both positions and seqlen_offset for each layer.
- update_ssm_state(layer_idx: int, new_ssm_state: Float[Array, 'batch num_heads head_dim state_size']) Mamba2Cache[source]#
Update the SSM state for a specific layer.
- Parameters
layer_idx – Index of the layer to update
new_ssm_state – New SSM state to replace the current one
- Returns
Updated MambaCache
- views: list[easydel.layers.caching.mamba2.cache.Mamba2CacheView | None]#
- class easydel.layers.caching.mamba2.cache.Mamba2CacheMetaData(partition_axis: PartitionAxis, num_hidden_layers: int, batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int)[source]#
Bases:
BaseCacheMetadataMetadata configuration for Mamba2 state-space cache.
Extends the original Mamba cache with support for multi-head state-space models and group normalization. The head-based organization allows for more expressive state representations.
- partition_axis#
Tensor partitioning configuration.
- Type
PartitionAxis
Number of Mamba2 layers in the model.
- Type
int
- batch_size#
Number of sequences in batch.
- Type
int
- intermediate_size#
Hidden dimension of MLP layers.
- Type
int
- num_heads#
Number of SSM heads (similar to attention heads).
- Type
int
- head_dim#
Dimension per SSM head.
- Type
int
- state_size#
Size of the state-space representation.
- Type
int
- conv_kernel_size#
Size of convolutional kernel.
- Type
int
- n_groups#
Number of groups for group operations.
- Type
int
- batch_size: int#
- conv_kernel_size: int#
- classmethod create(parition_axis: PartitionAxis, num_hidden_layers: int, batch_size: int, intermediate_size: int, num_heads: int, head_dim: int, state_size: int, conv_kernel_size: int, n_groups: int) Mamba2CacheMetaData[source]#
Create and validate Mamba2 cache metadata.
Factory method that validates all parameters before creating the metadata instance. Ensures all dimensions are positive and compatible.
- Parameters
partition_axis (PartitionAxis) – Sharding configuration.
num_hidden_layers (int) – Number of model layers.
batch_size (int) – Batch size for cache allocation.
intermediate_size (int) – MLP hidden dimension.
num_heads (int) – Number of SSM heads.
head_dim (int) – Dimension per head.
state_size (int) – State-space size per head.
conv_kernel_size (int) – Convolution kernel size.
n_groups (int) – Number of normalization groups.
- Returns
Validated metadata instance.
- Return type
- Raises
ValueError – If any parameter is non-positive.
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- head_dim: int#
- intermediate_size: int#
- n_groups: int#
- num_heads: int#
- num_hidden_layers: int#
- partition_axis: PartitionAxis#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- state_size: int#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.mamba2.cache.Mamba2CacheView(conv_states: jaxtyping.Float[Array, 'batch extended_size conv_kernel_size'] | eformer.jaximus._imus.ImplicitArray, ssm_states: jaxtyping.Float[Array, 'batch num_heads head_dim state_size'] | eformer.jaximus._imus.ImplicitArray, positions: Int[Array, 'batch'], seqlen_offset: int, metadata: Mamba2CacheMetaData, layer_index: int | None = None)[source]#
Bases:
BaseCacheViewSingle-layer cache view for Mamba2 state-space model.
Manages both convolutional and SSM states for one Mamba2 layer, with additional tracking for sequence positions and offsets. The multi-head organization allows for richer state representations.
- conv_states#
Convolutional states buffer. Shape: [batch, intermediate_size + 2*n_groups*state_size, kernel_size]
- Type
cx.Array | ImplicitArray
- ssm_states#
State-space model states. Shape: [batch, num_heads, head_dim, state_size]
- Type
cx.Array | ImplicitArray
- positions#
Current position per sequence. Shape: [batch_size]
- Type
cx.Array
- seqlen_offset#
Global sequence offset for continuation.
- Type
int
- metadata#
Static configuration.
- Type
- layer_index#
Layer index in model.
- Type
int | None
- concatenate_to_cache(*args, **kwargs) tuple[source]#
Not implemented for Mamba2 cache.
Mamba2 uses separate update methods for conv and SSM states.
- Raises
NotImplementedError – Always raised.
Note
Use update_conv_state() and update_ssm_state() instead.
- conv_states: jaxtyping.Float[Array, 'batch extended_size conv_kernel_size'] | eformer.jaximus._imus.ImplicitArray#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(metadata: Mamba2CacheMetaData, partition_specs: PartitionSpec, dtype: dtype, layer_index: int | None = None) Mamba2CacheView[source]#
Initialize a Mamba2 cache view with zero states.
Creates cache tensors for the extended Mamba2 architecture, including multi-head SSM states and expanded convolutional buffers that incorporate group normalization dimensions.
- Parameters
metadata (Mamba2CacheMetaData) – Configuration for cache dimensions.
partition_specs (PartitionSpec) – Sharding specification for distributed execution.
dtype (jnp.dtype) – Data type for state tensors.
layer_index (int | None) – Optional layer index in model.
- Returns
Initialized cache view with allocated tensors.
- Return type
Note
Conv state size is extended by 2*n_groups*state_size to accommodate group normalization features.
- metadata: Mamba2CacheMetaData#
- positions: Int[Array, 'batch']#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- reset() Mamba2CacheView[source]#
Reset all cache states to zeros.
Clears both convolutional and SSM states while preserving structure and shape. Position tracking is maintained.
- Returns
Reset view with zeroed states.
- Return type
- seqlen_offset: int#
- ssm_states: jaxtyping.Float[Array, 'batch num_heads head_dim state_size'] | eformer.jaximus._imus.ImplicitArray#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- update_conv_state(new_conv_state: Float[Array, 'batch extended_size'], cache_position: Int[Array, '...']) Mamba2CacheView[source]#
Update convolutional state with new values.
Maintains a rolling buffer of convolutional states, with support for the extended state size used in Mamba2.
- Parameters
new_conv_state (cx.Array) – New conv state to insert. Shape: [batch, intermediate_size + 2*n_groups*state_size]
cache_position (cx.Array) – Position index for insertion.
- Returns
Updated view with new conv state.
- Return type
- update_ssm_state(new_ssm_state: Float[Array, 'batch num_heads head_dim state_size']) Mamba2CacheView[source]#
Update SSM state with head-structured representation.
Replaces the multi-head SSM state with new values, maintaining the head-based organization.
- Parameters
new_ssm_state (cx.Array) – New SSM state. Shape: [batch, num_heads, head_dim, state_size]
- Returns
Updated view with new SSM state.
- Return type
- class easydel.layers.caching.mamba2.cache.Mamba2Metadata[source]#
Bases:
BaseRunTimeMetadataRuntime metadata for Mamba2 cache operations.
Placeholder for future Mamba2-specific runtime state. May include head masks, group indices, etc.