easydel.modules.openelm.modeling_openelm_flax#
- class easydel.modules.openelm.modeling_openelm_flax.OpenELMDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOpenELM Transformer Decoder Layer.
This module represents a single decoder layer in the OpenELM model, combining self-attention and FFN sub-layers with residual connections and layer normalization applied before each sub-layer.
- config#
Configuration object for the model.
- Type
- layer_idx#
The index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- attn#
The self-attention module.
- ffn#
The feed-forward network (FFN) module.
- class easydel.modules.openelm.modeling_openelm_flax.OpenELMFeedForwardNetwork(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOpenELM Feed-Forward Network (FFN) module.
This module implements the FFN layer used in the OpenELM model. It supports both standard MLP and Gated Linear Unit (GLU) variants.
- config#
Configuration object for the model.
- Type
- layer_idx#
The index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- ffn_with_glu#
Whether the FFN uses a Gated Linear Unit.
- Type
bool
- proj_1#
First linear projection layer (or gate projection in GLU).
- Type
- proj_2#
Second linear projection layer (down projection).
- Type
- gate_proj#
Gate projection layer used only if ffn_with_glu is True.
- Type
ParallelLinear, optional
- activation_fn#
The activation function.
- Type
callable
- class easydel.modules.openelm.modeling_openelm_flax.OpenELMForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleOpenELM model with a Causal Language Modeling head.
This model consists of the base OpenELM transformer (OpenELMModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. Optionally, the input token embeddings can be tied to the output projection layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- transformer#
The core OpenELM transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits. This is None if config.share_input_output_layers is True.
- Type
ParallelLinear, optional
- class easydel.modules.openelm.modeling_openelm_flax.OpenELMModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base OpenELM model transformer.
This class represents the core transformer architecture of the OpenELM model, consisting of an embedding layer, multiple OpenELMDecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- token_embeddings#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[OpenELMDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.modules.openelm.modeling_openelm_flax.OpenELMMultiHeadCausalAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleOpenELM Multi-Head Causal Attention module.
This module implements the multi-head causal self-attention mechanism used in the OpenELM model. It supports Grouped Query Attention (GQA) and optional RMS Normalization of query and key projections.
- config#
Configuration object for the model.
- Type
- layer_idx#
The index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- qkv_proj#
Combined linear layer for query, key, and value projections.
- Type
- out_proj#
Linear layer for the output projection.
- Type
- head_dim#
Dimensionality of each attention head.
- Type
int
- attention_performer#
Module to perform the core attention computation.
- num_q_heads#
Number of query heads.
- Type
int
- num_k_heads#
Number of key heads.
- Type
int
- num_v_heads#
Number of value heads.
- Type
int
- transformer_dim#
Dimensionality of the transformer model.
- Type
int
- num_groups#
Number of query groups for GQA.
- Type
int
- rotary#
Rotary position embedding module.
- Type
RoPE