easydel.modules.openelm.modeling_openelm_flax

Contents

easydel.modules.openelm.modeling_openelm_flax#

class easydel.modules.openelm.modeling_openelm_flax.OpenELMDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

OpenELM Transformer Decoder Layer.

This module represents a single decoder layer in the OpenELM model, combining self-attention and FFN sub-layers with residual connections and layer normalization applied before each sub-layer.

config#

Configuration object for the model.

Type

OpenELMConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

attn#

The self-attention module.

Type

OpenELMMultiHeadCausalAttention

ffn#

The feed-forward network (FFN) module.

Type

OpenELMFeedForwardNetwork

attn_norm#

Layer normalization before the attention layer.

Type

RMSNorm

ffn_norm#

Layer normalization before the FFN layer.

Type

RMSNorm

class easydel.modules.openelm.modeling_openelm_flax.OpenELMFeedForwardNetwork(*args: Any, **kwargs: Any)[source]#

Bases: Module

OpenELM Feed-Forward Network (FFN) module.

This module implements the FFN layer used in the OpenELM model. It supports both standard MLP and Gated Linear Unit (GLU) variants.

config#

Configuration object for the model.

Type

OpenELMConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

ffn_with_glu#

Whether the FFN uses a Gated Linear Unit.

Type

bool

proj_1#

First linear projection layer (or gate projection in GLU).

Type

ParallelLinear

proj_2#

Second linear projection layer (down projection).

Type

ParallelLinear

gate_proj#

Gate projection layer used only if ffn_with_glu is True.

Type

ParallelLinear, optional

activation_fn#

The activation function.

Type

callable

class easydel.modules.openelm.modeling_openelm_flax.OpenELMForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OpenELM model with a Causal Language Modeling head.

This model consists of the base OpenELM transformer (OpenELMModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.

config#

Configuration object for the model.

Type

OpenELMConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

transformer#

The core OpenELM transformer model.

Type

OpenELMModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits. This is None if config.share_input_output_layers is True.

Type

ParallelLinear, optional

class easydel.modules.openelm.modeling_openelm_flax.OpenELMModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OpenELM model transformer.

This class represents the core transformer architecture of the OpenELM model, consisting of an embedding layer, multiple OpenELMDecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

OpenELMConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

token_embeddings#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[OpenELMDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.modules.openelm.modeling_openelm_flax.OpenELMMultiHeadCausalAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

OpenELM Multi-Head Causal Attention module.

This module implements the multi-head causal self-attention mechanism used in the OpenELM model. It supports Grouped Query Attention (GQA) and optional RMS Normalization of query and key projections.

config#

Configuration object for the model.

Type

OpenELMConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

qkv_proj#

Combined linear layer for query, key, and value projections.

Type

ParallelLinear

q_norm#

RMS Normalization applied to the query projection if enabled.

Type

RMSNorm, optional

k_norm#

RMS Normalization applied to the key projection if enabled.

Type

RMSNorm, optional

out_proj#

Linear layer for the output projection.

Type

ParallelLinear

head_dim#

Dimensionality of each attention head.

Type

int

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

num_q_heads#

Number of query heads.

Type

int

num_k_heads#

Number of key heads.

Type

int

num_v_heads#

Number of value heads.

Type

int

transformer_dim#

Dimensionality of the transformer model.

Type

int

num_groups#

Number of query groups for GQA.

Type

int

rotary#

Rotary position embedding module.

Type

RoPE