easydel.modules.stablelm.modeling_stablelm#
- class easydel.modules.stablelm.modeling_stablelm.StableLmAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionStableLM Attention with Q/K normalization.
Inherits Q/K normalization from QKNormAttention. Features: - Uses LayerNorm instead of RMSNorm - Per-head normalization (StableLmLayerNormPerHead) - Partial RoPE (partial_rotary_factor)
- norms_mapping: ClassVar = {'key_normalization': 'k_layernorm', 'query_normalization': 'q_layernorm'}#
- class easydel.modules.stablelm.modeling_stablelm.StableLmDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA single decoder layer for the StableLM model.
This layer combines self-attention, MLP, and residual connections with layer normalization. It supports parallel residual connections.
- config#
Configuration object for the model.
- Type
- self_attn#
Self-attention module.
- Type
- mlp#
MLP module.
- Type
- input_layernorm#
Layer normalization applied before self-attention.
- Type
nn.LayerNorm
- post_attention_layernorm#
Layer normalization applied after self-attention and before the MLP.
- Type
nn.LayerNorm
- dropout_rng_key#
Name of the RNG key for dropout.
- Type
str
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.stablelm.modeling_stablelm.StableLmForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[StableLmModel,StableLmConfig]StableLM model with a Causal Language Modeling (CLM) head.
- class easydel.modules.stablelm.modeling_stablelm.StableLmLayerNormPerHead(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleApplies Layer Normalization independently to each attention head’s dimension.
- norms#
List of LayerNorm modules, one per head.
- Type
list[nn.LayerNorm]
- class easydel.modules.stablelm.modeling_stablelm.StableLmMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMulti-Layer Perceptron (MLP) block for the StableLM model.
- config#
Configuration object for the model.
- Type
- gate_proj#
Linear layer for the gating mechanism.
- Type
- down_proj#
Linear layer for down-projection.
- Type
- up_proj#
Linear layer for up-projection.
- Type
- act_fn#
Activation function (specified in config).
- Type
callable
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- class easydel.modules.stablelm.modeling_stablelm.StableLmModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base StableLM transformer model.
This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.
- config#
Configuration object for the model.
- Type
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
nn.List[StableLmDecoderLayer]
- norm#
Final layer normalization.
- Type
nn.LayerNorm
- gradient_checkpointing#
Gradient checkpointing strategy.
- Type
str
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for matrix multiplications.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property frequencies#
Cached property for precomputed rotary frequencies.