easydel.modules.mixtral.modeling_mixtral#

class easydel.modules.mixtral.modeling_mixtral.MixtralAttention(*args: Any, **kwargs: Any)[source]#

Bases: UnifiedAttention

Mixtral Attention module with sliding window support.

Inherits from UnifiedAttention with Mixtral-specific customizations: - Sliding window attention support - Custom RoPE configuration

class easydel.modules.mixtral.modeling_mixtral.MixtralDecoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Mixtral Transformer Decoder Layer with updated MoE integration.

class easydel.modules.mixtral.modeling_mixtral.MixtralForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: BaseCausalLMModule[MixtralModel, MixtralConfig]

Mixtral model with a Causal Language Modeling head.

class easydel.modules.mixtral.modeling_mixtral.MixtralForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: BaseSequenceClassificationModule[MixtralModel, MixtralConfig]

Mixtral model with a Sequence Classification head.

class easydel.modules.mixtral.modeling_mixtral.MixtralMoEMlp(*args: Any, **kwargs: Any)[source]#

Bases: Module

Mixtral MoE MLP using the new ParallelMoELinear layers.

reform_param: ClassVar = {'down_proj$': {'inverse_spliter': <function MixtralMoEMlp.<lambda>>, 'splits': [{'name': 'w2.kernel', 'spliter': <function MixtralMoEMlp.<lambda>>}]}, 'gate_up_proj$': {'inverse_spliter': <function MixtralMoEMlp.<lambda>>, 'splits': [{'name': 'w1.kernel', 'spliter': <function MixtralMoEMlp.<lambda>>}, {'name': 'w3.kernel', 'spliter': <function MixtralMoEMlp.<lambda>>}]}}#
class easydel.modules.mixtral.modeling_mixtral.MixtralModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Mixtral model transformer.

This class represents the core transformer architecture of the Mixtral model, consisting of an embedding layer, multiple MixtralDecoderLayer layers (with sparse MoE), and a final layer normalization.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[MixtralDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers

get_decoder()[source]#

Returns the decoder part of the model’s graph definition.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. Decoder-Only models don’t have an encoder.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.mixtral.modeling_mixtral.MixtralSparseMoeBlock(*args: Any, **kwargs: Any)[source]#

Bases: BaseMoeModule

Mixtral Sparse MoE block using BaseMoeModule.