# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from functools import partial
import chex
import jax
import jax.numpy as jnp
from eformer import common_types
from eformer.escale import apply_logical_sharding
from ejkernel.types import MaskInfo
from flax import nnx as nn
from jax.ad_checkpoint import checkpoint_name
from jaxtyping import Array, Bool, Float, Int
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.loss_utils import auxiliary_load_balancing_loss_func
from easydel.infra.modeling_outputs import (
DecoderLayerOutput,
MoeCausalLMOutput,
MoeModelOutput,
SequenceClassifierOutput,
)
from easydel.infra.utils import ACT2FN, auto_remat, get_dot_general_by_bits
from easydel.layers.attention_unified import UnifiedAttention
from easydel.layers.base_modules import BaseCausalLMModule
from easydel.layers.caching import (
RaggedPagesCache,
RaggedPagesCacheView,
RaggedPagesMetadata,
TransformerCache,
TransformerCacheView,
TransformerMetadata,
)
from easydel.layers.linear import ColumnParallelLinear, RowParallelLinear
from easydel.layers.moe import (
BaseMoeModule,
ColumnParallelMoELinear,
MoeLoadBalancingStrategy,
MoeRoutingStrategy,
RowParallelMoELinear,
)
from easydel.layers.norms import RMSNorm as RMSNorm
from .qwen3_moe_configuration import Qwen3MoeConfig
[docs]class Qwen3MoeMLPStack(nn.Module):
"""Qwen3Moe MoE MLP using the new ParallelMoELinear layers."""
reform_param: typing.ClassVar = {
"gate_up_proj$": {
"splits": [
{"name": "gate_proj.kernel", "spliter": lambda x: x[..., : x.shape[-1] // 2]},
{"name": "up_proj.kernel", "spliter": lambda x: x[..., x.shape[-1] // 2 :]},
],
"inverse_spliter": lambda torch, gate, up: torch.stack((gate, up), dim=-1).flatten(-2),
},
"down_proj$": {
"splits": [
{"name": "down_proj.kernel", "spliter": lambda x: x},
],
"inverse_spliter": lambda x: x,
},
}
def __init__(
self,
config: Qwen3MoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__()
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.gate_proj = ColumnParallelMoELinear(
num_experts=config.num_experts,
in_features=config.hidden_size,
out_features=config.moe_intermediate_size,
rngs=rngs,
kernel_init=nn.initializers.normal(),
use_bias=False,
partition_manager=config.partition_manager,
use_expert_tensor_mode=config.use_expert_tensor_mode,
dtype=dtype,
param_dtype=param_dtype,
)
self.down_proj = RowParallelMoELinear(
num_experts=config.num_experts,
in_features=config.moe_intermediate_size,
out_features=config.hidden_size,
rngs=rngs,
use_bias=False,
kernel_init=nn.initializers.normal(),
partition_manager=config.partition_manager,
use_expert_tensor_mode=config.use_expert_tensor_mode,
dtype=dtype,
param_dtype=param_dtype,
)
self.up_proj = ColumnParallelMoELinear(
num_experts=config.num_experts,
in_features=config.hidden_size,
out_features=config.moe_intermediate_size,
rngs=rngs,
use_bias=False,
kernel_init=nn.initializers.normal(),
partition_manager=config.partition_manager,
use_expert_tensor_mode=config.use_expert_tensor_mode,
dtype=dtype,
param_dtype=param_dtype,
)
self.act_fn = ACT2FN[config.hidden_act]
def __call__(
self,
hidden_states: Float[Array, "batch seq_len hidden_dim"],
group_sizes: chex.Array,
sorted_experts: chex.Array | None = None,
) -> chex.Array:
"""Forward pass through MoE MLP."""
return self.down_proj(
self.act_fn(self.gate_proj(hidden_states, group_sizes, sorted_experts))
* self.up_proj(hidden_states, group_sizes, sorted_experts),
group_sizes,
sorted_experts,
)
[docs]class Qwen3MoeMLP(nn.Module):
"""Qwen3Moe MLP module.
This module implements the feed-forward network (MLP) used in the Qwen3Moe model.
It uses a Gated Linear Unit (GLU) structure with SiLU activation.
Attributes:
config (Qwen3MoeConfig): Configuration object for the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
gate_proj (ParallelLinear): Linear layer for the GLU gate.
down_proj (ParallelLinear): Linear layer for the down projection.
up_proj (ParallelLinear): Linear layer for the GLU value.
act_fn (callable): Activation function (SiLU).
"""
def __init__(
self,
config: Qwen3MoeConfig,
intermediate_size=None,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Qwen3MoeMLP module.
Args:
config (Qwen3MoeConfig): The configuration object for the Qwen3Moe model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.bfloat16.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.bfloat16.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
column_parallel_linear = partial(
ColumnParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
row_parallel_linear = partial(
RowParallelLinear,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(config.initializer_range),
precision=precision,
rngs=rngs,
**get_dot_general_by_bits(config.bits, config.easy_method),
)
self.gate_proj = column_parallel_linear(config.hidden_size, intermediate_size, rngs=rngs)
self.down_proj = row_parallel_linear(intermediate_size, config.hidden_size, rngs=rngs)
self.up_proj = column_parallel_linear(config.hidden_size, intermediate_size, rngs=rngs)
self.act_fn = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_states: Float[Array, "batch seq_len hidden_dim"]) -> jnp.ndarray:
"""Forward pass of the Qwen3MoeMLP module.
Args:
hidden_states (jnp.ndarray): Input hidden states.
Returns:
jnp.ndarray: Output hidden states after MLP transformation.
"""
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
gate = checkpoint_name(self.act_fn(self.gate_proj(hidden_states)), "mlp_gate")
up = checkpoint_name(self.up_proj(hidden_states), "mlp_up")
hidden_states = checkpoint_name(self.down_proj(gate * up), "mlp_down")
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
return checkpoint_name(hidden_states, "mlp_output")
[docs]class Qwen3MoeSparseBlock(BaseMoeModule):
"""Sparse Mixture of Experts (MoE) block for Qwen3 MoE.
This block routes input hidden states to a selected subset of experts
and combines their outputs.
Attributes:
config (Qwen3MoeConfig): Configuration object for the model.
gate (ParallelLinear): Linear layer for the gating network.
experts (nn.List[Qwen3MoeMLP]): List of expert MLP modules.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for matrix multiplications.
rngs (nn.Rngs): Random number generators.
"""
def __init__(
self,
config: Qwen3MoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Qwen3MoeSparseBlock module.
Args:
config (Qwen3MoeConfig): The configuration object for the model.
dtype (jnp.dtype): Data type for computations (default: jnp.bfloat16).
param_dtype (jnp.dtype): Data type for parameters (default: jnp.bfloat16).
precision (jax.lax.PrecisionLike): Precision setting for JAX operations (default: None).
rngs (nn.Rngs): Random number generators.
"""
super().__init__(
config=config,
n_routed_experts=config.num_experts,
num_experts_per_tok=config.num_experts_per_tok,
hidden_size=config.hidden_size,
lbl_coef=None,
rzl_coef=None,
routing_strategy=MoeRoutingStrategy.TOP_K if config.norm_topk_prob else MoeRoutingStrategy.TOP_K_NDIV,
load_balancing_strategy=MoeLoadBalancingStrategy.STANDARD,
)
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
self.gate = ColumnParallelLinear(
config.hidden_size,
config.num_experts,
use_bias=False,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
kernel_init=nn.initializers.normal(config.initializer_range),
)
self.experts = Qwen3MoeMLPStack(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
def __call__(self, hidden_states: Float[Array, "batch seq_len hidden_dim"]) -> tuple[chex.Array, chex.Array]:
"""Forward pass of the Sparse MoE block.
Args:
hidden_states (chex.Array): Input hidden states (batch_size * sequence_length, hidden_dim).
Returns:
tp.Tuple[chex.Array, chex.Array]: A tuple containing:
- final_hidden_states (chex.Array): The output hidden states after MoE processing.
- router_logits (chex.Array): The logits output by the gating network.
"""
out, router_logits = self.moe_call(
hidden_state=hidden_states,
gate_layer=self.gate,
expert_layer=self.experts,
wi_kernel=self.experts.gate_proj.kernel.value,
wu_kernel=self.experts.up_proj.kernel.value,
wd_kernel=self.experts.down_proj.kernel.value,
act_fn=self.experts.act_fn,
)
return checkpoint_name(out, "moe_expert_output"), checkpoint_name(router_logits, "moe_router_logits")
[docs]class Qwen3MoeAttention(UnifiedAttention):
"""Qwen3Moe Attention with Q/K normalization.
Inherits Q/K normalization (RMSNorm) from QKNormAttention.
Features:
- Layer-specific sliding window based on layer_idx and max_window_layers
- MoE model variant
"""
def __init__(
self,
config: Qwen3MoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
sliding_window = config.sliding_window
if not (
config.use_sliding_window
and getattr(config, "sliding_window", None) is not None
and layer_idx >= config.max_window_layers
):
sliding_window = None
super().__init__(
config,
dtype,
param_dtype,
precision,
rngs=rngs,
layer_idx=layer_idx,
attention_type="standard",
causal=True,
sliding_window=sliding_window,
use_qk_norm=True,
)
self.layer_idx = layer_idx
def _postprocess_qkv(self, query_states, key_states, value_states):
return self.query_normalization(query_states), self.key_normalization(key_states), value_states
[docs]class Qwen3MoeDecoderLayer(nn.Module):
"""Qwen3Moe Transformer Decoder Layer.
This module represents a single decoder layer in the Qwen3Moe model,
combining self-attention and MLP sub-layers with residual connections
and RMS normalization.
Attributes:
config (Qwen3MoeConfig): Configuration object for the model.
layer_idx (int): The index of the layer in the model.
dtype (jnp.dtype): Data type for computations.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
input_layernorm (RMSNorm): RMS normalization applied before the attention layer.
self_attn (Qwen3MoeAttention): The self-attention module.
mlp (Qwen3MoeMLP): The feed-forward (MLP) module.
post_attention_layernorm (RMSNorm): RMS normalization applied after the attention layer and before the MLP layer.
"""
def __init__(
self,
config: Qwen3MoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
layer_idx: int,
):
"""Initializes the Qwen3MoeDecoderLayer.
Args:
config (Qwen3MoeConfig): The configuration object for the Qwen3Moe model.
layer_idx (int): The index of the layer in the model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.bfloat16.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.bfloat16.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
self.config = config
self.dtype = dtype
self.param_dtype = param_dtype
self.precision = precision
attn_block = Qwen3MoeAttention
mlp_block = Qwen3MoeMLP
moe_block = Qwen3MoeSparseBlock
attn_block, mlp_block, moe_block = auto_remat(
attn_block,
mlp_block,
moe_block,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.self_attn = attn_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
layer_idx=layer_idx,
)
self.is_moe = (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
)
if self.is_moe:
self.mlp = moe_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
else:
self.mlp = mlp_block(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.input_layernorm = RMSNorm(
config.hidden_size,
eps=self.config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size,
eps=self.config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
hidden_states: Float[Array, "batch seq_len hidden_dim"],
mask_info: MaskInfo,
position_ids: Int[Array, "batch seq_len"],
mode: common_types.RUNTIME_MODE_TYPES, # type:ignore
cache_view: TransformerCacheView | RaggedPagesCacheView | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
output_attentions: bool = False,
output_router_logits: bool = False,
frequencies: Float[Array, "seq_len head_dim"] | None = None,
):
"""Forward pass of the Qwen3MoeDecoderLayer module.
Args:
hidden_states (chex.Array): Input hidden states.
attention_mask (chex.Array): Mask to apply on the attention scores. Shape:
(batch_size, 1, query_length, key_length).
position_ids (chex.Array): Position indices for the tokens. Shape: (batch_size, sequence_length).
causal_mask (tp.Optional[chex.Array | bool]): Causal mask for ensuring autoregressive behavior.
cache_view (tp.Optional[TransformerCacheView | RaggedPagesCacheView]): Cache view for attention KVs.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional).
output_attentions (bool): Whether to return attention weights. Default is False.
fcm_mask (tp.Optional[chex.Array]): Flash Chunking Mask (FCM) for attention.
frequencies (tp.Optional[chex.Array]): Precomputed rotary frequency embeddings.
Returns:
tp.Tuple[chex.Array, tp.Optional[chex.Array]]:
A tuple containing the output hidden states and optionally the attention weights.
"""
attn_outputs = self.self_attn(
self.input_layernorm(hidden_states),
mask_info,
position_ids,
mode,
cache_view,
cache_metadata,
output_attentions,
frequencies,
)
hidden_states = checkpoint_name(hidden_states + attn_outputs.attention_output, "residual")
feed_forward_input = self.post_attention_layernorm(hidden_states)
feed_forward_hidden_states = self.mlp(feed_forward_input)
router_logits = None
if self.is_moe:
feed_forward_hidden_states, router_logits = feed_forward_hidden_states
hidden_states = checkpoint_name(hidden_states + feed_forward_hidden_states, "residual")
return DecoderLayerOutput(
hidden_states=checkpoint_name(hidden_states, "layer_output"),
attention_weight=attn_outputs.attention_weight,
router_logits=router_logits if output_router_logits else None,
cache_view=attn_outputs.cache_view,
)
[docs]@register_module(TaskType.BASE_MODULE, config=Qwen3MoeConfig, model_type="qwen3_moe")
class Qwen3MoeModel(EasyDeLBaseModule):
"""The base Qwen3Moe model transformer.
This class represents the core transformer architecture of the Qwen3Moe model,
consisting of an embedding layer, multiple Qwen3MoeDecoderLayer layers,
and a final RMS normalization layer.
Attributes:
config (Qwen3MoeConfig): Configuration object for the model.
dtype (jnp.dtype): Data type for computation.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
embed_tokens (nn.Embed): Embedding layer for input tokens.
layers (tp.List[Qwen3MoeDecoderLayer]): List of decoder layers.
norm (RMSNorm): Final layer normalization.
gradient_checkpointing (EasyDeLGradientCheckPointers): Gradient checkpointing configuration.
"""
def __init__(
self,
config: Qwen3MoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Qwen3MoeModel.
Args:
config (Qwen3MoeConfig): The configuration object for the Qwen3Moe model.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.bfloat16.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.bfloat16.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
"""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
embed_block = auto_remat(
nn.Embed,
policy=config.gradient_checkpointing,
save_names=config.gradient_checkpointing_targets,
exclude_names=config.gradient_checkpointing_targets,
)
self.embed_tokens = embed_block(
config.vocab_size,
config.hidden_size,
embedding_init=jax.nn.initializers.normal(stddev=config.initializer_range),
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
self.layers = [
Qwen3MoeDecoderLayer(
config=config,
layer_idx=layer_idx,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
for layer_idx in range(config.num_hidden_layers)
]
self.norm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
dtype=dtype,
param_dtype=param_dtype,
rngs=rngs,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
output_router_logits: bool | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
) -> MoeModelOutput:
"""Forward pass of the Qwen3MoeModel.
Args:
input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length).
inputs_embeds (tp.Optional[chex.Array]): Input embeddings.
Either `input_ids` or `inputs_embeds` must be provided.
attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices.
Shape: (batch_size, sequence_length).
position_ids (tp.Optional[chex.Array]): Position indices for the tokens.
Shape: (batch_size, sequence_length).
segment_ids (tp.Optional[chex.Array]): Segment IDs (unused).
output_attentions (tp.Optional[bool]): Whether to return attention weights.
Defaults to `config.output_attentions`.
output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers.
Defaults to `config.output_hidden_states`.
past_key_values (tp.Optional[TransformerCache | RaggedPagesCache]):
Precomputed key/value states for attention.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
Returns:
MoeModelOutput: The model's output.
returns a `MoeModelOutput` object containing `last_hidden_state`, `hidden_states` (optional),
and `attentions` (optional).
Raises:
ValueError: If neither `input_ids` nor `inputs_embeds` is provided.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = checkpoint_name(self.embed_tokens(input_ids.astype("i4")), "embeddings")
sequence_length = inputs_embeds.shape[1]
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
all_router_logits = () if output_router_logits else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
assert sequence_length <= self.config.max_position_embeddings, (
f"Maximum Position Embedding Reached ! "
f"(Excepted <= {self.config.max_position_embeddings} got {sequence_length})"
)
mask_info = MaskInfo.dynamic_init(
mask_info=mask_info,
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
)
if position_ids is None:
position_ids = mask_info.q_position_ids
hidden_states = inputs_embeds
if mode is None:
mode = (
common_types.MODE_DECODE
if sequence_length == 1 and past_key_values is not None
else common_types.MODE_TRAIN
)
if past_key_values is None:
past_key_values = TransformerCache.init_empty(len(self.layers))
hidden_states = apply_logical_sharding(
hidden_states,
dynamic_axes=common_types.HiddenStateSharding,
partition_manager=self.config.partition_manager,
)
for idx, block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = block(
hidden_states=hidden_states,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
cache_view=past_key_values.views[idx],
cache_metadata=cache_metadata,
output_attentions=output_attentions,
output_router_logits=output_router_logits,
frequencies=self.frequencies,
)
hidden_states = layer_outputs.hidden_states
if output_attentions:
all_attentions += (layer_outputs.attention_weight,)
past_key_values[idx] = layer_outputs.cache_view
if output_router_logits:
all_router_logits += (layer_outputs[-1],)
hidden_states = checkpoint_name(self.norm(hidden_states), "model_output")
return MoeModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=past_key_values,
router_logits=all_router_logits,
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
Base Models don't have a Language Model Head.
"""
raise NotImplementedError("The base model does not have a language model head.")
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.embed_tokens
[docs]@register_module(TaskType.CAUSAL_LM, config=Qwen3MoeConfig, model_type="qwen3_moe")
class Qwen3MoeForCausalLM(BaseCausalLMModule[Qwen3MoeModel, Qwen3MoeConfig]):
"""Qwen3 MoE model with a Causal Language Modeling head."""
_task_type = TaskType.CAUSAL_LM
_model_type = "qwen3_moe"
_config_class = Qwen3MoeConfig
def __init__(
self,
config: Qwen3MoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
super().__init__(
config=config,
base_model_class=Qwen3MoeModel,
base_model_name="model",
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
lm_head_bias=False,
router_aux_loss_coef=getattr(config, "router_aux_loss_coef", None),
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
output_router_logits: bool | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
apply_lm_head: bool = True,
) -> MoeCausalLMOutput:
"""Forward pass of the Qwen3MoeForCausalLM model."""
return self.forward_moe(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
apply_lm_head=apply_lm_head,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
aux_loss_fn=self._compute_aux_loss,
)
def _compute_aux_loss(self, outputs, attention_mask):
"""Compute auxiliary loss from router logits."""
if outputs.router_logits is None:
return None
aux_loss = auxiliary_load_balancing_loss_func(
gate_logits=outputs.router_logits,
num_experts=self.config.num_experts,
top_k=self.config.num_experts_per_tok,
attention_mask=attention_mask,
)
return aux_loss + (aux_loss * self.config.router_aux_loss_coef)
[docs]@register_module(TaskType.SEQUENCE_CLASSIFICATION, config=Qwen3MoeConfig, model_type="qwen3_moe")
class Qwen3MoeForSequenceClassification(EasyDeLBaseModule):
"""Qwen3Moe model with a Sequence Classification head.
This model consists of the base Qwen3Moe transformer (`Qwen3MoeModel`) followed by a
linear layer (`score`) that projects the transformer's output hidden states
(typically the hidden state of the last token or a pooled representation) to the number of classes
for classification.
Attributes:
config (Qwen3MoeConfig): Configuration object for the model.
dtype (jnp.dtype): Data type for computation.
param_dtype (jnp.dtype): Data type for parameters.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations.
rngs (nn.Rngs): Random number generators.
model (Qwen3MoeModel): The core Qwen3Moe transformer model.
score (ParallelLinear): The linear layer for classification.
"""
def __init__(
self,
config: Qwen3MoeConfig,
dtype: jnp.dtype = jnp.bfloat16,
param_dtype: jnp.dtype = jnp.bfloat16,
precision: jax.lax.PrecisionLike = None,
*,
rngs: nn.Rngs,
):
"""Initializes the Qwen3MoeForSequenceClassification model.
Args:
config (Qwen3MoeConfig): The configuration object for the Qwen3Moe model.
Must include `num_labels`.
dtype (jnp.dtype): Data type for computation. Defaults to jnp.bfloat16.
param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.bfloat16.
precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None.
rngs (nn.Rngs): Random number generators.
Raises:
AssertionError: If `config.num_labels` is not defined.
"""
super().__init__(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
self.model = Qwen3MoeModel(
config=config,
dtype=dtype,
param_dtype=param_dtype,
precision=precision,
rngs=rngs,
)
assert hasattr(config, "num_labels"), (
"in order to use `SequenceClassification` Models in `EasyDeL` "
"you first need to attach `num_labels` to model `config`"
)
self.score = ColumnParallelLinear(
self.config.hidden_size,
config.num_labels,
dtype=dtype,
param_dtype=param_dtype,
use_bias=False,
kernel_init=jax.nn.initializers.normal(stddev=config.initializer_range),
precision=self.precision,
rngs=rngs,
)
def __call__(
self,
input_ids: Int[Array, "batch seq_len"] | None = None,
inputs_embeds: Float[Array, "batch seq_len hidden_dim"] | None = None,
attention_mask: Bool[Array, "batch seq_len"] | None = None,
mask_info: MaskInfo | None = None,
position_ids: Int[Array, "batch seq_len"] | None = None,
mode: common_types.RUNTIME_MODE_TYPES | None = None, # type:ignore
past_key_values: TransformerCache | RaggedPagesCache | None = None,
cache_metadata: TransformerMetadata | RaggedPagesMetadata | None = None,
apply_lm_head: bool = True,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
) -> SequenceClassifierOutput:
"""Forward pass of the Qwen3MoeForSequenceClassification model.
Args:
input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length).
inputs_embeds (tp.Optional[chex.Array]): Input embeddings.
Either `input_ids` or `inputs_embeds` must be provided.
attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices.
Shape: (batch_size, sequence_length).
position_ids (tp.Optional[chex.Array]): Position indices for the tokens.
Shape: (batch_size, sequence_length).
segment_ids (tp.Optional[chex.Array]): Segment IDs (unused).
past_key_values (tp.Optional[TransformerCache | RaggedPagesCache]):
Precomputed key/value states for attention.
cache_metadata (tp.Optional[TransformerMetadata | RaggedPagesMetadata]): Metadata for paged attention.
output_attentions (tp.Optional[bool]): Whether to return attention weights.
Defaults to `config.output_attentions`.
output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers.
Defaults to `config.output_hidden_states`.
Returns:
SequenceClassifierOutput: The model's output,
returns a `SequenceClassifierOutput` object containing `logits`, `hidden_states` (optional),
and `attentions` (optional).
Raises:
ValueError: If `config.pad_token_id` is None and `batch_size > 1`.
"""
transformer_outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
mask_info=mask_info,
position_ids=position_ids,
mode=mode,
past_key_values=past_key_values,
cache_metadata=cache_metadata,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = jnp.argmax(jnp.equal(input_ids, self.config.pad_token_id).astype("i4"), -1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
else:
sequence_lengths = -1
pooled_logits = logits[jnp.arange(batch_size), sequence_lengths]
return SequenceClassifierOutput(
logits=pooled_logits,
past_key_values=past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
[docs] def get_encoder(self):
"""
Returns the encoder part of the model's graph definition.
Decoder-Only models don't have an encoder.
"""
raise NotImplementedError("This is a decoder-only model and does not have an encoder.")
[docs] def get_decoder(self):
"""
Returns the decoder part of the model's graph definition.
"""
return self.model.get_decoder()
[docs] def get_lm_head(self):
"""
Returns the language model head of the module.
This model has a sequence classification head, not an LM Head.
"""
raise NotImplementedError("This model has a sequence classification head, not a language model head.")
[docs] def get_embedding(self):
"""
Returns the embedding layer of the module.
"""
return self.model.get_embedding()