easydel.modules.qwen3.modeling_qwen3#

class easydel.modules.qwen3.modeling_qwen3.Qwen3Attention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Qwen3 Attention with Q/K normalization.

Inherits Q/K normalization (RMSNorm) from QKNormAttention. Features: - Layer-specific sliding window

class easydel.modules.qwen3.modeling_qwen3.Qwen3DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen3 Transformer Decoder Layer.

This module represents a single decoder layer in the Qwen3 model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.

config#

Configuration object for the model. layer_idx (int): The index of the layer in the model.

Type

Qwen3Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

RMS normalization applied before the attention layer.

Type

RMSNorm

self_attn#

The self-attention module.

Type

Qwen3Attention

mlp#

The feed-forward (MLP) module.

Type

Qwen3MLP

post_attention_layernorm#

RMS normalization applied after the attention layer and before the MLP layer.

Type

RMSNorm

config: Qwen3Config#
dtype: dtype#
input_layernorm: RMSNorm#
mlp: Qwen3MLP#
param_dtype: dtype#
post_attention_layernorm: RMSNorm#
precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset]#
self_attn: Qwen3Attention#
class easydel.modules.qwen3.modeling_qwen3.Qwen3ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Qwen3Model, Qwen3Config]

Qwen3 model with a Causal Language Modeling head.

class easydel.modules.qwen3.modeling_qwen3.Qwen3ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[Qwen3Model, Qwen3Config]

Qwen3 model with a Sequence Classification head.

class easydel.modules.qwen3.modeling_qwen3.Qwen3MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen3 MLP module.

This module implements the feed-forward network (MLP) used in the Qwen3 model. It uses a Gated Linear Unit (GLU) structure with SiLU activation.

config#

Configuration object for the model.

Type

Qwen3Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_proj#

Linear layer for the GLU gate.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

up_proj#

Linear layer for the GLU value.

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

act_fn: callable#
config: Qwen3Config#
down_proj: RowParallelLinear#
dtype: dtype#
gate_proj: ColumnParallelLinear#
param_dtype: dtype#
precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset]#
up_proj: ColumnParallelLinear#
class easydel.modules.qwen3.modeling_qwen3.Qwen3Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen3 model transformer.

This class represents the core transformer architecture of the Qwen3 model, consisting of an embedding layer, multiple Qwen3DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Qwen3Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Qwen3DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

embed_tokens: Embed#
get_decoder() Qwen3Model[source]#

Returns the decoder part of the model’s graph definition.

get_embedding() Embed[source]#

Returns the embedding layer of the module.

get_encoder() None[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head() None[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

layers: list[easydel.modules.qwen3.modeling_qwen3.Qwen3DecoderLayer]#
norm: RMSNorm#