easydel.modules.qwen3_moe.__init__#

class easydel.modules.qwen3_moe.__init__.Qwen3MoeConfig(vocab_size=151936, hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=4, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=768, num_experts_per_tok=8, num_experts=128, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, **kwargs)[source]#

Bases: EasyDeLBaseConfig

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model.

Parameters

fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.

Returns

The partition rules.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

model_type: str = 'qwen3_moe'#
class easydel.modules.qwen3_moe.__init__.Qwen3MoeForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen3Moe model with a Causal Language Modeling head.

This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen3Moe transformer model.

Type

Qwen3MoeModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.modules.qwen3_moe.__init__.Qwen3MoeForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen3Moe model with a Sequence Classification head.

This model consists of the base Qwen3Moe transformer (Qwen3MoeModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen3Moe transformer model.

Type

Qwen3MoeModel

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.modules.qwen3_moe.__init__.Qwen3MoeModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen3Moe model transformer.

This class represents the core transformer architecture of the Qwen3Moe model, consisting of an embedding layer, multiple Qwen3MoeDecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Qwen3MoeConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Qwen3MoeDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers