easydel.modules.olmo2.modeling_olmo2_flax#
- class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2Attention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleOLMo-2 Attention module.
This module implements the multi-head attention mechanism with rotary position embeddings and Grouped Query Attention (GQA) used in the OLMo-2 model. It includes RMSNorm applied to query and key projections before the attention calculation.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
Dimensionality of the hidden states.
- Type
int
- num_heads#
Number of attention heads.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_groups#
Number of query head groups for each key/value head.
- Type
int
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- o_proj#
Linear layer for the output projection.
- Type
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module.
- Type
RoPE
- class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOLMo-2 Transformer Decoder Layer.
This module represents a single decoder layer in the OLMo-2 model, combining self-attention and MLP sub-layers with residual connections and layer normalization applied before each sub-layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOLMo-2 model with a Causal Language Modeling head.
This model consists of the base OLMo-2 transformer (Olmo2Model) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core OLMo-2 transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2ForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOLMo-2 model with a Sequence Classification head.
This model consists of the base OLMo-2 transformer (Olmo2Model) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core OLMo-2 transformer model.
- Type
- score#
The linear layer for classification.
- Type
- class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOLMo-2 MLP module.
This module implements the feed-forward network (MLP) used in the OLMo-2 model. It consists of gate, up, and down projections with a SiLU activation.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- gate_proj#
Linear layer for the gate projection.
- Type
- down_proj#
Linear layer for the down projection.
- Type
- up_proj#
Linear layer for the up projection.
- Type
- act_fn#
Activation function (SiLU).
- Type
callable
- class easydel.modules.olmo2.modeling_olmo2_flax.Olmo2Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base OLMo-2 model transformer.
This class represents the core transformer architecture of the OLMo-2 model, consisting of an embedding layer, multiple Olmo2DecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[Olmo2DecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.