easydel.modules.openelm.modeling_openelm#
- class easydel.modules.openelm.modeling_openelm.OpenELMDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOpenELM Transformer Decoder Layer.
This module represents a single decoder layer in the OpenELM model, combining self-attention and FFN sub-layers with residual connections and layer normalization applied before each sub-layer.
- config#
Configuration object for the model.
- Type
- layer_idx#
The index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- attn#
The self-attention module.
- ffn#
The feed-forward network (FFN) module.
- class easydel.modules.openelm.modeling_openelm.OpenELMFeedForwardNetwork(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleOpenELM Feed-Forward Network (FFN) module.
This module implements the FFN layer used in the OpenELM model. It supports both standard MLP and Gated Linear Unit (GLU) variants.
- config#
Configuration object for the model.
- Type
- layer_idx#
The index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- ffn_with_glu#
Whether the FFN uses a Gated Linear Unit.
- Type
bool
- proj_1#
First linear projection layer (or gate projection in GLU).
- Type
- proj_2#
Second linear projection layer (down projection).
- Type
- gate_proj#
Gate projection layer used only if ffn_with_glu is True.
- Type
ColumnParallelLinear, optional
- activation_fn#
The activation function.
- Type
callable
- class easydel.modules.openelm.modeling_openelm.OpenELMForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[OpenELMModel,OpenELMConfig]OpenELM model with a Causal Language Modeling head.
- class easydel.modules.openelm.modeling_openelm.OpenELMModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base OpenELM model transformer.
This class represents the core transformer architecture of the OpenELM model, consisting of an embedding layer, multiple OpenELMDecoderLayer layers, and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- token_embeddings#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[OpenELMDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.modules.openelm.modeling_openelm.OpenELMMultiHeadCausalAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionOpenELM causal attention based on UnifiedAttention with per-layer head configuration.
- define_network(config: OpenELMConfig, dtype: dtype, param_dtype: dtype, precision: Union[None, str, Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision], DotAlgorithm, DotAlgorithmPreset], rngs: Rngs) None[source]#
Define network structure.
Override this to customize projection structure (e.g., fused QKV). Default creates separate Q/K/V/O projections.
- Parameters
config – Model configuration
dtype – Data type for computations
param_dtype – Data type for parameters
precision – JAX precision setting
rngs – Random number generators
- projection_mapping: ClassVar = {'key_projection': 'k_proj', 'mla_kv_a_layernorm': 'kv_a_layernorm', 'mla_kv_a_proj_with_mqa': 'kv_a_proj_with_mqa', 'mla_kv_b_proj': 'kv_b_proj', 'mla_q_a_layernorm': 'q_a_layernorm', 'mla_q_a_proj': 'q_a_proj', 'mla_q_b_proj': 'q_b_proj', 'mla_q_proj': 'q_proj', 'output_projection': 'out_proj', 'query_key_value_projection': 'qkv_proj', 'query_projection': 'q_proj', 'value_projection': 'v_proj'}#