easydel.modules.grok_1.modeling_grok_1#

class easydel.modules.grok_1.modeling_grok_1.Grok1Attention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Grok-1 Attention module.

This module implements the multi-head attention mechanism with rotary position embeddings used in the Grok-1 model.

config#

Configuration object for the model.

Type

Grok1Config

layer_index#

The index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.grok_1.modeling_grok_1.Grok1BLockSparseMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Grok-1 Block Sparse MLP module.

This module implements the specific MLP structure used within the sparse Mixture of Experts layer in the Grok-1 model.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.grok_1.modeling_grok_1.Grok1DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Grok-1 Transformer Decoder Layer.

This module represents a single decoder layer in the Grok-1 model, combining self-attention and a sparse MoE block with residual connections and layer normalization.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.grok_1.modeling_grok_1.Grok1ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[Grok1Model, Grok1Config]

Grok-1 model with a language modeling head.

This model extends the base Grok1Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks. It also includes handling for the Mixture of Experts auxiliary loss.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

apply_lm_head(hidden_states: Float[Array, 'batch seq_len hidden_dim']) Union[Array, ndarray, bool, number][source]#

Apply LM head with Grok-1’s output multiplier scale.

class easydel.modules.grok_1.modeling_grok_1.Grok1Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Grok-1 model implementation.

This class implements the main Grok-1 transformer model architecture, consisting of an embedding layer, multiple Grok1DecoderLayer layers (with sparse MoE), and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.grok_1.modeling_grok_1.Grok1SparseMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Grok-1 Sparse Mixture of Experts (MoE) block.

This module implements the sparse MoE layer used in Grok-1. It routes tokens to a subset of experts based on learned gating weights.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs