easydel.modules.stablelm.modeling_stablelm_flax#
- class easydel.modules.stablelm.modeling_stablelm_flax.StableLmAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleStableLM Attention module with Rotary Position Embeddings and optional LayerNorm on QK.
- config#
Configuration object for the model.
- Type
Dimensionality of the hidden states.
- Type
int
- num_heads#
Number of attention heads.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_heads#
Number of key/value heads (for GQA).
- Type
int
- num_key_value_groups#
Number of query heads per key/value head.
- Type
int
- max_position_embeddings#
Maximum sequence length.
- Type
int
- rope_theta#
Base value for RoPE.
- Type
float
- partial_rotary_factor#
Factor determining the portion of head dimension subject to RoPE.
- Type
float
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- o_proj#
Linear layer for output projection.
- Type
- rotary_emb_dim#
Dimensionality of the rotary embeddings.
- Type
int
- attention_performer#
Module for performing attention computation.
- qk_layernorm#
Whether to apply LayerNorm to query and key states.
- Type
bool
- q_layernorm#
LayerNorm for query states (if qk_layernorm is True).
- k_layernorm#
LayerNorm for key states (if qk_layernorm is True).
- rotary#
Rotary positional embedding module.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.stablelm.modeling_stablelm_flax.StableLmDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA single decoder layer for the StableLM model.
This layer combines self-attention, MLP, and residual connections with layer normalization. It supports parallel residual connections.
- config#
Configuration object for the model.
- Type
- self_attn#
Self-attention module.
- Type
- mlp#
MLP module.
- Type
- input_layernorm#
Layer normalization applied before self-attention.
- Type
nn.LayerNorm
- post_attention_layernorm#
Layer normalization applied after self-attention and before the MLP.
- Type
nn.LayerNorm
- dropout_rng_key#
Name of the RNG key for dropout.
- Type
str
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.stablelm.modeling_stablelm_flax.StableLmForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleStableLM model with a Causal Language Modeling (CLM) head.
This class wraps the base StableLmModel and adds a linear layer (language model head) to predict the next token logits.
- config#
Configuration object for the model.
- Type
- model#
The base StableLM model.
- Type
- lm_head#
The language model head (linear layer).
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.stablelm.modeling_stablelm_flax.StableLmLayerNormPerHead(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleApplies Layer Normalization independently to each attention head’s dimension.
- norms#
List of LayerNorm modules, one per head.
- Type
list[nn.LayerNorm]
- class easydel.modules.stablelm.modeling_stablelm_flax.StableLmMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMulti-Layer Perceptron (MLP) block for the StableLM model.
- config#
Configuration object for the model.
- Type
- gate_proj#
Linear layer for the gating mechanism.
- Type
- down_proj#
Linear layer for down-projection.
- Type
- up_proj#
Linear layer for up-projection.
- Type
- act_fn#
Activation function (specified in config).
- Type
callable
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- class easydel.modules.stablelm.modeling_stablelm_flax.StableLmModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base StableLM transformer model.
This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.
- config#
Configuration object for the model.
- Type
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
nn.List[StableLmDecoderLayer]
- norm#
Final layer normalization.
- Type
nn.LayerNorm
- gradient_checkpointing#
Gradient checkpointing strategy.
- Type
str
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property frequencies#
Cached property for precomputed rotary frequencies.