easydel.modules.stablelm.modeling_stablelm#

class easydel.modules.stablelm.modeling_stablelm.StableLmAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

StableLM Attention with Q/K normalization.

Inherits Q/K normalization from QKNormAttention. Features: - Uses LayerNorm instead of RMSNorm - Per-head normalization (StableLmLayerNormPerHead) - Partial RoPE (partial_rotary_factor)

norms_mapping: ClassVar = {'key_normalization': 'k_layernorm', 'query_normalization': 'q_layernorm'}#
class easydel.modules.stablelm.modeling_stablelm.StableLmDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

A single decoder layer for the StableLM model.

This layer combines self-attention, MLP, and residual connections with layer normalization. It supports parallel residual connections.

config#

Configuration object for the model.

Type

StableLmConfig

self_attn#

Self-attention module.

Type

StableLmAttention

mlp#

MLP module.

Type

StableLmMLP

input_layernorm#

Layer normalization applied before self-attention.

Type

nn.LayerNorm

post_attention_layernorm#

Layer normalization applied after self-attention and before the MLP.

Type

nn.LayerNorm

dropout_rng_key#

Name of the RNG key for dropout.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.stablelm.modeling_stablelm.StableLmForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[StableLmModel, StableLmConfig]

StableLM model with a Causal Language Modeling (CLM) head.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.stablelm.modeling_stablelm.StableLmLayerNormPerHead(*args: Any, **kwargs: Any)[source]#

Bases: Module

Applies Layer Normalization independently to each attention head’s dimension.

norms#

List of LayerNorm modules, one per head.

Type

list[nn.LayerNorm]

class easydel.modules.stablelm.modeling_stablelm.StableLmMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron (MLP) block for the StableLM model.

config#

Configuration object for the model.

Type

StableLmConfig

gate_proj#

Linear layer for the gating mechanism.

Type

ParallelLinear

down_proj#

Linear layer for down-projection.

Type

ParallelLinear

up_proj#

Linear layer for up-projection.

Type

ParallelLinear

act_fn#

Activation function (specified in config).

Type

callable

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

class easydel.modules.stablelm.modeling_stablelm.StableLmModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base StableLM transformer model.

This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.

config#

Configuration object for the model.

Type

StableLmConfig

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

nn.List[StableLmDecoderLayer]

norm#

Final layer normalization.

Type

nn.LayerNorm

gradient_checkpointing#

Gradient checkpointing strategy.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Cached property for precomputed rotary frequencies.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.