easydel.modules.qwen2_moe.modeling_qwen2_moe_flax

Contents

easydel.modules.qwen2_moe.modeling_qwen2_moe_flax#

class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Qwen2 MoE Attention module.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

hidden_size#

Dimensionality of the hidden states.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_groups#

Number of groups for key/value heads (for GQA).

Type

int

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

o_proj#

Linear layer for output projection.

Type

ParallelLinear

attention_performer#

Module for performing attention computation.

Type

FlexibleAttentionModule

resid_dropout#

Dropout layer for residual connections.

Type

nn.Dropout

rotary#

Rotary positional embedding module.

Type

RotaryEmbedding

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

A single decoder layer for the Qwen2 MoE model.

This layer combines self-attention, a sparse MoE block (or a standard MLP), and residual connections with layer normalization.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

layer_idx#

Index of the current layer.

Type

int

self_attn#

Self-attention module.

Type

Qwen2MoeAttention

mlp#

MoE block or standard MLP.

Type

Qwen2MoeSparseMoeBlock | Qwen2MoeMLP

input_layernorm#

Layer normalization applied before self-attention.

Type

RMSNorm

post_attention_layernorm#

Layer normalization applied after self-attention and before the MLP/MoE block.

Type

RMSNorm

dropout_rng_key#

Name of the RNG key for dropout.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 MoE model with a Causal Language Modeling (CLM) head.

This class wraps the base Qwen2MoeModel and adds a linear layer (language model head) to predict the next token logits.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

model#

The base Qwen2 MoE model.

Type

Qwen2MoeModel

lm_head#

The language model head (linear layer).

Type

ParallelLinear

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 MoE model with a sequence classification head.

This class wraps the base Qwen2MoeModel and adds a linear layer on top to perform sequence classification tasks.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

model#

The base Qwen2 MoE model.

Type

Qwen2MoeModel

score#

The sequence classification head (linear layer).

Type

ParallelLinear

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron (MLP) block for the Qwen2 MoE model.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

gate_proj#

Linear layer for the gating mechanism.

Type

ParallelLinear

down_proj#

Linear layer for down-projection.

Type

ParallelLinear

up_proj#

Linear layer for up-projection.

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen2 MoE transformer model.

This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

nn.List[Qwen2MoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing strategy.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeSparseMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Sparse Mixture of Experts (MoE) block for Qwen2 MoE.

This block routes input hidden states to a selected subset of experts and combines their outputs.

config#

Configuration object for the model.

Type

Qwen2MoeConfig

gate#

Linear layer for the gating network.

Type

ParallelLinear

experts#

List of expert MLP modules.

Type

nn.List[Qwen2MoeMLP]

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs