easydel.modules.mixtral.__init__#

class easydel.modules.mixtral.__init__.MixtralConfig(vocab_size=32000, hidden_size=4096, intermediate_size=14336, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=8, hidden_act='silu', max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=1000000.0, sliding_window=4096, attention_dropout=0.0, num_experts_per_tok=2, num_local_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, initialization_of_moe: bool = False, router_jitter_noise=0.0, **kwargs)[source]#

Bases: EasyDeLBaseConfig

Configuration objects inherit from [EasyDeLBaseConfig] and can be used to control the model outputs. Read the documentation from [EasyDeLBaseConfig] for more information.

Parameters
  • vocab_size (int, optional, defaults to 32000) – Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the inputs_ids passed to the forward method.

  • hidden_size (int, optional, defaults to 4096) – Dimensionality of the encoder layers and the pooler layer.

  • intermediate_size (int, optional, defaults to 14336) – Dimensionality of the “intermediate” (i.e., feed-forward) layer in the Transformer encoder.

  • num_hidden_layers (int, optional, defaults to 32) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int, optional, defaults to 32) – Number of attention heads for each attention layer in the Transformer encoder.

  • num_key_value_heads (int, optional, defaults to 8) – Number of key and value heads for each attention layer in the Transformer encoder.

  • hidden_act (str or function, optional, defaults to “silu”) – The non-linear activation function (function or string) to use in the encoder and pooler. If string, “gelu”, “relu”, “swish” and “gelu_new” are supported.

  • max_position_embeddings (int, optional, defaults to 4096 * 32) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 2048 or 4096).

  • initializer_range (float, optional, defaults to 0.02) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • rms_norm_eps (float, optional, defaults to 1e-5) – The epsilon used by the rms normalization layers.

  • use_cache (bool, optional, defaults to True) – Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if config.is_decoder=True.

  • pad_token_id (int, optional) – The index of the padding token in the vocabulary.

  • bos_token_id (int, optional, defaults to 1) – The id of the beginning-of-sequence token.

  • eos_token_id (int, optional, defaults to 2) – The id of the end-of-sequence token.

  • tie_word_embeddings (bool, optional, defaults to False) – Whether to tie the weights of the input embeddings and the output embeddings.

  • rope_theta (float, optional, defaults to 1e6) – The theta value to use for rotary position embeddings.

  • sliding_window (int, optional, defaults to 4096) – The sliding window size.

  • attention_dropout (float, optional, defaults to 0.0) – The dropout ratio for the attention probabilities.

  • num_experts_per_tok (int, optional, defaults to 2) – The number of experts per token.

  • num_local_experts (int, optional, defaults to 8) – The number of local experts.

  • output_router_logits (bool, optional, defaults to False) – Whether to output router logits.

  • router_aux_loss_coef (float, optional, defaults to 0.001) – The router auxiliary loss coefficient.

  • gradient_checkpointing (str, optional, defaults to “nothing_saveable”) – The gradient checkpointing configuration.

  • use_scan_mlp (bool, optional, defaults to False) – Whether to use the scan implementation for the MLP.

  • scan_mlp_chunk_size (int, optional, defaults to 1024) – The chunk size to use when scanning the MLP.

  • number_rep_kv (int, optional, defaults to 1) – Number of repetitions for the key and value vectors.

  • bits (int, optional) – The number of bits to quantize the model to.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – The configuration for rope scaling.

  • attention_bias (bool, optional, defaults to False) – Whether to use bias in the attention layer.

  • initialization_of_moe (bool, optional, defaults to False) – Whether to initialize the MoE layers.

  • router_jitter_noise (float, optional, defaults to 0.0) – The jitter noise for the router.

attach_custom_arguments(gradient_checkpointing: EasyDeLGradientCheckPointers = EasyDeLGradientCheckPointers.NONE, use_scan_mlp: bool = False, scan_mlp_chunk_size: int = 1024, number_rep_kv: int = 1, bits: Optional[int] = None, attention_dropout: float = 0.0, rope_scaling: Dict[str, Union[str, float]] = None, attention_bias: bool = False, initialization_of_moe: bool = False, **kwargs)[source]#

The attach_custom_arguments function adds the following arguments to the model:

