easydel.modules.olmo.modeling_olmo_flax#
- class easydel.modules.olmo.modeling_olmo_flax.OlmoAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleOLMo Attention module.
This module implements the multi-head attention mechanism with rotary position embeddings and Grouped Query Attention (GQA) used in the OLMo model. It also supports optional QKV clipping.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
Dimensionality of the hidden states.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_groups#
Number of query head groups for each key/value head.
- Type
int
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- o_proj#
Linear layer for the output projection.
- Type
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module.
- Type
RoPE
- class easydel.modules.olmo.modeling_olmo_flax.OlmoDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOLMo Transformer Decoder Layer.
This module represents a single decoder layer in the OLMo model, combining self-attention and MLP sub-layers with residual connections. Unlike typical transformer blocks, OLMo applies the layer normalization after the residual connection.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- class easydel.modules.olmo.modeling_olmo_flax.OlmoForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOLMo model with a Causal Language Modeling head.
This model consists of the base OLMo transformer (OlmoModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- class easydel.modules.olmo.modeling_olmo_flax.OlmoForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOLMo model with a Sequence Classification head.
This model consists of the base OLMo transformer (OlmoModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the last token) to the number of classes for classification.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- score#
The linear layer for classification.
- Type
- class easydel.modules.olmo.modeling_olmo_flax.OlmoMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOLMo MLP module.
This module implements the feed-forward network (MLP) used in the OLMo model. It consists of gate, up, and down projections with a SiLU activation.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- gate_proj#
Linear layer for the gate projection.
- Type
- down_proj#
Linear layer for the down projection.
- Type
- up_proj#
Linear layer for the up projection.
- Type
- act_fn#
Activation function (SiLU).
- Type
callable
- class easydel.modules.olmo.modeling_olmo_flax.OlmoModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base OLMo model transformer.
This class represents the core transformer architecture of the OLMo model, consisting of an embedding layer and multiple OlmoDecoderLayer layers. Note that OLMo does not have a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[OlmoDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.