easydel.modules.grok_1.modeling_grok_1_flax#
- class easydel.modules.grok_1.modeling_grok_1_flax.Grok1Attention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleGrok-1 Attention module.
This module implements the multi-head attention mechanism with rotary position embeddings used in the Grok-1 model.
- config#
Configuration object for the model.
- Type
- layer_index#
The index of the current layer.
- Type
int
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.grok_1.modeling_grok_1_flax.Grok1BLockSparseMLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGrok-1 Block Sparse MLP module.
This module implements the specific MLP structure used within the sparse Mixture of Experts layer in the Grok-1 model.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.grok_1.modeling_grok_1_flax.Grok1DecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGrok-1 Transformer Decoder Layer.
This module represents a single decoder layer in the Grok-1 model, combining self-attention and a sparse MoE block with residual connections and layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.grok_1.modeling_grok_1_flax.Grok1ForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGrok-1 model with a language modeling head.
This model extends the base Grok1Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks. It also includes handling for the Mixture of Experts auxiliary loss.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- class easydel.modules.grok_1.modeling_grok_1_flax.Grok1Model(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleGrok-1 model implementation.
This class implements the main Grok-1 transformer model architecture, consisting of an embedding layer, multiple Grok1DecoderLayer layers (with sparse MoE), and a final RMS normalization layer.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- property frequencies#
Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.
Uses self.config.get_basic_frequencies() and caches the result.
- Returns
The frequency components, potentially cached.
- Return type
jnp.ndarray
- class easydel.modules.grok_1.modeling_grok_1_flax.Grok1SparseMoeBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleGrok-1 Sparse Mixture of Experts (MoE) block.
This module implements the sparse MoE layer used in Grok-1. It routes tokens to a subset of experts based on learned gating weights.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs