easydel.modules.qwen3.modeling_qwen3#
- class easydel.modules.qwen3.modeling_qwen3.Qwen3Attention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionQwen3 Attention with Q/K normalization.
Inherits Q/K normalization (RMSNorm) from QKNormAttention. Features: - Layer-specific sliding window
- class easydel.modules.qwen3.modeling_qwen3.Qwen3DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleQwen3 Transformer Decoder Layer.
This module represents a single decoder layer in the Qwen3 model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.
- config#
Configuration object for the model. layer_idx (int): The index of the layer in the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- post_attention_layernorm#
RMS normalization applied after the attention layer and before the MLP layer.
- Type
- config: Qwen3Config#
- precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset]#
- self_attn: Qwen3Attention#
- class easydel.modules.qwen3.modeling_qwen3.Qwen3ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[Qwen3Model,Qwen3Config]Qwen3 model with a Causal Language Modeling head.
- class easydel.modules.qwen3.modeling_qwen3.Qwen3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
BaseSequenceClassificationModule[Qwen3Model,Qwen3Config]Qwen3 model with a Sequence Classification head.
- class easydel.modules.qwen3.modeling_qwen3.Qwen3MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleQwen3 MLP module.
This module implements the feed-forward network (MLP) used in the Qwen3 model. It uses a Gated Linear Unit (GLU) structure with SiLU activation.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- gate_proj#
Linear layer for the GLU gate.
- Type
- down_proj#
Linear layer for the down projection.
- Type
- up_proj#
Linear layer for the GLU value.
- Type
- act_fn#
Activation function (SiLU).
- Type
callable
- act_fn: callable#
- config: Qwen3Config#
- down_proj: RowParallelLinear#
- gate_proj: ColumnParallelLinear#
- precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset]#
- up_proj: ColumnParallelLinear#
- class easydel.modules.qwen3.modeling_qwen3.Qwen3Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Qwen3 model transformer.
This class represents the core transformer architecture of the Qwen3 model, consisting of an embedding layer, multiple Qwen3DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Qwen3DecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- get_decoder() Qwen3Model[source]#
Returns the decoder part of the model’s graph definition.
- get_encoder() None[source]#
Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.
- get_lm_head() None[source]#
Returns the language model head of the module. Base Models don’t have a Language Model Head.
- layers: list[easydel.modules.qwen3.modeling_qwen3.Qwen3DecoderLayer]#