easydel.modules.qwen2_moe.modeling_qwen2_moe_flax#
- class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleQwen2 MoE Attention module.
- config#
Configuration object for the model.
- Type
Dimensionality of the hidden states.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_groups#
Number of groups for key/value heads (for GQA).
- Type
int
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- o_proj#
Linear layer for output projection.
- Type
- attention_performer#
Module for performing attention computation.
- resid_dropout#
Dropout layer for residual connections.
- Type
nn.Dropout
- rotary#
Rotary positional embedding module.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA single decoder layer for the Qwen2 MoE model.
This layer combines self-attention, a sparse MoE block (or a standard MLP), and residual connections with layer normalization.
- config#
Configuration object for the model.
- Type
- layer_idx#
Index of the current layer.
- Type
int
- self_attn#
Self-attention module.
- Type
- mlp#
MoE block or standard MLP.
- post_attention_layernorm#
Layer normalization applied after self-attention and before the MLP/MoE block.
- Type
- dropout_rng_key#
Name of the RNG key for dropout.
- Type
str
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen2 MoE model with a Causal Language Modeling (CLM) head.
This class wraps the base Qwen2MoeModel and adds a linear layer (language model head) to predict the next token logits.
- config#
Configuration object for the model.
- Type
- model#
The base Qwen2 MoE model.
- Type
- lm_head#
The language model head (linear layer).
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen2 MoE model with a sequence classification head.
This class wraps the base Qwen2MoeModel and adds a linear layer on top to perform sequence classification tasks.
- config#
Configuration object for the model.
- Type
- model#
The base Qwen2 MoE model.
- Type
- score#
The sequence classification head (linear layer).
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMulti-Layer Perceptron (MLP) block for the Qwen2 MoE model.
- config#
Configuration object for the model.
- Type
- gate_proj#
Linear layer for the gating mechanism.
- Type
- down_proj#
Linear layer for down-projection.
- Type
- up_proj#
Linear layer for up-projection.
- Type
- act_fn#
Activation function (SiLU).
- Type
callable
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen2 MoE transformer model.
This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.
- config#
Configuration object for the model.
- Type
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
nn.List[Qwen2MoeDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing strategy.
- Type
str
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.qwen2_moe.modeling_qwen2_moe_flax.Qwen2MoeSparseMoeBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSparse Mixture of Experts (MoE) block for Qwen2 MoE.
This block routes input hidden states to a selected subset of experts and combines their outputs.
- config#
Configuration object for the model.
- Type
- gate#
Linear layer for the gating network.
- Type
- experts#
List of expert MLP modules.
- Type
nn.List[Qwen2MoeMLP]
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs