easydel.modules.grok_1.modeling_grok_1_flax#

class easydel.modules.grok_1.modeling_grok_1_flax.Grok1Attention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Grok-1 Attention module.

This module implements the multi-head attention mechanism with rotary position embeddings used in the Grok-1 model.

config#

Configuration object for the model.

Type

Grok1Config

layer_index#

The index of the current layer.

Type

int

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.grok_1.modeling_grok_1_flax.Grok1BLockSparseMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Grok-1 Block Sparse MLP module.

This module implements the specific MLP structure used within the sparse Mixture of Experts layer in the Grok-1 model.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.grok_1.modeling_grok_1_flax.Grok1DecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Grok-1 Transformer Decoder Layer.

This module represents a single decoder layer in the Grok-1 model, combining self-attention and a sparse MoE block with residual connections and layer normalization.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.grok_1.modeling_grok_1_flax.Grok1ForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Grok-1 model with a language modeling head.

This model extends the base Grok1Model by adding a linear layer on top to predict the next token in a sequence, making it suitable for causal language modeling tasks. It also includes handling for the Mixture of Experts auxiliary loss.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.grok_1.modeling_grok_1_flax.Grok1Model(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Grok-1 model implementation.

This class implements the main Grok-1 transformer model architecture, consisting of an embedding layer, multiple Grok1DecoderLayer layers (with sparse MoE), and a final RMS normalization layer.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property frequencies#

Retrieves or computes the frequency components (e.g., for RoPE) from the configuration.

Uses self.config.get_basic_frequencies() and caches the result.

Returns

The frequency components, potentially cached.

Return type

jnp.ndarray

class easydel.modules.grok_1.modeling_grok_1_flax.Grok1SparseMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Grok-1 Sparse Mixture of Experts (MoE) block.

This module implements the sparse MoE layer used in Grok-1. It routes tokens to a subset of experts based on learned gating weights.

config#

Configuration object for the model.

Type

Grok1Config

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs