easydel.modules.qwen3_moe.modeling_qwen3_moe

Contents

easydel.modules.qwen3_moe.modeling_qwen3_moe#

class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Qwen3Moe Attention with Q/K normalization.

Inherits Q/K normalization (RMSNorm) from QKNormAttention. Features: - Layer-specific sliding window based on layer_idx and max_window_layers - MoE model variant

class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen3Moe Transformer Decoder Layer.

This module represents a single decoder layer in the Qwen3Moe model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

layer_idx#

The index of the layer in the model.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

RMS normalization applied before the attention layer.

Type

RMSNorm

self_attn#

The self-attention module.

Type

Qwen3MoeAttention

mlp#

The feed-forward (MLP) module.

Type

Qwen3MoeMLP

post_attention_layernorm#

RMS normalization applied after the attention layer and before the MLP layer.

Type

RMSNorm

class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Qwen3MoeModel, Qwen3MoeConfig]

Qwen3 MoE model with a Causal Language Modeling head.

class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen3Moe model with a Sequence Classification head.

This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a

linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen3Moe transformer model.

Type

Qwen3MoeModel

score#

The linear layer for classification.

Type

ParallelLinear

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. This model has a sequence classification head, not an LM Head.

class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen3Moe MLP module.

This module implements the feed-forward network (MLP) used in the Qwen3Moe model. It uses a Gated Linear Unit (GLU) structure with SiLU activation.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_proj#

Linear layer for the GLU gate.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

up_proj#

Linear layer for the GLU value.

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLPStack(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen3Moe MoE MLP using the new ParallelMoELinear layers.

reform_param: ClassVar = {'down_proj$': {'inverse_spliter': <function Qwen3MoeMLPStack.<lambda>>, 'splits': [{'name': 'down_proj.kernel', 'spliter': <function Qwen3MoeMLPStack.<lambda>>}]}, 'gate_up_proj$': {'inverse_spliter': <function Qwen3MoeMLPStack.<lambda>>, 'splits': [{'name': 'gate_proj.kernel', 'spliter': <function Qwen3MoeMLPStack.<lambda>>}, {'name': 'up_proj.kernel', 'spliter': <function Qwen3MoeMLPStack.<lambda>>}]}}#
class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen3Moe model transformer.

This class represents the core transformer architecture of the Qwen3Moe model, consisting of an embedding layer, multiple Qwen3MoeDecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Qwen3MoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseBlock(*args: Any, **kwargs: Any)[source]#

Bases: BaseMoeModule

Sparse Mixture of Experts (MoE) block for Qwen3 MoE.

This block routes input hidden states to a selected subset of experts and combines their outputs.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

gate#

Linear layer for the gating network.

Type

ParallelLinear

experts#

List of expert MLP modules.

Type

nn.List[Qwen3MoeMLP]

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs