easydel.modules.qwen2_moe.modeling_qwen2_moe

Contents

easydel.modules.qwen2_moe.modeling_qwen2_moe#

class easydel.modules.qwen2_moe.modeling_qwen2_moe.Qwen2MoeAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Qwen2 MoE Attention module with sliding window support.

Inherits from UnifiedAttention with Qwen2Moe-specific customizations: - Sliding window attention - Custom bias configuration (Q/K/V use qkv_bias, O doesn’t) - Attention dropout

class easydel.modules.qwen2_moe.modeling_qwen2_moe.Qwen2MoeDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

A single decoder layer for the Qwen2 MoE model.

This layer combines self-attention, a sparse MoE block (or a standard MLP), and residual connections with layer normalization.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

layer_idx#

Index of the current layer.

Type

int

self_attn#

Self-attention module.

Type

Qwen2MoeAttention

mlp#

MoE block or standard MLP.

Type

Qwen2MoeSparseBlock | Qwen2MoeMLP

input_layernorm#

Layer normalization applied before self-attention.

Type

RMSNorm

post_attention_layernorm#

Layer normalization applied after self-attention and before the MLP/MoE block.

Type

RMSNorm

dropout_rng_key#

Name of the RNG key for dropout.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.qwen2_moe.modeling_qwen2_moe.Qwen2MoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Qwen2MoeModel, Qwen2MoeConfig]

Qwen2 MoE model with a Causal Language Modeling head.

class easydel.modules.qwen2_moe.modeling_qwen2_moe.Qwen2MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 MoE model with a sequence classification head.

This class wraps the base Qwen2MoeModel and adds a linear layer on top to perform sequence classification tasks.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

model#

The base Qwen2 MoE model.

Type

Qwen2MoeModel

score#

The sequence classification head (linear layer).

Type

ParallelLinear

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. This model has a sequence classification head, not an LM Head.

class easydel.modules.qwen2_moe.modeling_qwen2_moe.Qwen2MoeMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron (MLP) block for the Qwen2 MoE model.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

gate_proj#

Linear layer for the gating mechanism.

Type

ParallelLinear

down_proj#

Linear layer for down-projection.

Type

ParallelLinear

up_proj#

Linear layer for up-projection.

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

class easydel.modules.qwen2_moe.modeling_qwen2_moe.Qwen2MoeMLPStack(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen2Moe MoE MLP using the new ParallelMoELinear layers.

reform_param: ClassVar = {'down_proj$': {'inverse_spliter': <function Qwen2MoeMLPStack.<lambda>>, 'splits': [{'name': 'down_proj.kernel', 'spliter': <function Qwen2MoeMLPStack.<lambda>>}]}, 'gate_up_proj$': {'inverse_spliter': <function Qwen2MoeMLPStack.<lambda>>, 'splits': [{'name': 'gate_proj.kernel', 'spliter': <function Qwen2MoeMLPStack.<lambda>>}, {'name': 'up_proj.kernel', 'spliter': <function Qwen2MoeMLPStack.<lambda>>}]}}#
class easydel.modules.qwen2_moe.modeling_qwen2_moe.Qwen2MoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen2 MoE transformer model.

This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

nn.List[Qwen2MoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing strategy.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseBlock(*args: Any, **kwargs: Any)[source]#

Bases: BaseMoeModule

Sparse Mixture of Experts (MoE) block for Qwen2 MoE.

This block routes input hidden states to a selected subset of experts and combines their outputs.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

gate#

Linear layer for the gating network.

Type

ParallelLinear

experts#

List of expert MLP modules.

Type

nn.List[Qwen2MoeMLP]

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs