easydel.modules.qwen2.modeling_qwen_flax

Contents

easydel.modules.qwen2.modeling_qwen_flax#

class easydel.modules.qwen2.modeling_qwen_flax.Qwen2Attention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Qwen2 Attention module.

This module implements the multi-head attention mechanism used in the Qwen2 model. It supports Grouped Query Attention (GQA) and Rotary Position Embeddings (RoPE). It also includes a residual dropout layer.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

hidden_size#

Dimensionality of the hidden states.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_groups#

Number of query head groups for each key/value head.

Type

int

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

o_proj#

Linear layer for the output projection.

Type

ParallelLinear

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

resid_dropout#

Dropout applied to the residual connection within the attention block (if config enables).

Type

nn.Dropout

rotary#

Rotary position embedding module.

Type

RoPE

class easydel.modules.qwen2.modeling_qwen_flax.Qwen2DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen2 Transformer Decoder Layer.

This module represents a single decoder layer in the Qwen2 model, combining self-attention and MLP sub-layers with residual connections and RMS normalization.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

RMS normalization applied before the attention layer.

Type

RMSNorm

self_attn#

The self-attention module.

Type

Qwen2Attention

mlp#

The feed-forward (MLP) module.

Type

Qwen2MLP

post_attention_layernorm#

RMS normalization applied after the attention layer and before the MLP layer.

Type

RMSNorm

class easydel.modules.qwen2.modeling_qwen_flax.Qwen2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 model with a Causal Language Modeling head.

This model consists of the base Qwen2 transformer (Qwen2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen2 transformer model.

Type

Qwen2Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.modules.qwen2.modeling_qwen_flax.Qwen2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Qwen2 model with a Sequence Classification head.

This model consists of the base Qwen2 transformer (Qwen2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token or a pooled representation) to the number of classes for classification.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Qwen2 transformer model.

Type

Qwen2Model

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.modules.qwen2.modeling_qwen_flax.Qwen2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Qwen2 MLP module.

This module implements the feed-forward network (MLP) used in the Qwen2 model. It uses a Gated Linear Unit (GLU) structure with SiLU activation and includes dropout.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_proj#

Linear layer for the GLU gate.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

up_proj#

Linear layer for the GLU value.

Type

ParallelLinear

dropout#

Dropout layer applied to the output.

Type

nn.Dropout

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.qwen2.modeling_qwen_flax.Qwen2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Qwen2 model transformer.

This class represents the core transformer architecture of the Qwen2 model, consisting of an embedding layer, multiple Qwen2DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Qwen2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Qwen2DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

dropout#

Dropout layer applied after embeddings.

Type

nn.Dropout

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers