# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import typing as tp
import chex
import jax
import jax.numpy as jnp
from flax import nnx as nn
from jax import lax
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.modeling_outputs import FlaxBaseModelOutput
from easydel.infra.utils import (
ACT2FN,
auto_remat,
)
from easydel.layers.caching.mamba2_cache import (
Mamba2Cache,
Mamba2CacheMetaData,
Mamba2CacheView,
)
from easydel.layers.norms import RMSNorm as FlaxMamba2RMSNorm
from easydel.modules.mamba2.mamba2_configuration import Mamba2Config as Mamba2Config
from easydel.utils import traversals as etr
[docs]def init_to_value(x, dtype):
return lambda *_: x.astype(dtype)
[docs]@etr.auto_pytree
class Mamba2Output(FlaxBaseModelOutput):
last_hidden_state: chex.Array = None
cache_params: tp.Optional[Mamba2Cache] = None
hidden_states: tp.Optional[tp.Tuple[chex.Array]] = None
[docs]@etr.auto_pytree
class Mamba2CausalLMOutput(FlaxBaseModelOutput):
logits: chex.Array = None
cache_params: tp.Optional[Mamba2Cache] = None
hidden_states: tp.Optional[tp.Tuple[chex.Array]] = None
[docs]def pad_tensor_by_size(input_tensor: jnp.ndarray, pad_size: int):
"""
Padding x tensor with `pad_size` on the seq_len dim (dim=1)
"""
if input_tensor.ndim == 4:
pad_width = [(0, 0), (0, pad_size), (0, 0), (0, 0)]
else:
pad_width = [(0, 0), (0, pad_size), (0, 0)]
return jnp.pad(input_tensor, pad_width, mode="constant", constant_values=0)
[docs]def reshape_into_chunks(input_tensor, pad_size, chunk_size):
"""
Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
simultaneously splitting it into chunk sequences.
"""
input_tensor = pad_tensor_by_size(input_tensor, pad_size)
if input_tensor.ndim == 3:
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]
)
else:
return input_tensor.reshape(
input_tensor.shape[0],
-1,
chunk_size,
input_tensor.shape[2],
input_tensor.shape[3],
)
[docs]def segment_sum(input_tensor):
"""
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
"""
chunk_size = input_tensor.shape[-1]
input_tensor = jnp.expand_dims(input_tensor, axis=-1)
input_tensor = jnp.tile(input_tensor, (1,) * (input_tensor.ndim - 1) + (chunk_size,))
mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool), k=-1)
input_tensor = jnp.where(mask, input_tensor, 0)
tensor_segsum = jnp.cumsum(input_tensor, axis=-2)
mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=bool), k=0)
tensor_segsum = jnp.where(mask, tensor_segsum, -jnp.inf)
return tensor_segsum
_T = tp.TypeVar("_T")
[docs]def create_tuple_parser(
n: int,
) -> tp.Callable[[tp.Union[_T, tp.Sequence[_T]]], tuple[_T, ...]]:
def parse(x: tp.Union[_T, tp.Sequence[_T]]) -> tuple[_T, ...]:
if isinstance(x, tp.Sequence):
if len(x) == n:
return tuple(x)
else:
raise ValueError(f"x!=n ({x}!=({n}))")
else:
return tuple(itertools.repeat(x, n))
return parse
[docs]class Conv1D(nn.Module):
def __init__(
self,
features: int,
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
use_bias: bool = True,
num_spatial_dims: int = 1,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: tp.Optional[tp.Union[str, lax.Precision]] = None,
*,
rngs: nn.Rngs,
):
self.kernel = nn.Param(
nn.initializers.lecun_normal(dtype=param_dtype)(
rngs.params(),
(kernel_size, 1, features),
param_dtype,
),
)
if use_bias:
self.bias = nn.Param(
nn.initializers.zeros(
rngs.params(),
shape=(features,),
dtype=param_dtype,
)
)
self.features = features
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.use_bias = use_bias
self.num_spatial_dims = num_spatial_dims
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
def __call__(self, x):
unbatched_rank = self.num_spatial_dims + 2
if x.ndim != unbatched_rank:
raise ValueError(
f"Input to `Conv` needs to have rank {unbatched_rank},"
f" but input has shape {x.shape}.",
)
rhs = jnp.asarray(jnp.swapaxes(self.kernel.value, 0, 2), dtype=self.dtype)
x = lax.conv_general_dilated(
lhs=x,
rhs=rhs,
window_strides=(self.stride,),
padding=((self.padding, self.padding),),
rhs_dilation=(self.dilation,),
feature_group_count=self.groups,
)
if self.use_bias:
x = x + jnp.asarray(self.bias.value.reshape(1, -1, 1), dtype=self.dtype)
return x
[docs]class MambaRMSNormGated(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
dtype: jnp.dtype = jnp.float32,
):
self.hidden_size = hidden_size
self.eps = eps
self.dtype = dtype
self.kernel = nn.Param(
jnp.ones((self.hidden_size,), self.dtype),
)
def __call__(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.astype(jnp.float32)
if gate is not None:
gate = gate.astype(jnp.float32)
hidden_states = hidden_states * jax.nn.silu(gate)
variance = jnp.mean(jnp.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * jax.lax.rsqrt(variance + self.eps)
return (self.kernel.value * hidden_states).astype(input_dtype)
[docs]class Mamba2Mixer(nn.Module):
def __init__(
self,
config: Mamba2Config,
layer_idx: int,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: tp.Optional[tp.Union[str, lax.Precision]] = None,
*,
rngs: nn.Rngs,
) -> None:
self.config = config
self.layer_idx = layer_idx
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.num_heads = config.num_heads
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = int(config.expand * self.hidden_size)
self.time_step_rank = int(config.time_step_rank)
self.use_conv_bias = config.use_conv_bias
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.norm_before_gate = config.norm_before_gate
self.layer_norm_epsilon = config.layer_norm_epsilon
self.rms_norm = config.rms_norm
self.n_groups = config.n_groups
self.head_dim = config.head_dim
self.chunk_size = config.chunk_size
self.time_step_limit = config.time_step_limit
self.time_step_min = config.time_step_min
self.time_step_max = config.time_step_max
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
self.conv1d = Conv1D(
features=self.conv_dim,
kernel_size=self.config.conv_kernel,
groups=self.conv_dim,
stride=1,
padding=self.config.conv_kernel - 1,
use_bias=self.config.use_conv_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
use_bias=self.config.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
dt = jax.lax.clamp(
self.config.time_step_floor,
jnp.exp(
jax.random.normal(
key=rngs.params(),
shape=(self.config.num_heads,),
dtype=self.param_dtype,
)
* (jnp.log(self.config.time_step_max) - jnp.log(self.config.time_step_min))
+ jnp.log(self.config.time_step_min)
).astype(jnp.float32),
1e9,
)
inv_dt = dt + jnp.log(-jnp.expm1(-dt))
self.dt_bias = nn.Param(inv_dt.astype(self.param_dtype))
self.A_log = nn.Param(
jnp.log(
jnp.arange(1, self.num_heads + 1, dtype=jnp.float32),
).astype(self.param_dtype),
)
self.D = nn.Param(jnp.ones(self.num_heads, dtype=self.param_dtype))
self.norm = MambaRMSNormGated(
self.intermediate_size,
eps=self.layer_norm_epsilon,
dtype=self.param_dtype,
)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
use_bias=self.config.use_bias,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
input_states: chex.Array,
cache_params: tp.Optional[Mamba2CacheView] = None,
cache_position: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
):
dtype = input_states.dtype
if (
attention_mask is not None
and attention_mask.shape[1] > 1
and attention_mask.shape[0] > 1
):
input_states = (input_states * attention_mask[:, :, None]).astype(dtype)
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# Gated MLP's linear projection
projected_states = self.in_proj(input_states)
d_mlp = (
projected_states.shape[-1]
- 2 * self.intermediate_size
- 2 * self.n_groups * self.ssm_state_size
- self.num_heads
) // 2
_, _, gate, hidden_states, dt = jnp.split(
projected_states,
[
d_mlp,
d_mlp * 2,
d_mlp * 2 + self.intermediate_size,
d_mlp * 2 + self.intermediate_size + self.conv_dim,
],
axis=-1,
)
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].copy()
if cache_params.seqlen_offset > 0:
conv_state = cache_params.conv_states
# [batch, intermediate_size, conv_kernel_size]
conv_state = jnp.roll(conv_state, shifts=-1, axis=-1)
# handle batched generation - states are copied through
conv_state[:, :, -1] = (
hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
)
cache_params.conv_states = jax.lax.dynamic_update_slice(
cache_params.conv_states,
conv_state,
(0, 0, 0, 0),
)
hidden_states = jnp.sum(conv_state * self.conv1d.kernel.value[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias.value
hidden_states = self.act(hidden_states).astype(dtype)[:, None, ...]
# [batch, 1, intermediate_size] : decoding
else:
hidden_states = jnp.swapaxes(hidden_states, 2, 1)
pad_width = [
(0, 0),
(0, 0),
(self.conv_kernel_size - hidden_states.shape[-1], 0),
]
conv_state = jnp.pad(hidden_states, pad_width)
cache_params.conv_states = jax.lax.dynamic_update_slice(
cache_params.conv_states,
conv_state,
(0, 0, 0, 0),
)
# Apply convolution and activation
hidden_states = self.conv1d(hidden_states)
hidden_states = jnp.swapaxes(hidden_states, 2, 1)
hidden_states = self.act(hidden_states)
hidden_states = hidden_states[:, :seq_len, :]
# Apply attention mask if necessary
def apply_mask(hidden_states, attention_mask):
return hidden_states * attention_mask[:, :, None]
def identity(hidden_states):
return hidden_states
mask_condition = (
attention_mask is not None
and attention_mask.shape[1] > 1
and attention_mask.shape[0] > 1
)
hidden_states = jax.lax.cond(
mask_condition,
lambda: apply_mask(hidden_states, attention_mask),
lambda: identity(hidden_states),
)
else:
ssm_state = jnp.zeros(
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
dtype=dtype,
)
convin = self.conv1d(jnp.swapaxes(hidden_states, 2, 1))[..., :seq_len]
hidden_states = self.act(jnp.swapaxes(convin, 2, 1))
hidden_states, B, C = jnp.split(
hidden_states,
[
self.intermediate_size,
self.intermediate_size + self.n_groups * self.ssm_state_size,
],
axis=-1,
)
A = -jnp.exp(self.A_log.value.astype("float32")) # [num_heads]
if cache_params is not None and cache_params.seqlen_offset > 0:
dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
dt = jax.nn.softplus(dt + dt_bias.astype(dt.dtype))
dt = jnp.clip(dt, min=self.time_step_min)
A = (
A[..., None, None]
.expand(self.num_heads, self.head_dim, self.ssm_state_size)
.astype(dtype=jnp.float32)
)
# [bsz, num_heads, head_dim, state_size]
dA = jnp.exp(dt[..., None] * A)
# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
# -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
batch_size = B.shape[0]
# Process B
B = B.reshape(batch_size, self.n_groups, -1, 1)
B = jnp.tile(B, (1, 1, self.num_heads // self.n_groups, 1))
B = B.reshape(batch_size, -1, B.shape[-1])
dB = dt[..., None] * B[..., None, :]
# Process hidden_states
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
dBx = dB * hidden_states[..., None]
# State calculation
dA = jnp.exp(dt[..., None] * A)
new_ssm_states = cache_params.ssm_states[self.layer_idx] * dA + dBx
cache_params = cache_params.ssm_states[self.layer_idx] = new_ssm_states
# Process C
C = C.reshape(batch_size, self.n_groups, -1, 1)
C = jnp.tile(C, (1, 1, self.num_heads // self.n_groups, 1))
C = C.reshape(batch_size, -1, C.shape[-1])
# Compute y
ssm_states = cache_params.ssm_states[self.layer_idx]
ssm_states_reshaped = ssm_states.reshape(
batch_size * self.num_heads, self.head_dim, self.ssm_state_size
)
C_reshaped = C.reshape(batch_size * self.num_heads, self.ssm_state_size, 1)
y = jnp.matmul(ssm_states_reshaped, C_reshaped)
y = y.reshape(batch_size, self.num_heads, self.head_dim)
# D skip connection
D = jnp.tile(self.D[:, None], (1, self.head_dim))
y = y + hidden_states * D
# Reshape y
y = y.reshape(batch_size, -1)[:, None, ...]
else:
# begin ssd naive implementation without einsums
dt = jax.nn.softplus(dt + self.dt_bias)
dt = jnp.clip(dt, min=self.time_step_min)
hidden_states = hidden_states.reshape(
batch_size, seq_len, -1, self.head_dim
).astype(jnp.float32)
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(jnp.float32)
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(jnp.float32)
B = B.repeat(self.num_heads // self.n_groups, 2)
C = C.repeat(self.num_heads // self.n_groups, 2)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A.astype(hidden_states.dtype) * dt
# Rearrange into blocks/chunks
hidden_states, A, B, C = [
reshape_into_chunks(t, pad_size, self.chunk_size)
for t in (hidden_states, A, B, C)
]
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
A = jnp.transpose(A, axes=(0, 3, 1, 2))
A_cumsum = jnp.cumsum(A, axis=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
# This is the analog of a causal mask
L = jnp.exp(segment_sum(A))
# First, contraction of C and B to get G (attention-weights like)
G_intermediate = (
C[:, :, :, None, :, :] * B[:, :, None, :, :, :]
) # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(axis=-1) # shape: (b, c, l, s, h)
# Step 2: Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * jnp.transpose(L, (0, 2, 3, 4, 1))[..., None]
M = M_intermediate.sum(axis=-1)
# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = jnp.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * jnp.transpose(decay_states, (0, 2, 3, 1))[..., None]
# permute back B * decay states
states = jnp.transpose(
(
jnp.transpose(B_decay_contraction, axes=(0, 1, 3, 2, 4))[..., None]
* jnp.transpose(hidden_states, axes=(0, 1, 3, 2, 4))[..., None, :]
).sum(axis=3),
axes=(0, 1, 2, 4, 3),
)
if cache_params is not None and cache_params.seqlen_offset > 0:
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = jnp.zeros_like(states[:, :1])
states = jnp.concatenate([previous_states, states], axis=1)
decay_chunk = jnp.exp(
segment_sum(jnp.pad(A_cumsum[:, :, :, -1], ((0, 0), (0, 0), (1, 0))))
)
states_permuted = jnp.transpose(states, axes=(0, 2, 1, 3, 4))
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(
axis=2
)
new_states = jnp.transpose(result, (0, 2, 1, 3, 4))
states, ssm_state = new_states[:, :-1], new_states[:, -1]
# Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = jnp.exp(A_cumsum)
# compute Yoff
C_times_states = C[..., None, :] * states[:, :, None, ...]
state_decay_out_permuted = jnp.transpose(state_decay_out, axes=(0, 2, 3, 1))
Y_off = C_times_states.sum(-1) * state_decay_out_permuted[..., None]
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
y = y + D_residual
# Cutting off padded chunks
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx] = ssm_state
scan_output = self.norm(y, gate)
contextualized_states = self.out_proj(
scan_output.astype(dtype)
) # [batch, seq_len, hidden_size]
return contextualized_states
[docs]class Mamba2Block(nn.Module):
def __init__(
self,
config: Mamba2Config,
layer_idx: int,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: tp.Optional[tp.Union[str, lax.Precision]] = None,
*,
rngs: nn.Rngs,
) -> None:
self.config = config
self.layer_idx = layer_idx
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.residual_in_fp32 = config.residual_in_fp32
self.norm = FlaxMamba2RMSNorm(
dim=config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
block = Mamba2Mixer
(block,) = auto_remat(
block,
policy=config.gradient_checkpointing,
)
self.mixer = block(
config=config,
layer_idx=layer_idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
hidden_states: chex.Array,
cache_params: tp.Optional[Mamba2CacheView] = None,
cache_position: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
) -> chex.Array:
residual = hidden_states
hidden_states = self.norm(hidden_states)
if self.residual_in_fp32:
residual = residual.astype(jnp.float32)
hidden_states = self.mixer(
hidden_states,
cache_params,
cache_position,
attention_mask,
)
hidden_states = residual + hidden_states
return hidden_states
[docs]@register_module(
TaskType.BASE_MODULE,
config=Mamba2Config,
model_type="mamba2",
)
class Mamba2Model(EasyDeLBaseModule):
def __init__(
self,
config: Mamba2Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: tp.Optional[tp.Union[str, lax.Precision]] = None,
*,
rngs: nn.Rngs,
) -> None:
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.embeddings = nn.Embed(
config.vocab_size,
config.hidden_size,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.layers = [
Mamba2Block(
config=config,
layer_idx=layer_idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for layer_idx in range(config.num_hidden_layers)
]
self.norm_f = FlaxMamba2RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
cache_params: tp.Optional[Mamba2Cache] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
cache_position: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
**kwargs,
) -> tp.Union[tp.Tuple, Mamba2Output]:
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
all_hidden_states = () if output_hidden_states else None
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if cache_params is None:
cache_params = Mamba2Cache.init_empty(len(self.layers))
if attention_mask is None:
attention_mask = jnp.ones(inputs_embeds.shape[:2], dtype="i4")
hidden_states = inputs_embeds
for idx, block in enumerate(self.layers):
hidden_states = block(
hidden_states=hidden_states,
cache_params=cache_params.views[idx],
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, cache_params, all_hidden_states] if v is not None
)
return Mamba2Output(
last_hidden_state=hidden_states,
cache_params=cache_params,
hidden_states=all_hidden_states,
)
[docs]@register_module(
TaskType.CAUSAL_LM,
config=Mamba2Config,
model_type="mamba2",
)
class Mamba2ForCausalLM(EasyDeLBaseModule):
def __init__(
self,
config: Mamba2Config,
dtype: jnp.dtype = jnp.float32,
param_dtype: jnp.dtype = jnp.float32,
precision: tp.Optional[tp.Union[str, lax.Precision]] = None,
*,
rngs: nn.Rngs,
) -> None:
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.backbone = Mamba2Model(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.lm_head = nn.Linear(
config.hidden_size,
config.vocab_size,
use_bias=False,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(
self,
input_ids: tp.Optional[chex.Array] = None,
inputs_embeds: tp.Optional[chex.Array] = None,
cache_params: tp.Optional[Mamba2Cache] = None,
output_hidden_states: tp.Optional[bool] = None,
return_dict: tp.Optional[bool] = None,
cache_position: tp.Optional[chex.Array] = None,
attention_mask: tp.Optional[chex.Array] = None,
**kwargs,
) -> tp.Union[tp.Tuple, Mamba2CausalLMOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
mamba_outputs = self.backbone(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_params=cache_params,
cache_position=cache_position,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = mamba_outputs[0]
logits = self.lm_head(hidden_states).astype(jnp.float32)
if not return_dict:
return (logits,) + mamba_outputs[1:]
return Mamba2CausalLMOutput(
logits=logits,
cache_params=mamba_outputs.cache_params,
hidden_states=mamba_outputs.hidden_states,
)
[docs] def init_cache(self, batch_size: int, max_length: int):
return Mamba2Cache.init_layers_cache(
metadata=Mamba2CacheMetaData(
batch_size=batch_size,
intermediate_size=int(self.config.expand * self.config.hidden_size),
conv_kernel_size=self.config.conv_kernel,
head_dim=self.config.head_dim,
n_groups=self.config.n_groups,
state_size=self.config.state_size,
num_heads=self.config.num_heads,
),
dtype=self.dtype,
num_hidden_layers=self.config.num_hidden_layers,
)