Source code for easydel.modules.mixtral.modeling_mixtral_flax

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
import typing as tp

import chex
import jax
from eformer import common_types
from eformer.escale import apply_logical_sharding
from flax import nnx as nn
from jax import numpy as jnp

from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType, register_module
from easydel.infra.loss_utils import auxiliary_load_balancing_loss_func
from easydel.infra.modeling_outputs import (
	AttentionLayerOutput,
	DecoderLayerOutput,
	MoeCausalLMOutput,
	MoeModelOutput,
	SequenceClassifierOutput,
)
from easydel.infra.utils import (
	ACT2FN,
	auto_remat,
	block_wise_ffn,
	get_dot_general_by_bits,
)
from easydel.layers.attention import AttentionModule, FlexibleAttentionModule
from easydel.layers.caching import (
	PagedAttentionCache,
	PagedAttentionCacheView,
	PagedAttentionMetadata,
	TransformerCache,
	TransformerCacheView,
	TransformerMetadata,
)
from easydel.layers.linear import ParallelLinear
from easydel.layers.norms import RMSNorm

from .mixtral_configuration import MixtralConfig as MixtralConfig


[docs]class MixtralAttention(AttentionModule): """Mixtral Attention module. This module implements the multi-head attention mechanism with rotary position embeddings and Grouped Query Attention (GQA) used in the Mixtral model. Attributes: config (MixtralConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computations. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. hidden_size (int): Dimensionality of the hidden states. num_heads (int): Number of attention heads. head_dim (int): Dimensionality of each attention head. num_key_value_heads (int): Number of key/value heads (for GQA). num_key_value_groups (int): Number of query head groups for each key/value head. max_position_embeddings (int): Maximum sequence length supported. q_proj (ParallelLinear): Linear layer for query projection. k_proj (ParallelLinear): Linear layer for key projection. v_proj (ParallelLinear): Linear layer for value projection. o_proj (ParallelLinear): Linear layer for the output projection. attention_performer (FlexibleAttentionModule): Module to perform the core attention computation. rotary (RoPE): Rotary position embedding module. """ def __init__( self, config: MixtralConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the MixtralAttention module. Args: config (MixtralConfig): The configuration object for the Mixtral model. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. """ super().__init__(config=config) self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings linear = functools.partial( ParallelLinear, use_bias=getattr(config, "attention_bias", False), dtype=dtype, param_dtype=param_dtype, precision=precision, kernel_init=nn.initializers.normal(), **get_dot_general_by_bits(config.bits, config.easy_method), ) self.q_proj = linear( self.hidden_size, self.num_heads * self.head_dim, rngs=rngs, ) self.k_proj = linear( self.hidden_size, self.num_key_value_heads * self.head_dim, rngs=rngs, ) self.v_proj = linear( self.hidden_size, self.num_key_value_heads * self.head_dim, rngs=rngs, ) self.o_proj = linear( self.num_heads * self.head_dim, self.hidden_size, rngs=rngs, ) self.attention_performer = FlexibleAttentionModule( dropout_prob=config.attention_dropout, base_config=config, softmax_scale=self.head_dim**-0.5, ) self.rotary = self.config.get_basic_rope(self.dtype, self.head_dim) def __call__( self, hidden_states: chex.Array, attention_mask: chex.Array, position_ids: chex.Array, causal_mask: tp.Optional[chex.Array | bool], mode: common_types.RUNTIME_MODE_TYPES, # type:ignore cache_view: tp.Optional[TransformerCacheView | PagedAttentionCacheView] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: bool = False, fcm_mask: tp.Optional[chex.Array] = None, frequencies: tp.Optional[chex.Array] = None, ): """Forward pass of the MixtralAttention module. Args: hidden_states (chex.Array): Input hidden states. Shape: (batch_size, sequence_length, hidden_size). attention_mask (chex.Array): Mask to apply on the attention scores. Shape: (batch_size, 1, query_length, key_length). position_ids (chex.Array): Position indices for the tokens. Shape: (batch_size, sequence_length). causal_mask (tp.Optional[chex.Array | bool]): Causal mask for ensuring autoregressive behavior. cache_view (tp.Optional[TransformerCacheView | PagedAttentionCacheView]): Cache view for attention KVs. cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention. segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional). output_attentions (bool): Whether to return attention weights. Default is False. fcm_mask (tp.Optional[chex.Array]): Flash Chunking Mask (FCM) for attention. frequencies (tp.Optional[chex.Array]): Precomputed rotary frequency embeddings. Returns: tp.Union[tp.Tuple[chex.Array, chex.Array], tp.Tuple[chex.Array]]: A tuple containing the attention output hidden states. If `output_attentions` is True, it also includes the attention weights. """ batch_size, sequence_length = hidden_states.shape[:2] query_states, key_states, value_states = ( self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states), ) query_states = query_states.reshape( batch_size, sequence_length, self.config.num_attention_heads, self.head_dim, ) key_states = key_states.reshape( batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim, ) value_states = value_states.reshape( batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim, ) ( query_states, key_states, value_states, ) = self.apply_qkv_shardings(query_states, key_states, value_states) query_states, key_states = self.rotary( positions=position_ids, query=query_states, key=key_states, frequencies=frequencies, ) ( key_states, value_states, attention_mask, init_attention_bias, cache_view, ) = self.concatenate( query=query_states, key=key_states, value=value_states, cache_view=cache_view, cache_metadata=cache_metadata, attention_mask=attention_mask, causal_mask=causal_mask, fcm_mask=fcm_mask, ) attentions = self.attention_performer.forward( query_states=query_states, key_states=key_states, value_states=value_states, mode=mode, bias=None, cache_metadata=cache_metadata, cache_view=cache_view, init_bias=init_attention_bias, attention_mask=attention_mask, segment_ids=segment_ids, causal=True, dropout_rng=self.rngs.params(), ) attn_output = self.shard_attention_prod( self._merge_heads(attentions.attention_outputs) ) attn_output = self.o_proj(attn_output) return AttentionLayerOutput( attention_output=attn_output, attention_weight=attentions.attention_weights if output_attentions else None, cache_view=cache_view, )
[docs]class MixtralBLockSparseTop2MLP(nn.Module): """Mixtral Block Sparse Top-2 MLP module. This module implements the specific MLP structure used within the sparse Mixture of Experts layer in the Mixtral model. It consists of three linear projections (gate, up, down) and a SiLU activation function. Attributes: config (MixtralConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computation. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. w1 (ParallelLinear): Gate projection layer. w3 (ParallelLinear): Up projection layer. w2 (ParallelLinear): Down projection layer. act_fn (callable): Activation function (SiLU). """ def __init__( self, config: MixtralConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the MixtralBLockSparseTop2MLP module. Args: config (MixtralConfig): The configuration object for the Mixtral model. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. """ super().__init__() self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs linear = functools.partial( ParallelLinear, use_bias=False, dtype=dtype, param_dtype=param_dtype, precision=precision, kernel_init=nn.initializers.normal(), **get_dot_general_by_bits(config.bits, config.easy_method), ) self.w1 = linear( self.config.hidden_size, self.config.intermediate_size, rngs=rngs, ) self.w3 = linear( self.config.hidden_size, self.config.intermediate_size, rngs=rngs, ) self.w2 = linear( self.config.intermediate_size, self.config.hidden_size, rngs=rngs, ) self.act_fn = ACT2FN[self.config.hidden_act] def __call__(self, hidden_states: chex.Array): hidden_states = apply_logical_sharding( hidden_states, dynamic_axes=common_types.HiddenStateSharding, partition_manager=self.config.partition_manager, ) w1 = self.act_fn(self.w1(hidden_states)) w3 = self.w3(hidden_states) hidden_states = self.w2(w1 * w3) hidden_states = apply_logical_sharding( hidden_states, dynamic_axes=common_types.HiddenStateSharding, partition_manager=self.config.partition_manager, ) return hidden_states
[docs]class MixtralSparseMoeBlock(nn.Module): """Mixtral Sparse Mixture of Experts (MoE) block. This module implements the sparse MoE layer used in Mixtral. It routes each token to the top-2 experts based on learned gating weights. Attributes: config (MixtralConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computation. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. gate (ParallelLinear): Linear layer for computing router logits. experts (tp.List[MixtralBLockSparseTop2MLP]): List of expert MLP modules. """ def __init__( self, config: MixtralConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the MixtralSparseMoeBlock. Args: config (MixtralConfig): The configuration object for the Mixtral model. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. """ super().__init__() self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs self.gate = ParallelLinear( config.hidden_size, config.num_local_experts, use_bias=False, rngs=rngs, dtype=dtype, param_dtype=param_dtype, precision=precision, kernel_init=nn.initializers.normal(), ) self.experts = [ MixtralBLockSparseTop2MLP( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) for i in range(config.num_local_experts) ] def __call__(self, hidden_states: chex.Array) -> tp.Tuple[chex.Array, chex.Array]: hidden_states = apply_logical_sharding( hidden_states, dynamic_axes=common_types.HiddenStateSharding, partition_manager=self.config.partition_manager, ) router_logits = self.gate(hidden_states).astype( jnp.promote_types(self.dtype, jnp.float32) ) routing_weights, selected_experts = jax.lax.top_k( router_logits, k=self.config.num_experts_per_tok, ) routing_weights = jax.nn.softmax( routing_weights.astype(jnp.promote_types(self.dtype, jnp.float32)), axis=-1, ) final_hidden_state = jnp.zeros_like(hidden_states) for index in range(self.config.num_local_experts): expert_layer_output = ( block_wise_ffn( self.experts[index], hidden_states, self.config.scan_mlp_chunk_size, ) if self.config.use_scan_mlp else self.experts[index](hidden_states) ) expert_layer_output_exp = ( jnp.sum(jnp.multiply(selected_experts == index, routing_weights), axis=-1)[ :, :, None ] * expert_layer_output ) final_hidden_state += expert_layer_output_exp return ( final_hidden_state, router_logits, )
[docs]class MixtralDecoderLayer(nn.Module): """Mixtral Transformer Decoder Layer. This module represents a single decoder layer in the Mixtral model, combining self-attention and a sparse MoE block with residual connections and layer normalization. Attributes: config (MixtralConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computation. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. input_layernorm (RMSNorm): Layer normalization before the attention layer. self_attn (MixtralAttention): The self-attention module. post_attention_layernorm (RMSNorm): Layer normalization after the attention layer and before the MoE block. block_sparse_moe (MixtralSparseMoeBlock): The sparse Mixture of Experts block. """ def __init__( self, config: MixtralConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the MixtralDecoderLayer. Args: config (MixtralConfig): The configuration object for the Mixtral model. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. """ super().__init__() self.config = config self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.rngs = rngs attn_block = MixtralAttention mlp_block = MixtralSparseMoeBlock attn_block, mlp_block = auto_remat( attn_block, mlp_block, policy=config.gradient_checkpointing, ) self.self_attn = attn_block( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.block_sparse_moe = mlp_block( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.input_layernorm = RMSNorm( dim=config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.post_attention_layernorm = RMSNorm( dim=config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__( self, hidden_states: chex.Array, attention_mask: chex.Array, position_ids: chex.Array, causal_mask: tp.Optional[chex.Array | bool], mode: common_types.RUNTIME_MODE_TYPES, # type:ignore cache_view: tp.Optional[TransformerCacheView | PagedAttentionCacheView] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: bool = False, output_router_logits: bool = False, fcm_mask: tp.Optional[chex.Array] = None, frequencies: tp.Optional[chex.Array] = None, ) -> DecoderLayerOutput: """Forward pass of the MixtralDecoderLayer module. Args: hidden_states (chex.Array): Input hidden states. Shape: (batch_size, sequence_length, hidden_size). attention_mask (chex.Array): Mask to apply on the attention scores. Shape: (batch_size, 1, query_length, key_length). position_ids (chex.Array): Position indices for the tokens. Shape: (batch_size, sequence_length). causal_mask (tp.Optional[chex.Array | bool]): Causal mask for ensuring autoregressive behavior. segment_ids (tp.Optional[chex.Array]): Segment IDs for segment-based attention (optional). cache_view (tp.Optional[TransformerCacheView | PagedAttentionCacheView]): Cache view for attention KVs. cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention. output_attentions (bool): Whether to return attention weights. Default is False. output_router_logits (bool): Whether to return router logits from the MoE layer. Default is False. fcm_mask (tp.Optional[chex.Array]): Flash Chunking Mask (FCM) for attention. frequencies (tp.Optional[chex.Array]): Precomputed rotary frequency embeddings. Returns: DecoderLayerOutput: A tuple containing: - hidden_states (chex.Array): The output hidden states. - attention_weights (chex.Array, optional): Attention weights if `output_attentions` is True. - router_logits (chex.Array, optional): Router logits if `output_router_logits` is True. """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = apply_logical_sharding( hidden_states, dynamic_axes=common_types.HiddenStateSharding, partition_manager=self.config.partition_manager, ) attn_outputs = self.self_attn( hidden_states, attention_mask, position_ids, causal_mask, mode, cache_view, cache_metadata, segment_ids, output_attentions, fcm_mask, frequencies, ) hidden_states = attn_outputs.attention_output hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, router_logits = self.block_sparse_moe(hidden_states) hidden_states = residual + hidden_states return DecoderLayerOutput( hidden_states=hidden_states, attention_weight=attn_outputs.attention_weight, router_logits=router_logits if output_router_logits else None, cache_view=attn_outputs.cache_view, )
[docs]@register_module( TaskType.BASE_MODULE, config=MixtralConfig, model_type="mixtral", ) class MixtralModel(EasyDeLBaseModule): """The base Mixtral model transformer. This class represents the core transformer architecture of the Mixtral model, consisting of an embedding layer, multiple MixtralDecoderLayer layers (with sparse MoE), and a final layer normalization. Attributes: config (MixtralConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computation. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. embed_tokens (nn.Embed): Embedding layer for input tokens. layers (tp.List[MixtralDecoderLayer]): List of decoder layers. norm (RMSNorm): Final layer normalization. gradient_checkpointing (EasyDeLGradientCheckPointers): Gradient checkpointing configuration. """ def __init__( self, config: MixtralConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the MixtralModel. Args: config (MixtralConfig): The configuration object for the Mixtral model. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. """ super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.embed_tokens = nn.Embed( config.vocab_size, config.hidden_size, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) self.layers = [ MixtralDecoderLayer( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) for _ in range(config.num_hidden_layers) ] self.norm = RMSNorm( dim=config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, param_dtype=param_dtype, rngs=rngs, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, ) -> MoeModelOutput: """Forward pass of the MixtralModel. Args: input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length). inputs_embeds (tp.Optional[chex.Array]): Input embeddings. Shape: (batch_size, sequence_length, hidden_size). Either `input_ids` or `inputs_embeds` must be provided. attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices. Shape: (batch_size, sequence_length). position_ids (tp.Optional[chex.Array]): Position indices for the tokens. Shape: (batch_size, sequence_length). segment_ids (tp.Optional[chex.Array]): Segment IDs (unused). output_attentions (tp.Optional[bool]): Whether to return attention weights. Defaults to `config.output_attentions`. output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers. Defaults to `config.output_hidden_states`. output_router_logits (tp.Optional[bool]): Whether to return router logits from the MoE layers. Defaults to `config.output_router_logits`. past_key_values (tp.Optional[TransformerCache | PagedAttentionCache]): Precomputed key/value states for attention. cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention. Returns: MoeModelOutput: The model's output. returns a `MoeModelOutput` object containing `last_hidden_state`, `hidden_states` (optional), `attentions` (optional), and `router_logits` (optional). Raises: ValueError: If neither `input_ids` nor `inputs_embeds` is provided. """ if output_router_logits is None: output_router_logits = self.config.output_router_logits output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids.astype("i4")) batch_size, sequence_length, _ = inputs_embeds.shape assert sequence_length <= self.config.max_position_embeddings, ( f"Maximum Position Embedding Reached ! (Excepted <= {self.config.max_position_embeddings} got {sequence_length})" ) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length), "b1") else: if attention_mask.dtype != jnp.bool: attention_mask = jnp.astype(attention_mask == 1, "b1") if position_ids is None: position_ids = jnp.broadcast_to( jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0), (batch_size, sequence_length), ).astype(jnp.int32) if attention_mask.ndim == 2: attention_mask = jnp.expand_dims(attention_mask, (1, 2)) hidden_states = inputs_embeds if mode is None: mode = ( common_types.MODE_DECODE if sequence_length == 1 and past_key_values is not None else common_types.MODE_TRAIN ) if past_key_values is None: past_key_values = TransformerCache.init_empty(len(self.layers)) hidden_states = apply_logical_sharding( hidden_states, dynamic_axes=common_types.HiddenStateSharding, partition_manager=self.config.partition_manager, ) for idx, block in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = block( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, mode=mode, cache_view=past_key_values.views[idx], cache_metadata=cache_metadata, output_attentions=output_attentions, output_router_logits=output_router_logits, causal_mask=self.causal_mask, segment_ids=segment_ids, frequencies=self.frequencies, ) hidden_states = layer_outputs.hidden_states if output_attentions: all_self_attns += (layer_outputs.attention_weight,) if output_router_logits: all_router_logits += (layer_outputs.router_logits,) past_key_values[idx] = layer_outputs.cache_view hidden_states = self.norm(hidden_states) return MoeModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, past_key_values=past_key_values, )
[docs]@register_module( TaskType.CAUSAL_LM, config=MixtralConfig, model_type="mixtral", ) class MixtralForCausalLM(EasyDeLBaseModule): """Mixtral model with a Causal Language Modeling head. This model consists of the base Mixtral transformer (`MixtralModel`) followed by a linear layer (`lm_head`) that projects the transformer's output hidden states to the vocabulary size, producing logits for next token prediction. It also handles the calculation of the auxiliary loss from the MoE layers. Attributes: config (MixtralConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computation. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. model (MixtralModel): The core Mixtral transformer model. lm_head (ParallelLinear): The linear layer for projecting hidden states to vocabulary logits. num_experts (int): Total number of experts. num_experts_per_tok (int): Number of experts to route per token. """ def __init__( self, config: MixtralConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the MixtralForCausalLM model. Args: config (MixtralConfig): The configuration object for the Mixtral model. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. """ super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.model = MixtralModel( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.lm_head = ParallelLinear( config.hidden_size, config.vocab_size, rngs=rngs, dtype=dtype, param_dtype=param_dtype, precision=precision, use_bias=False, kernel_init=nn.initializers.normal(config.initializer_range), **get_dot_general_by_bits(config.bits, config.easy_method), ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, ) -> MoeCausalLMOutput | tp.Tuple: """Forward pass of the MixtralForCausalLM model. Args: input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length). inputs_embeds (tp.Optional[chex.Array]): Input embeddings. Shape: (batch_size, sequence_length, hidden_size). Either `input_ids` or `inputs_embeds` must be provided. attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices. Shape: (batch_size, sequence_length). position_ids (tp.Optional[chex.Array]): Position indices for the tokens. Shape: (batch_size, sequence_length). segment_ids (tp.Optional[chex.Array]): Segment IDs (unused). output_attentions (tp.Optional[bool]): Whether to return attention weights. Defaults to `config.output_attentions`. output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers. Defaults to `config.output_hidden_states`. output_router_logits (tp.Optional[bool]): Whether to return router logits from the MoE layers. Defaults to `config.output_router_logits`. past_key_values (tp.Optional[TransformerCache | PagedAttentionCache]): Precomputed key/value states for attention. cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention. Returns: MoeCausalLMOutput: The model's output. returns a `MoeCausalLMOutput` object containing `logits`, `aux_loss` (optional), `hidden_states` (optional), `attentions` (optional), and `router_logits` (optional). """ if output_router_logits is None: output_router_logits = self.config.output_router_logits outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, mode=mode, past_key_values=past_key_values, cache_metadata=cache_metadata, segment_ids=segment_ids, ) logits = self.lm_head(outputs.last_hidden_state) aux_loss = None if output_router_logits and outputs.router_logits is not None: aux_loss = auxiliary_load_balancing_loss_func( gate_logits=outputs.router_logits, num_experts=self.config.num_local_experts, top_k=self.config.num_experts_per_tok, attention_mask=attention_mask, ) aux_loss += aux_loss * self.config.router_aux_loss_coef return MoeCausalLMOutput( aux_loss=aux_loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, past_key_values=outputs.past_key_values, )
[docs]@register_module( TaskType.SEQUENCE_CLASSIFICATION, config=MixtralConfig, model_type="mixtral", ) class MixtralForSequenceClassification(EasyDeLBaseModule): """Mixtral model with a Sequence Classification head. This model consists of the base Mixtral transformer (`MixtralModel`) followed by a linear layer (`score`) that projects the transformer's output hidden states (typically the hidden state of the first token) to the number of classes for classification. It also handles the calculation of the auxiliary loss from the MoE layers. Attributes: config (MixtralConfig): Configuration object for the model. dtype (jnp.dtype): Data type for computation. param_dtype (jnp.dtype): Data type for parameters. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. rngs (nn.Rngs): Random number generators. model (MixtralModel): The core Mixtral transformer model. score (ParallelLinear): The linear layer for classification. num_experts (int): Total number of experts. num_experts_per_tok (int): Number of experts to route per token. """ def __init__( self, config: MixtralConfig, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, precision: jax.lax.PrecisionLike = None, *, rngs: nn.Rngs, ): """Initializes the MixtralForSequenceClassification model. Args: config (MixtralConfig): The configuration object for the Mixtral model. Must include `num_labels`. dtype (jnp.dtype): Data type for computation. Defaults to jnp.float32. param_dtype (jnp.dtype): Data type for parameters. Defaults to jnp.float32. precision (jax.lax.PrecisionLike): Precision setting for JAX operations. Defaults to None. rngs (nn.Rngs): Random number generators. Raises: AssertionError: If `config.num_labels` is not defined. """ super().__init__( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) self.model = MixtralModel( config=config, dtype=dtype, param_dtype=param_dtype, precision=precision, rngs=rngs, ) assert hasattr(config, "num_labels"), ( "in order to use `SequenceClassification` Models in `EasyDeL` you first need to attach `num_labels` to model `config`" ) self.score = ParallelLinear( self.config.hidden_size, config.num_labels, dtype=dtype, param_dtype=param_dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(stddev=config.initializer_range), precision=self.precision, rngs=rngs, ) def __call__( self, input_ids: tp.Optional[chex.Array] = None, inputs_embeds: tp.Optional[chex.Array] = None, attention_mask: tp.Optional[chex.Array] = None, position_ids: tp.Optional[chex.Array] = None, segment_ids: tp.Optional[chex.Array] = None, output_attentions: tp.Optional[bool] = None, output_hidden_states: tp.Optional[bool] = None, output_router_logits: tp.Optional[bool] = None, mode: tp.Optional[common_types.RUNTIME_MODE_TYPES] = None, # type:ignore past_key_values: tp.Optional[TransformerCache | PagedAttentionCache] = None, cache_metadata: tp.Optional[TransformerMetadata | PagedAttentionMetadata] = None, ) -> SequenceClassifierOutput: """Forward pass of the MixtralForSequenceClassification model. Args: input_ids (tp.Optional[chex.Array]): Input token IDs. Shape: (batch_size, sequence_length). inputs_embeds (tp.Optional[chex.Array]): Input embeddings. Shape: (batch_size, sequence_length, hidden_size). Either `input_ids` or `inputs_embeds` must be provided. attention_mask (tp.Optional[chex.Array]): Mask to avoid performing attention on padding token indices. Shape: (batch_size, sequence_length). position_ids (tp.Optional[chex.Array]): Position indices for the tokens. Shape: (batch_size, sequence_length). segment_ids (tp.Optional[chex.Array]): Segment IDs (unused). output_attentions (tp.Optional[bool]): Whether to return attention weights. Defaults to `config.output_attentions`. output_hidden_states (tp.Optional[bool]): Whether to return hidden states for all layers. Defaults to `config.output_hidden_states`. output_router_logits (tp.Optional[bool]): Whether to return router logits from the MoE layers. Defaults to `config.output_router_logits`. past_key_values (tp.Optional[TransformerCache | PagedAttentionCache]): Precomputed key/value states for attention. cache_metadata (tp.Optional[TransformerMetadata | PagedAttentionMetadata]): Metadata for paged attention. Returns: SequenceClassifierOutput: The model's output. returns a `SequenceClassifierOutput` object containing `logits`, `aux_loss` (optional), `hidden_states` (optional), `attentions` (optional), and `router_logits` (optional). Raises: ValueError: If `config.pad_token_id` is None and `batch_size > 1`. """ transformer_outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, mode=mode, past_key_values=past_key_values, cache_metadata=cache_metadata, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, inputs_embeds=inputs_embeds, segment_ids=segment_ids, ) hidden_states = transformer_outputs.last_hidden_state logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = ( jnp.argmax(jnp.equal(input_ids, self.config.pad_token_id).astype("i4"), -1) - 1 ) sequence_lengths = sequence_lengths % input_ids.shape[-1] else: sequence_lengths = -1 pooled_logits = logits[jnp.arange(batch_size), sequence_lengths] aux_loss = None if output_router_logits and transformer_outputs.router_logits is not None: aux_loss = auxiliary_load_balancing_loss_func( gate_logits=transformer_outputs.router_logits, num_experts=self.config.num_local_experts, top_k=self.config.num_experts_per_tok, attention_mask=attention_mask, ) aux_loss += aux_loss * self.config.router_aux_loss_coef return SequenceClassifierOutput( logits=pooled_logits, past_key_values=past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, aux_loss=aux_loss, )