Parameters
  • gradient_checkpointing (EasyDeLGradientCheckPointers, optional) – Gradient checkpointing strategy. Defaults to EasyDeLGradientCheckPointers.NONE.

  • use_scan_mlp (bool, optional) – Whether to use scan for MLP layers. Defaults to False.

  • scan_mlp_chunk_size (int, optional) – Chunk size for scan MLP. Defaults to 1024.

  • number_rep_kv (int, optional) – Number of repetitions for key/value heads. Defaults to 1.

  • bits (tp.Optional[int], optional) – Quantization bits. Defaults to None.

  • attention_dropout (float, optional) – Dropout probability for attention. Defaults to 0.0.

  • rope_scaling (tp.Dict[str, tp.Union[str, float]], optional) – RoPE scaling configuration. Defaults to None.

  • attention_bias (bool, optional) – Whether to use bias in attention layers. Defaults to False.

  • initialization_of_moe (bool, optional) – Whether MoE layers are being initialized. Defaults to False.

  • **kwargs – Additional keyword arguments (ignored).

Return type

A tuple of the following

get_partition_rules(*args, **kwargs)[source]#

Get the partition rules for the model. This method defines how the model’s parameters are partitioned across devices for distributed training and inference.

Parameters
  • *args – Additional positional arguments (unused).

  • **kwargs – Additional keyword arguments (unused).

Returns

A tuple of partition rules, where each rule is a tuple

containing a regex pattern for parameter names and the corresponding PartitionSpec.

Return type

tp.Tuple[tp.Tuple[str, PartitionSpec]]

static get_weight_decay_exclusions()[source]#

Returns a tuple of parameter names for which weight decay should be excluded.

Returns

An empty tuple, indicating no specific weight decay exclusions for this model.

Return type

tuple

property granted_freq_max_position_embedding: int#

Returns the maximum position embedding size specifically for frequency-based position embeddings.

If freq_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for frequency encoding.

Return type

int

property granted_mask_max_position_embedding: int#

Returns the maximum position embedding size specifically for mask-based position embeddings.

If mask_max_position_embeddings is set, it returns that value. Otherwise, it falls back to max_position_embeddings.

Returns

The granted maximum position embedding size for mask encoding.

Return type

int

model_type: str = 'mixtral'#
static rng_keys()[source]#

Returns the names of the random number generator keys used by the model.

Returns

A tuple containing “params”, “dropout”, and “jitter” as the RNG keys.

Return type

tuple

class easydel.modules.mixtral.__init__.MixtralForCausalLM(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mixtral model with a Causal Language Modeling head.

This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (lm_head) that projects the transformer’s output hidden states to the vocabulary size, producing logits for next token prediction. It also handles the calculation of the auxiliary loss from the MoE layers.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Mixtral transformer model.

Type

MixtralModel

lm_head#

The linear layer for projecting hidden states to vocabulary logits.

Type

ParallelLinear

num_experts#

Total number of experts.

Type

int

num_experts_per_tok#

Number of experts to route per token.

Type

int

class easydel.modules.mixtral.__init__.MixtralForSequenceClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Mixtral model with a Sequence Classification head.

This model consists of the base Mixtral transformer (MixtralModel) followed by a linear layer (score) that projects the transformer’s output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

model#

The core Mixtral transformer model.

Type

MixtralModel

score#

The linear layer for classification.

Type

ParallelLinear

num_experts#

Total number of experts.

Type

int

num_experts_per_tok#

Number of experts to route per token.

Type

int

class easydel.modules.mixtral.__init__.MixtralModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The base Mixtral model transformer.

This class represents the core transformer architecture of the Mixtral model, consisting of an embedding layer, multiple MixtralDecoderLayer layers (with sparse MoE), and a final layer normalization.

config#

Configuration object for the model.

Type

MixtralConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

Precision setting for JAX operations.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

embed_tokens#

Embedding layer for input tokens.

Type

nn.Embed

layers#

List of decoder layers.

Type

tp.List[MixtralDecoderLayer]

norm#

Final layer normalization.

Type

RMSNorm

gradient_checkpointing#

Gradient checkpointing configuration.

Type

EasyDeLGradientCheckPointers