easydel.modules.qwen3_moe.modeling_qwen3_moe#
- class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionQwen3Moe Attention with Q/K normalization.
Inherits Q/K normalization (RMSNorm) from QKNormAttention. Features: - Layer-specific sliding window based on layer_idx and max_window_layers - MoE model variant
- class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleQwen3Moe Transformer Decoder Layer.
This module represents a single decoder layer in the Qwen3Moe model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.
- config#
Configuration object for the model.
- Type
- layer_idx#
The index of the layer in the model.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- mlp#
The feed-forward (MLP) module.
- Type
- class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[Qwen3MoeModel,Qwen3MoeConfig]Qwen3 MoE model with a Causal Language Modeling head.
- class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3Moe model with a Sequence Classification head.
- This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a
linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3Moe transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleQwen3Moe MLP module.
This module implements the feed-forward network (MLP) used in the Qwen3Moe model. It uses a Gated Linear Unit (GLU) structure with SiLU activation.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- gate_proj#
Linear layer for the GLU gate.
- Type
- down_proj#
Linear layer for the down projection.
- Type
- up_proj#
Linear layer for the GLU value.
- Type
- act_fn#
Activation function (SiLU).
- Type
callable
- class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLPStack(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleQwen3Moe MoE MLP using the new ParallelMoELinear layers.
- reform_param: ClassVar = {'down_proj$': {'inverse_spliter': <function Qwen3MoeMLPStack.<lambda>>, 'splits': [{'name': 'down_proj.kernel', 'spliter': <function Qwen3MoeMLPStack.<lambda>>}]}, 'gate_up_proj$': {'inverse_spliter': <function Qwen3MoeMLPStack.<lambda>>, 'splits': [{'name': 'gate_proj.kernel', 'spliter': <function Qwen3MoeMLPStack.<lambda>>}, {'name': 'up_proj.kernel', 'spliter': <function Qwen3MoeMLPStack.<lambda>>}]}}#
- class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen3Moe model transformer.
This class represents the core transformer architecture of the Qwen3Moe model, consisting of an embedding layer, multiple Qwen3MoeDecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Qwen3MoeDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.modules.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseBlock(*args: Any, **kwargs: Any)[source]#
Bases:
BaseMoeModuleSparse Mixture of Experts (MoE) block for Qwen3 MoE.
This block routes input hidden states to a selected subset of experts and combines their outputs.
- config#
Configuration object for the model.
- Type
- gate#
Linear layer for the gating network.
- Type
- experts#
List of expert MLP modules.
- Type
nn.List[Qwen3MoeMLP]
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs