easydel.modules.qwen3.__init__#
- class easydel.modules.qwen3.__init__.Qwen3Config(vocab_size=151936, hidden_size=4096, intermediate_size=22016, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=32, head_dim=128, hidden_act='silu', max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-06, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, **kwargs)[source]#
Bases:
EasyDeLBaseConfig- get_partition_rules(*args, **kwargs)[source]#
Get the partition rules for the model.
- Parameters
fully_sharded_data_parallel (bool, optional, defaults to True) – Whether to use fully sharded data parallelism.
- Returns
The partition rules.
- Return type
tp.Tuple[tp.Tuple[str, PartitionSpec]]
- model_type: str = 'qwen3'#
- class easydel.modules.qwen3.__init__.Qwen3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3 model with a Causal Language Modeling head.
This model consists of the base Qwen3 transformer (Qwen3Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3 transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.qwen3.__init__.Qwen3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleQwen3 model with a Sequence Classification head.
This model consists of the base Qwen3 transformer (Qwen3Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Qwen3 transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.modules.qwen3.__init__.Qwen3Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen3 model transformer.
This class represents the core transformer architecture of the Qwen3 model, consisting of an embedding layer, multiple Qwen3DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Qwen3DecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.