easydel.modules.mixtral.modeling_mixtral#
- class easydel.modules.mixtral.modeling_mixtral.MixtralAttention(*args: Any, **kwargs: Any)[source]#
Bases:
UnifiedAttentionMixtral Attention module with sliding window support.
Inherits from UnifiedAttention with Mixtral-specific customizations: - Sliding window attention support - Custom RoPE configuration
- class easydel.modules.mixtral.modeling_mixtral.MixtralDecoderLayer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMixtral Transformer Decoder Layer with updated MoE integration.
- class easydel.modules.mixtral.modeling_mixtral.MixtralForCausalLM(*args: Any, **kwargs: Any)[source]#
Bases:
BaseCausalLMModule[MixtralModel,MixtralConfig]Mixtral model with a Causal Language Modeling head.
- class easydel.modules.mixtral.modeling_mixtral.MixtralForSequenceClassification(*args: Any, **kwargs: Any)[source]#
Bases:
BaseSequenceClassificationModule[MixtralModel,MixtralConfig]Mixtral model with a Sequence Classification head.
- class easydel.modules.mixtral.modeling_mixtral.MixtralMoEMlp(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleMixtral MoE MLP using the new ParallelMoELinear layers.
- reform_param: ClassVar = {'down_proj$': {'inverse_spliter': <function MixtralMoEMlp.<lambda>>, 'splits': [{'name': 'w2.kernel', 'spliter': <function MixtralMoEMlp.<lambda>>}]}, 'gate_up_proj$': {'inverse_spliter': <function MixtralMoEMlp.<lambda>>, 'splits': [{'name': 'w1.kernel', 'spliter': <function MixtralMoEMlp.<lambda>>}, {'name': 'w3.kernel', 'spliter': <function MixtralMoEMlp.<lambda>>}]}}#
- class easydel.modules.mixtral.modeling_mixtral.MixtralModel(*args: Any, **kwargs: Any)[source]#
Bases:
EasyDeLBaseModuleThe base Mixtral model transformer.
This class represents the core transformer architecture of the Mixtral model, consisting of an embedding layer, multiple MixtralDecoderLayer layers (with sparse MoE), and a final layer normalization.
- config#
Configuration object for the model.
- Type
- dtype#
Data type for computation.
- Type
jnp.dtype
- param_dtype#
Data type for parameters.
- Type
jnp.dtype
- precision#
Precision setting for JAX operations.
- Type
jax.lax.PrecisionLike
- rngs#
Random number generators.
- Type
nn.Rngs
- embed_tokens#
Embedding layer for input tokens.
- Type
nn.Embed
- layers#
List of decoder layers.
- Type
tp.List[MixtralDecoderLayer]
- gradient_checkpointing#
Gradient checkpointing configuration.
- class easydel.modules.mixtral.modeling_mixtral.MixtralSparseMoeBlock(*args: Any, **kwargs: Any)[source]#
Bases:
BaseMoeModuleMixtral Sparse MoE block using BaseMoeModule.