easydel.modules.internlm2.modeling_internlm2_flax

Contents

easydel.modules.internlm2.modeling_internlm2_flax#

class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2Attention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

InternLM2 Attention module.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

hidden_size#

Dimensionality of the hidden states.

Type

int

num_heads#

Number of attention heads.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_heads#

Number of key/value heads (for GQA).

Type

int

num_key_value_groups#

Number of query head groups for each key/value head.

Type

int

max_position_embeddings#

Maximum sequence length supported.

Type

int

wqkv#

Linear layer for query, key, and value projections.

Type

ParallelLinear

wo#

Linear layer for the output projection.

Type

ParallelLinear

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

rotary#

Rotary position embedding module.

Type

RoPE

class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2Block(*args: Any, **kwargs: Any)[source]#

Bases: Module

InternLM2 Transformer Block.

This module combines the self-attention layer and the MLP layer with residual connections and layer normalization.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

attention#

The self-attention module.

Type

InternLM2Attention

feed_forward#

The feed-forward (MLP) module.

Type

InternLM2MLP

attention_norm#

Layer normalization before the attention layer.

Type

RMSNorm

ffn_norm#

Layer normalization before the MLP layer.

Type

RMSNorm

class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

InternLM2 model with a Causal Language Modeling head.

This model consists of the base InternLM2 transformer (InternLM2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

module#

The core InternLM2 transformer model.

Type

InternLM2Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

InternLM2 model with a Sequence Classification head.

This model consists of the base InternLM2 transformer (InternLM2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

module#

The core InternLM2 transformer model.

Type

InternLM2Model

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

InternLM2 MLP module.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

w1#

First linear transformation (gate projection).

Type

ParallelLinear

w3#

Second linear transformation (up projection).

Type

ParallelLinear

w2#

Third linear transformation (down projection).

Type

ParallelLinear

act_fn#

Activation function (e.g., SiLU).

Type

callable

class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base InternLM2 model transformer.

This class represents the core transformer architecture of the InternLM2 model, consisting of embedding layers, multiple transformer blocks, and a final layer normalization.

config#

Configuration object for the model.

Type

InternLM2Config

dtype#

Data type for computation. Default is jnp.float32.

Type

jnp.dtype

param_dtype#

Data type for parameters. Default is jnp.float32.

Type

jnp.dtype

precision#

Precision setting for JAX operations. Default is None.

Type

jax.lax.PrecisionLike

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

Sequence of transformer blocks.

Type

tp.Sequence[InternLM2Block]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

scan_layers#

Whether to use JAX scan for layer processing.

Type

bool

blocks_class#

The class used for the transformer blocks.

Type

InternLM2Block