easydel.modules.qwen2.modeling_qwen#

class easydel.modules.qwen2.modeling_qwen.Qwen2Attention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Qwen2 Attention module with sliding window support.

Inherits from UnifiedAttention with Qwen2-specific customizations: - Sliding window attention (layer-specific) - Custom bias configuration (Q/K/V use bias, O doesn’t) - Residual dropout

class easydel.modules.qwen2.modeling_qwen.Qwen2DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen2 Transformer Decoder Layer.

This module represents a single decoder layer in the Qwen2 model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

RMS normalization applied before the attention layer.

Type

RMSNorm

self_attn#

The self-attention module.

Type

Qwen2Attention

mlp#

The feed-forward (MLP) module.

Type

Qwen2MLP

post_attention_layernorm#

RMS normalization applied after the attention layer and before the MLP layer.

Type

RMSNorm

class easydel.modules.qwen2.modeling_qwen.Qwen2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Qwen2Model, Qwen2Config]

Qwen2 model with a Causal Language Modeling head.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.qwen2.modeling_qwen.Qwen2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[Qwen2Model, Qwen2Config]

Qwen2 model with a Sequence Classification head.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. This model has a sequence classification head, not an LM Head.

class easydel.modules.qwen2.modeling_qwen.Qwen2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen2 MLP module.

This module implements the feed-forward network (MLP) used in the Qwen2 model. It uses a Gated Linear Unit (GLU) structure with SiLU activation and includes dropout.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_proj#

Linear layer for the GLU gate.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

up_proj#

Linear layer for the GLU value.

Type

ParallelLinear

dropout#

Dropout layer applied to the output.

Type

nn.Dropout

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.qwen2.modeling_qwen.Qwen2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen2 model transformer.

This class represents the core transformer architecture of the Qwen2 model, consisting of an embedding layer, multiple Qwen2DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Qwen2DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.