easydel.modules.internlm2.modeling_internlm2_flax#
- class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2Attention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleInternLM2 Attention module.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
Dimensionality of the hidden states.
- Type
int
- num_heads#
Number of attention heads.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_heads#
Number of key/value heads (for GQA).
- Type
int
- num_key_value_groups#
Number of query head groups for each key/value head.
- Type
int
- max_position_embeddings#
Maximum sequence length supported.
- Type
int
- wqkv#
Linear layer for query, key, and value projections.
- Type
- wo#
Linear layer for the output projection.
- Type
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module.
- Type
RoPE
- class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2Block(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleInternLM2 Transformer Block.
This module combines the self-attention layer and the MLP layer with residual connections and layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- attention#
The self-attention module.
- Type
- feed_forward#
The feed-forward (MLP) module.
- Type
- class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleInternLM2 model with a Causal Language Modeling head.
This model consists of the base InternLM2 transformer (InternLM2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- module#
The core InternLM2 transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleInternLM2 model with a Sequence Classification head.
This model consists of the base InternLM2 transformer (InternLM2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- module#
The core InternLM2 transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleInternLM2 MLP module.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- w1#
First linear transformation (gate projection).
- Type
- w3#
Second linear transformation (up projection).
- Type
- w2#
Third linear transformation (down projection).
- Type
- act_fn#
Activation function (e.g., SiLU).
- Type
callable
- class easydel.modules.internlm2.modeling_internlm2_flax.InternLM2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base InternLM2 model transformer.
This class represents the core transformer architecture of the InternLM2 model, consisting of embedding layers, multiple transformer blocks, and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation. Default is jnp.float32.
- Type
jnp.dtype
- param_dtype#
Data type for parameters. Default is jnp.float32.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations. Default is None.
- Type
jax.lax.PrecisionLike
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
Sequence of transformer blocks.
- Type
tp.Sequence[InternLM2Block]
- gradient_checkpointing#
Gradient checkpointing configuration.
- scan_layers#
Whether to use JAX scan for layer processing.
- Type
bool
- blocks_class#
The class used for the transformer blocks.
- Type