easydel.modules.olmo2.modeling_olmo2_flax

Contents

easydel.modules.olmo2.modeling_olmo2_flax#

class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2Attention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

OLMo-2 Attention module.

This module implements the multi-head attention mechanism with rotary position embeddings and Grouped Query Attention (GQA) used in the OLMo-2 model. It includes RMSNorm applied to query and key projections before the attention calculation.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

hidden_size#

Dimensionality of the hidden states.

Type

int

num_heads#

Number of attention heads.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_groups#

Number of query head groups for each key/value head.

Type

int

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

o_proj#

Linear layer for the output projection.

Type

ParallelLinear

q_norm#

RMS Normalization applied to the query projection.

Type

RMSNorm

k_norm#

RMS Normalization applied to the key projection.

Type

RMSNorm

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

rotary#

Rotary position embedding module.

Type

RoPE

class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

OLMo-2 Transformer Decoder Layer.

This module represents a single decoder layer in the OLMo-2 model, combining self-attention and MLP sub-layers with residual connections and layer normalization applied before each sub-layer.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

self_attn#

The self-attention module.

Type

Olmo2Attention

mlp#

The feed-forward (MLP) module.

Type

Olmo2MLP

input_layernorm#

Layer normalization before the attention layer.

Type

RMSNorm

post_attention_layernorm#

Layer normalization before the MLP layer.

Type

RMSNorm

class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OLMo-2 model with a Causal Language Modeling head.

This model consists of the base OLMo-2 transformer (Olmo2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core OLMo-2 transformer model.

Type

Olmo2Model

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

OLMo-2 model with a Sequence Classification head.

This model consists of the base OLMo-2 transformer (Olmo2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token) to the number of classes for classification.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core OLMo-2 transformer model.

Type

Olmo2Model

score#

The linear layer for classification.

Type

ParallelLinear

class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

OLMo-2 MLP module.

This module implements the feed-forward network (MLP) used in the OLMo-2 model. It consists of gate, up, and down projections with a SiLU activation.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_proj#

Linear layer for the gate projection.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

up_proj#

Linear layer for the up projection.

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OLMo-2 model transformer.

This class represents the core transformer architecture of the OLMo-2 model, consisting of an embedding layer, multiple Olmo2DecoderLayer layers, and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Olmo2Config

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[Olmo2DecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers