easydel.modules.mixtral.modeling_mixtral_flax#
- class easydel.modules.mixtral.modeling_mixtral_flax.MixtralAttention(*args: Any, **kwargs: Any)[source]#
Bases:
AttentionModuleMixtral Attention module.
This module implements the multi-head attention mechanism with rotary position embeddings and Grouped Query Attention (GQA) used in the Mixtral model.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computations.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
Dimensionality of the hidden states.
- Type
int
- num_heads#
Number of attention heads.
- Type
int
- head_dim#
Dimensionality of each attention head.
- Type
int
- num_key_value_heads#
Number of key/value heads (for GQA).
- Type
int
- num_key_value_groups#
Number of query head groups for each key/value head.
- Type
int
- max_position_embeddings#
Maximum sequence length supported.
- Type
int
- q_proj#
Linear layer for query projection.
- Type
- k_proj#
Linear layer for key projection.
- Type
- v_proj#
Linear layer for value projection.
- Type
- o_proj#
Linear layer for the output projection.
- Type
- attention_performer#
Module to perform the core attention computation.
- rotary#
Rotary position embedding module.
- Type
RoPE
- class easydel.modules.mixtral.modeling_mixtral_flax.MixtralBLockSparseTop2MLP(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMixtral Block Sparse Top-2 MLP module.
This module implements the specific MLP structure used within the sparse Mixture of Experts layer in the Mixtral model. It consists of three linear projections (gate, up, down) and a SiLU activation function.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- w1#
Gate projection layer.
- Type
- w3#
Up projection layer.
- Type
- w2#
Down projection layer.
- Type
- act_fn#
Activation function (SiLU).
- Type
callable
- class easydel.modules.mixtral.modeling_mixtral_flax.MixtralDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMixtral Transformer Decoder Layer.
This module represents a single decoder layer in the Mixtral model, combining self-attention and a sparse MoE block with residual connections and layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- self_attn#
The self-attention module.
- Type
- post_attention_layernorm#
Layer normalization after the attention layer and before the MoE block.
- Type
- block_sparse_moe#
The sparse Mixture of Experts block.
- class easydel.modules.mixtral.modeling_mixtral_flax.MixtralForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMixtral model with a Causal Language Modeling head.
This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. It also handles the calculation of the auxiliary loss from the MoE layers.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Mixtral transformer model.
- Type
- lm_head#
The linear layer for projecting hidden states to vocabulary logits.
- Type
- num_experts#
Total number of experts.
- Type
int
- num_experts_per_tok#
Number of experts to route per token.
- Type
int
- class easydel.modules.mixtral.modeling_mixtral_flax.MixtralForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleMixtral model with a Sequence Classification head.
This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- model#
The core Mixtral transformer model.
- Type
- score#
The linear layer for classification.
- Type
- num_experts#
Total number of experts.
- Type
int
- num_experts_per_tok#
Number of experts to route per token.
- Type
int
- class easydel.modules.mixtral.modeling_mixtral_flax.MixtralModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Mixtral model transformer.
This class represents the core transformer architecture of the Mixtral model, consisting of an embedding layer, multiple MixtralDecoderLayer layers (with sparse MoE), and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[MixtralDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.modules.mixtral.modeling_mixtral_flax.MixtralSparseMoeBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMixtral Sparse Mixture of Experts (MoE) block.
This module implements the sparse MoE layer used in Mixtral. It routes each token to the top-2 experts based on learned gating weights.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- gate#
Linear layer for computing router logits.
- Type
- experts#
List of expert MLP modules.
- Type
tp.List[MixtralBLockSparseTop2MLP]