easydel.modules.olmo2.modeling_olmo2#

class easydel.modules.olmo2.modeling_olmo2.Olmo2Attention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

OLMo-2 Attention with Q/K normalization.

Uses RMSNorm for Q/K normalization to improve training stability. Standard RoPE-based attention without sliding window.

class easydel.modules.olmo2.modeling_olmo2.Olmo2DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

OLMo-2 Transformer Decoder Layer.

This module represents a single decoder layer in the OLMo-2 model, combining self-attention and MLP sub-layers with residual connections and layer normalization applied before each sub-layer.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

self_attn#

The self-attention module.

Type

Olmo2Attention

mlp#

The feed-forward (MLP) module.

Type

Olmo2MLP

input_layernorm#

Layer normalization before the attention layer.

Type

RMSNorm

post_attention_layernorm#

Layer normalization before the MLP layer.

Type

RMSNorm

class easydel.modules.olmo2.modeling_olmo2.Olmo2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Olmo2Model, Olmo2Config]

OLMo-2 model with a Causal Language Modeling head.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module.

class easydel.modules.olmo2.modeling_olmo2.Olmo2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[Olmo2Model, Olmo2Config]

OLMo-2 model with a Sequence Classification head.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. This model has a sequence classification head, not an LM Head.

class easydel.modules.olmo2.modeling_olmo2.Olmo2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

OLMo-2 MLP module.

This module implements the feed-forward network (MLP) used in the OLMo-2 model. It consists of gate, up, and down projections with a SiLU activation.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_proj#

Linear layer for the gate projection.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

up_proj#

Linear layer for the up projection.

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.olmo2.modeling_olmo2.Olmo2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OLMo-2 model transformer.

This class represents the core transformer architecture of the OLMo-2 model, consisting of an embedding layer, multiple Olmo2DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Olmo2DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.