easydel.modules.qwen3_moe.__init__#
- class easydel.modules.qwen3_moe.__init__.Qwen3MoeConfig(vocab_size=151936, hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=4, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=768, num_experts_per_tok=8, num_experts=128, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'qwen3_moe'#
- class easydel.modules.qwen3_moe.__init__.Qwen3MoeForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3Moe model with a Causal Language Modeling head.
This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3Moe transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.qwen3_moe.__init__.Qwen3MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3Moe model with a Sequence Classification head.
This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3Moe transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.modules.qwen3_moe.__init__.Qwen3MoeModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen3Moe model transformer.
This class represents the core transformer architecture of the Qwen3Moe model, consisting of an embedding layer, multiple Qwen3MoeDecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Qwen3MoeDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.