# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import itertools
import typing as tp
import chex
import jax
import jax.numpy as jnp
from eformer.pytree import auto_pytree
from einops import repeat
from ejkernel.types import MaskInfo
from flax import nnx as nn
from jax import lax
from jax.ad_checkpoint import checkpoint_name
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import BaseModelOutput
from easydel.infra.utils import ACT2FN, ArrayParam, auto_remat, get_dot_general_by_bits
from easydel.layers.caching import MambaCache, MambaCacheMetaData, MambaCacheView
from easydel.layers.linear import ColumnParallelLinear
from easydel.layers.norms import RMSNorm as MambaRMSNorm
from .mamba_configuration import MambaConfig as MambaConfig
[docs]def init_to_value(x, dtype):
"""Return initializer that fills parameters with a broadcasted constant."""
return lambda _, shape, dtype: jnp.broadcast_to(jnp.asarray(x, dtype=dtype), shape)
[docs]@auto_pytree
class MambaOutput(BaseModelOutput):
"""Output container for the base Mamba model with cached state."""
last_hidden_state: chex.Array = None
cache: MambaCache | None = None
hidden_states: tuple[chex.Array] | None = None
[docs]@auto_pytree
class MambaCausalLMOutput(BaseModelOutput):
"""Causal LM output including logits and cache for Mamba decoding."""
logits: chex.Array = None
cache: MambaCache | None = None
hidden_states: tuple[chex.Array] | None = None
last_hidden_state: chex.Array | None = None
_T = tp.TypeVar("_T")
[docs]def create_tuple_parser(
n: int,
) -> tp.Callable[[_T | tp.Sequence[_T]], tuple[_T, ...]]:
"""Normalize a scalar or sequence into a tuple of length ``n``."""
def parse(x: _T | tp.Sequence[_T]) -> tuple[_T, ...]:
if isinstance(x, tp.Sequence):
if len(x) == n:
return tuple(x)
else:
raise ValueError(f"x!=n ({x}!=({n}))")
else:
return tuple(itertools.repeat(x, n))
return parse
[docs]class Lambda(nn.Module):
"""Convenience wrapper to insert callables into module pipelines."""
fn: tp.Callable
def __call__(self, x, **kwargs):
return self.fn(x, **kwargs)
[docs]class MambaConv1D(nn.Module):
"""Minimal 1D convolution layer backing the Mamba mixer implementation."""
def __init__(
self,
features: int,
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
use_bias: bool = True,
num_spatial_dims: int = 1,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: str | lax.Precision | None = None,
*,
rngs: nn.Rngs,
):
kernel_shape = (kernel_size, 1, features)
self.kernel = ArrayParam.bound(
shape=kernel_shape,
dtype=param_dtype,
init_method="lecun_normal",
key=rngs.params(),
)
if use_bias:
self.bias = ArrayParam.bound(
shape=(features,),
dtype=param_dtype,
init_method="zeros",
key=rngs.params(),
)
self.features = features
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.use_bias = use_bias
self.num_spatial_dims = num_spatial_dims
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
def __call__(self, x):
unbatched_rank = self.num_spatial_dims + 2
if x.ndim != unbatched_rank:
raise ValueError(
f"Input to `Conv` needs to have rank {unbatched_rank}, but input has shape {x.shape}.",
)
org_x_dtype = x.dtype
x = lax.conv_general_dilated(
lhs=x.astype(self.dtype),
rhs=jnp.asarray(jnp.swapaxes(self.kernel.value, 0, 2), dtype=self.dtype),
window_strides=(self.stride,),
padding=((self.padding, self.padding),),
rhs_dilation=(self.dilation,),
feature_group_count=self.groups,
)
if self.use_bias:
x = x + jnp.asarray(self.bias.value.reshape(1, -1, 1), dtype=self.dtype)
return x.astype(org_x_dtype)
[docs]class MambaMixer(nn.Module):
"""Core selective state space mixer used inside each Mamba block."""
def __init__(
self,
config: MambaConfig,
layer_idx: int,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: str | lax.Precision | None = None,
*,
rngs: nn.Rngs,
) -> None:
self.config = config
self.layer_idx = layer_idx
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
hidden_size = config.hidden_size
ssm_state_size = config.state_size
intermediate_size = config.intermediate_size
time_step_rank = config.time_step_rank
conv_kernel_size = config.conv_kernel
self.conv1d = MambaConv1D(
features=intermediate_size,
use_bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=intermediate_size,
padding=config.conv_kernel - 1,
rngs=rngs,
)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
dt_init_std = time_step_rank**-0.5 * config.time_step_scale
if config.time_step_init_scheme == "constant":
init_kernel_dt = nn.initializers.constant(dt_init_std, dtype=param_dtype)
elif config.time_step_init_scheme == "random":
def init_kernel_dt(key, _shape, _dtype):
return (
jax.nn.initializers.uniform(scale=dt_init_std * 2, dtype=param_dtype)(key, _shape, _dtype)
- dt_init_std
)
else:
init_kernel_dt = nn.initializers.normal(config.initializer_range, param_dtype)
dt = jax.lax.clamp(
config.time_step_floor,
jnp.exp(
jax.random.normal(
key=rngs.params(),
shape=(intermediate_size,),
dtype=jnp.float32,
)
* (jnp.log(config.time_step_max) - jnp.log(config.time_step_min))
+ jnp.log(config.time_step_min)
),
config.time_step_max,
)
inv_dt = dt + jnp.log(-jnp.expm1(-dt))
linear_class = functools.partial(
ColumnParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.in_proj = linear_class(
hidden_size,
intermediate_size * 2,
use_bias=config.use_bias,
rngs=rngs,
)
self.x_proj = linear_class(
intermediate_size,
time_step_rank + ssm_state_size * 2,
use_bias=False,
rngs=rngs,
)
self.dt_proj = linear_class(
time_step_rank,
intermediate_size,
use_bias=True,
kernel_init=init_kernel_dt,
bias_init=lambda _, shape, dtype: inv_dt,
rngs=rngs,
)
self.out_proj = linear_class(
intermediate_size,
hidden_size,
use_bias=config.use_bias,
rngs=rngs,
)
A = repeat(jnp.arange(1, ssm_state_size + 1), "n -> d n", d=intermediate_size)
self.A_log = ArrayParam.bound(
shape=A.shape,
dtype=A.dtype,
init_method="zeros",
key=None,
value=jnp.log(A),
)
self.D = ArrayParam.bound(
shape=(intermediate_size,),
dtype=param_dtype,
init_method="ones",
key=None,
)
self.ssm_state_size = ssm_state_size
self.intermediate_size = intermediate_size
self.conv_kernel_size = conv_kernel_size
self.time_step_rank = time_step_rank
def __call__(
self,
input_states: chex.Array,
cache: MambaCacheView | None = None,
position_ids: chex.Array | None = None,
attention_mask: chex.Array | None = None,
):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = checkpoint_name(self.in_proj(input_states), name="ssm_input_proj")
projected_states = jnp.swapaxes(projected_states, 2, 1)
# [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = jnp.split(projected_states, 2, axis=1)
if attention_mask is not None:
hidden_states = hidden_states * jnp.expand_dims(attention_mask, 1)
# 2. Convolution sequence transformation
if cache is not None:
ssm_state = jnp.array(cache.ssm_states)
if position_ids.shape[0] == self.conv_kernel_size:
conv_state = jnp.pad(
hidden_states,
(
(0, 0),
(0, 0),
(self.conv_kernel_size - hidden_states.shape[-1], 0),
),
)
cache.update_conv_state(conv_state, position_ids)
hidden_states = self.act(
self.conv1d(hidden_states)[..., :seq_len]
) # [batch, intermediate_size, seq_len]
else:
conv_state = cache.update_conv_state(hidden_states, position_ids)
hidden_states = jnp.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1)
if self.use_conv_bias:
hidden_states = hidden_states + self.conv1d.bias
hidden_states = jnp.expand_dims(
self.act(hidden_states).astype(dtype), -1
) # [batch, intermediate_size, 1]
else:
ssm_state = jnp.zeros((batch_size, self.intermediate_size, self.ssm_state_size), dtype=dtype)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len])
# [batch, intermediate_size, seq_len]
if attention_mask is not None:
hidden_states = hidden_states * jnp.expand_dims(attention_mask, 1)
# 3. State Space Model sequence transformation
# 3.a. Selection
ssm_parameters = checkpoint_name(self.x_proj(jnp.swapaxes(hidden_states, 2, 1)), name="ssm_x_proj")
time_step, B, C = jnp.split(
ssm_parameters,
[
self.time_step_rank,
self.ssm_state_size + self.time_step_rank,
],
axis=-1,
)
discrete_time_step = checkpoint_name(self.dt_proj(time_step), name="ssm_dt_proj")
# [batch, seq_len, intermediate_size]
discrete_time_step = jnp.swapaxes(jax.nn.softplus(discrete_time_step), 2, 1)
# [batch, intermediate_size, seq_len]
# 3.b. Discretization
A = -jnp.exp(self.A_log.value.astype(jnp.float32))
# [intermediate_size, ssm_state_size]
modified_a = jnp.expand_dims(jnp.expand_dims(A, axis=0), axis=2)
modified_time_step = jnp.expand_dims(discrete_time_step, axis=-1)
discrete_A = jnp.exp(modified_a * modified_time_step)
discrete_B = modified_time_step * B[:, jnp.newaxis, :, :].astype(jnp.float32)
# [batch, intermediate_size, seq_len, ssm_state_size]
deltaB_u = discrete_B * hidden_states[:, :, :, jnp.newaxis].astype(jnp.float32)
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]
# [batch, intermediate_size, 1, ssm_state]
scan_output = jax.lax.batch_matmul(
ssm_state.astype(dtype),
jnp.expand_dims(C[:, i, :], -1),
)
# [batch, intermediate_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = jnp.stack(scan_outputs, axis=-1)
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = scan_output * self.act(gate)
if cache is not None:
cache.ssm_states = ssm_state
# 4. Final linear projection
contextualized_states = checkpoint_name(self.out_proj(jnp.swapaxes(scan_output, 2, 1)), name="ssm_output_proj")
return contextualized_states, cache
[docs]class MambaBlock(nn.Module):
"""Single Mamba layer applying normalization, mixer, and residual add."""
def __init__(
self,
config: MambaConfig,
layer_idx: int,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: str | lax.Precision | None = None,
*,
rngs: nn.Rngs,
):
self.config = config
self.layer_idx = layer_idx
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.residual_in_fp32 = config.residual_in_fp32
self.norm = MambaRMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
param_dtype=param_dtype,
)
block = auto_remat(
MambaMixer,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.mixer = block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
)
def __call__(
self,
hidden_states: chex.Array,
cache: MambaCacheView | None = None,
position_ids: chex.Array | None = None,
attention_mask: chex.Array | None = None,
) -> chex.Array:
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.residual_in_fp32:
residual = residual.astype(jnp.float32)
hidden_states, cache = self.mixer(
hidden_states,
cache,
position_ids,
attention_mask,
)
hidden_states = residual + hidden_states
return hidden_states, cache
[docs]@register_module(TaskType.BASE_MODULE, config=MambaConfig, model_type="mamba")
class MambaModel(EasyDeLBaseModule):
"""Sequence model built from stacked Mamba blocks and token embeddings."""
def __init__(
self,
config: MambaConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: str | lax.Precision | None = None,
*,
rngs: nn.Rngs,
) -> None:
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.embeddings = nn.Embed(
num_embeddings=config.vocab_size,
features=config.hidden_size,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.layers = [
MambaBlock(
config=config,
layer_idx=layer_idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for layer_idx in range(config.num_hidden_layers)
]
self.norm_f = MambaRMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
param_dtype=param_dtype,
)
def __call__(
self,
input_ids: chex.Array | None = None,
inputs_embeds: chex.Array | None = None,
cache: MambaCache | None = None,
position_ids: chex.Array | None = None,
attention_mask: chex.Array | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> tuple | MambaOutput:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
sequence_length = inputs_embeds.shape[1]
if attention_mask is None:
attention_mask = jnp.ones((inputs_embeds.shape[0], sequence_length), "b1")
else:
if attention_mask.dtype != jnp.bool:
attention_mask = jnp.astype(attention_mask == 1, "b1")
mask_info = MaskInfo.dynamic_init(
mask_info=None,
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
if position_ids is None:
position_ids = mask_info.q_position_ids
if cache is None:
cache = MambaCache.init_empty(len(self.layers))
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for idx, block in enumerate(self.layers):
hidden_states, cache_view = block(
hidden_states=hidden_states,
cache=cache.views[idx],
attention_mask=attention_mask,
position_ids=position_ids,
)
cache[idx] = cache_view
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = (*all_hidden_states, hidden_states)
return MambaOutput(
last_hidden_state=hidden_states,
cache=cache,
hidden_states=all_hidden_states,
)
[docs] def init_cache(
self,
batch_size: int,
max_length: int,
starts: int | None = None,
shardings: dict | None = None,
pad_token_id: int | None = None,
):
shardings = shardings or dict()
return MambaCache.init_cache(
dtype=self.dtype,
partition_specs=jax.sharding.PartitionSpec(
self.config.partition_axis.batch_axis,
self.config.partition_axis.key_sequence_axis,
self.config.partition_axis.head_axis,
self.config.partition_axis.attention_dim_axis,
),
metadata=MambaCacheMetaData.create(
num_hidden_layers=self.config.num_hidden_layers,
partition_axis=self.config.partition_axis,
batch_size=batch_size,
sequence_length=max_length,
num_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
),
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
Base Models don't have a Language Model Head.
"""
raise NotImplementedError("The base model does not have a language model head.")
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.embeddings
[docs]@register_module(TaskType.CAUSAL_LM, config=MambaConfig, model_type="mamba")
class MambaForCausalLM(EasyDeLBaseModule):
"""Causal language model head on top of the Mamba backbone."""
def __init__(
self,
config: MambaConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: str | lax.Precision | None = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.backbone = MambaModel(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
lm_head_block = ColumnParallelLinear
lm_head_block = auto_remat(
lm_head_block,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.lm_head = lm_head_block(
config.hidden_size,
config.vocab_size,
use_bias=False,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
input_ids: chex.Array | None = None,
inputs_embeds: chex.Array | None = None,
cache: MambaCache | None = None,
position_ids: chex.Array | None = None,
apply_lm_head: bool = True,
attention_mask: chex.Array | None = None,
output_hidden_states: bool | None = None,
**kwargs,
) -> tuple | MambaCausalLMOutput:
mamba_outputs = self.backbone(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
cache=cache,
output_hidden_states=output_hidden_states,
)
logits = None
if apply_lm_head:
logits = self.apply_lm_head(mamba_outputs.last_hidden_state)
return MambaCausalLMOutput(
logits=logits,
cache=mamba_outputs.cache,
hidden_states=mamba_outputs.hidden_states,
last_hidden_state=mamba_outputs.last_hidden_state,
)
[docs] def init_cache(
self,
batch_size: int,
max_length: int,
starts: int | None = None,
shardings: dict | None = None,
pad_token_id: int | None = None,
):
shardings = shardings or dict()
return MambaCache.init_cache(
dtype=self.dtype,
partition_specs=jax.sharding.PartitionSpec(
self.config.partition_axis.batch_axis,
self.config.partition_axis.key_sequence_axis,
self.config.partition_axis.head_axis,
self.config.partition_axis.attention_dim_axis,
),
metadata=MambaCacheMetaData.create(
num_hidden_layers=self.config.num_hidden_layers,
partition_axis=self.config.partition_axis,
batch_size=batch_size,
sequence_length=max_length,
num_heads=self.config.num_key_value_heads,
head_dim=self.config.head_dim,
),
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self.backbone.get_decoder()
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
"""
return self.lm_head
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.backbone.get_embedding()