easydel.modules.mixtral.modeling_mixtral_flax

Contents

easydel.modules.mixtral.modeling_mixtral_flax#

class easydel.modules.mixtral.modeling_mixtral_flax.MixtralAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

Mixtral Attention module.

This module implements the multi-head attention mechanism with rotary position embeddings and Grouped Query Attention (GQA) used in the Mixtral model.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computations.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

hidden_size#

Dimensionality of the hidden states.

Type

int

num_heads#

Number of attention heads.

Type

int

head_dim#

Dimensionality of each attention head.

Type

int

num_key_value_heads#

Number of key/value heads (for GQA).

Type

int

num_key_value_groups#

Number of query head groups for each key/value head.

Type

int

max_position_embeddings#

Maximum sequence length supported.

Type

int

q_proj#

Linear layer for query projection.

Type

ParallelLinear

k_proj#

Linear layer for key projection.

Type

ParallelLinear

v_proj#

Linear layer for value projection.

Type

ParallelLinear

o_proj#

Linear layer for the output projection.

Type

ParallelLinear

attention_performer#

Module to perform the core attention computation.

Type

FlexibleAttentionModule

rotary#

Rotary position embedding module.

Type

RoPE

class easydel.modules.mixtral.modeling_mixtral_flax.MixtralBLockSparseTop2MLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Mixtral Block Sparse Top-2 MLP module.

This module implements the specific MLP structure used within the sparse Mixture of Experts layer in the Mixtral model. It consists of three linear projections (gate, up, down) and a SiLU activation function.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

w1#

Gate projection layer.

Type

ParallelLinear

w3#

Up projection layer.

Type

ParallelLinear

w2#

Down projection layer.

Type

ParallelLinear

act_fn#

Activation function (SiLU).

Type

callable

class easydel.modules.mixtral.modeling_mixtral_flax.MixtralDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Mixtral Transformer Decoder Layer.

This module represents a single decoder layer in the Mixtral model, combining self-attention and a sparse MoE block with residual connections and layer normalization.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

input_layernorm#

Layer normalization before the attention layer.

Type

RMSNorm

self_attn#

The self-attention module.

Type

MixtralAttention

post_attention_layernorm#

Layer normalization after the attention layer and before the MoE block.

Type

RMSNorm

block_sparse_moe#

The sparse Mixture of Experts block.

Type

MixtralSparseMoeBlock

class easydel.modules.mixtral.modeling_mixtral_flax.MixtralForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mixtral model with a Causal Language Modeling head.

This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. It also handles the calculation of the auxiliary loss from the MoE layers.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Mixtral transformer model.

Type

MixtralModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

num_experts#

Total number of experts.

Type

int

num_experts_per_tok#

Number of experts to route per token.

Type

int

class easydel.modules.mixtral.modeling_mixtral_flax.MixtralForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mixtral model with a Sequence Classification head.

This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Mixtral transformer model.

Type

MixtralModel

score#

The linear layer for classification.

Type

ParallelLinear

num_experts#

Total number of experts.

Type

int

num_experts_per_tok#

Number of experts to route per token.

Type

int

class easydel.modules.mixtral.modeling_mixtral_flax.MixtralModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Mixtral model transformer.

This class represents the core transformer architecture of the Mixtral model, consisting of an embedding layer, multiple MixtralDecoderLayer layers (with sparse MoE), and a final layer normalization.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[MixtralDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

class easydel.modules.mixtral.modeling_mixtral_flax.MixtralSparseMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Mixtral Sparse Mixture of Experts (MoE) block.

This module implements the sparse MoE layer used in Mixtral. It routes each token to the top-2 experts based on learned gating weights.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

gate#

Linear layer for computing router logits.

Type

ParallelLinear

experts#

List of expert MLP modules.

Type

tp.List[MixtralBLockSparseTop2MLP]