easydel.modules.olmo.modeling_olmo#

class easydel.modules.olmo.modeling_olmo.OlmoAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

OLMo attention built on UnifiedAttention with optional QKV clipping.

class easydel.modules.olmo.modeling_olmo.OlmoDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

OLMo Transformer Decoder Layer.

This module represents a single decoder layer in the OLMo model, combining self-attention and MLP sub-layers with residual connections. Unlike typical transformer blocks, OLMo applies the layer normalization after the residual connection.

config#

Configuration object for the model.

Type

OlmoConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

self_attn#

The self-attention module.

Type

OlmoAttention

mlp#

The feed-forward (MLP) module.

Type

OlmoMLP

class easydel.modules.olmo.modeling_olmo.OlmoForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[OlmoModel, OlmoConfig]

OLMo model with a Causal Language Modeling head.

class easydel.modules.olmo.modeling_olmo.OlmoForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[OlmoModel, OlmoConfig]

OLMo model with a Sequence Classification head.

class easydel.modules.olmo.modeling_olmo.OlmoMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

OLMo MLP module.

This module implements the feed-forward network (MLP) used in the OLMo model. It consists of gate, up, and down projections with a SiLU activation.

config#

Configuration object for the model.

Type

OlmoConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

gate_proj#

Linear layer for the gate projection.

Type

ParallelLinear

down_proj#

Linear layer for the down projection.

Type

ParallelLinear

up_proj#

Linear layer for the up projection.

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.olmo.modeling_olmo.OlmoModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base OLMo model transformer.

This class represents the core transformer architecture of the OLMo model, consisting of an embedding layer and multiple OlmoDecoderLayer layers. Note that OLMo does not have a final layer normalization.

config#

Configuration object for the model.

Type

OlmoConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[OlmoDecoderLayer]

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.