easydel.modules.openelm.modeling_openelm#

class easydel.modules.openelm.modeling_openelm.OpenELMDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

OpenELM Transformer Decoder Layer.

This module represents a single decoder layer in the OpenELM model, combining self-attention and FFN sub-layers with residual connections and layer normalization applied before each sub-layer.

config#

Configuration object for the model.

Type

OpenELMConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

attn#

The self-attention module.

Type

OpenELMMultiHeadCausalAttention

ffn#

The feed-forward network (FFN) module.

Type

OpenELMFeedForwardNetwork

attn_norm#

Layer normalization before the attention layer.

Type

RMSNorm

ffn_norm#

Layer normalization before the FFN layer.

Type

RMSNorm

class easydel.modules.openelm.modeling_openelm.OpenELMFeedForwardNetwork(*args: Any, **kwargs: Any)[source]#

Bases: Module

OpenELM Feed-Forward Network (FFN) module.

This module implements the FFN layer used in the OpenELM model. It supports both standard MLP and Gated Linear Unit (GLU) variants.

config#

Configuration object for the model.

Type

OpenELMConfig

layer_idx#

The index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

ffn_with_glu#

Whether the FFN uses a Gated Linear Unit.

Type

bool

proj_1#

First linear projection layer (or gate projection in GLU).

Type

ParallelLinear

proj_2#

Second linear projection layer (down projection).

Type

ParallelLinear

gate_proj#

Gate projection layer used only if ffn_with_glu is True.

Type

ColumnParallelLinear, optional

activation_fn#

The activation function.

Type

callable

class easydel.modules.openelm.modeling_openelm.OpenELMForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[OpenELMModel, OpenELMConfig]

OpenELM model with a Causal Language Modeling head.

class easydel.modules.openelm.modeling_openelm.OpenELMModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OpenELM model transformer.

This class represents the core transformer architecture of the OpenELM model, consisting of an embedding layer, multiple OpenELMDecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

OpenELMConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

token_embeddings#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[OpenELMDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.openelm.modeling_openelm.OpenELMMultiHeadCausalAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

OpenELM causal attention based on UnifiedAttention with per-layer head configuration.

define_network(config: OpenELMConfig, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs) None[source]#

Define network structure.

Override this to customize projection structure (e.g., fused QKV). Default creates separate Q/K/V/O projections.

Parameters
  • config – Model configuration

  • dtype – Data type for computations

  • param_dtype – Data type for parameters

  • precision – JAX precision setting

  • rngs – Random number generators

projection_mapping: ClassVar = {'key_projection': 'k_proj', 'mla_kv_a_layernorm': 'kv_a_layernorm', 'mla_kv_a_proj_with_mqa': 'kv_a_proj_with_mqa', 'mla_kv_b_proj': 'kv_b_proj', 'mla_q_a_layernorm': 'q_a_layernorm', 'mla_q_a_proj': 'q_a_proj', 'mla_q_b_proj': 'q_b_proj', 'mla_q_proj': 'q_proj', 'output_projection': 'out_proj', 'query_key_value_projection': 'qkv_proj', 'query_projection': 'q_proj', 'value_projection': 'v_proj'}#