easydel.modules.stablelm.modeling_stablelm_flax

Contents

easydel.modules.stablelm.modeling_stablelm_flax#

class easydel.modules.stablelm.modeling_stablelm_flax.StableLmAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

StableLM Attention module with Rotary Position Embeddings and optional LayerNorm on QK.

config#

Configuration object for the model.

Type

StableLmConfig

hidden_size#

Dimensionality of the hidden states.

Type

int

num_heads#

Number of attention heads.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_heads#

Number of key/value heads (for GQA).

Type

int

num_key_value_groups#

Number of query heads per key/value head.

Type

int

max_position_embeddings#

Maximum sequence length.

Type

int

rope_theta#

Base value for RoPE.

Type

float

partial_rotary_factor#

Factor determining the portion of head dimension subject to RoPE.

Type

float

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

o_proj#

Linear layer for output projection.

Type

ParallelLinear

rotary_emb_dim#

Dimensionality of the rotary embeddings.

Type

int

attention_performer#

Module for performing attention computation.

Type

FlexibleAttentionModule

qk_layernorm#

Whether to apply LayerNorm to query and key states.

Type

bool

q_layernorm#

LayerNorm for query states (if qk_layernorm is True).

Type

StableLmLayerNormPerHead

k_layernorm#

LayerNorm for key states (if qk_layernorm is True).

Type

StableLmLayerNormPerHead

rotary#

Rotary positional embedding module.

Type

RotaryEmbedding

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.stablelm.modeling_stablelm_flax.StableLmDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

A single decoder layer for the StableLM model.

This layer combines self-attention, MLP, and residual connections with layer normalization. It supports parallel residual connections.

config#

Configuration object for the model.

Type

StableLmConfig

self_attn#

Self-attention module.

Type

StableLmAttention

mlp#

MLP module.

Type

StableLmMLP

input_layernorm#

Layer normalization applied before self-attention.

Type

nn.LayerNorm

post_attention_layernorm#

Layer normalization applied after self-attention and before the MLP.

Type

nn.LayerNorm

dropout_rng_key#

Name of the RNG key for dropout.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.stablelm.modeling_stablelm_flax.StableLmForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

StableLM model with a Causal Language Modeling (CLM) head.

This class wraps the base StableLmModel and adds a linear layer (language model head) to predict the next token logits.

config#

Configuration object for the model.

Type

StableLmConfig

model#

The base StableLM model.

Type

StableLmModel

lm_head#

The language model head (linear layer).

Type

ParallelLinear

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.stablelm.modeling_stablelm_flax.StableLmLayerNormPerHead(*args: Any, **kwargs: Any)[source]#

Bases: Module

Applies Layer Normalization independently to each attention head’s dimension.

norms#

List of LayerNorm modules, one per head.

Type

list[nn.LayerNorm]

class easydel.modules.stablelm.modeling_stablelm_flax.StableLmMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Multi-Layer Perceptron (MLP) block for the StableLM model.

config#

Configuration object for the model.

Type

StableLmConfig

gate_proj#

Linear layer for the gating mechanism.

Type

ParallelLinear

down_proj#

Linear layer for down-projection.

Type

ParallelLinear

up_proj#

Linear layer for up-projection.

Type

ParallelLinear

act_fn#

Activation function (specified in config).

Type

callable

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

class easydel.modules.stablelm.modeling_stablelm_flax.StableLmModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base StableLM transformer model.

This class implements the core transformer architecture, including embedding layers, decoder layers, and final normalization.

config#

Configuration object for the model.

Type

StableLmConfig

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

nn.List[StableLmDecoderLayer]

norm#

Final layer normalization.

Type

nn.LayerNorm

gradient_checkpointing#

Gradient checkpointing strategy.

Type

str

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for matrix multiplications.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Cached property for precomputed rotary frequencies.