Source code for easydel.modules.qwen3_moe.qwen3_moe_configuration

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing as tp

from eformer.common_types import (
    EMPTY,
    MODE_TRAIN,
    TP,
    ColumnWise,
    DynamicShardingAxes,
    Replicated,
    RowWise,
)
from eformer.loggings import get_logger

from easydel.infra.base_module import EasyDeLBaseConfig
from easydel.infra.factory import register_config
from easydel.infra.utils import AttnMaskDetail, AttnMaskType
from easydel.layers.moe.utils import get_moe_partition_spec

logger = get_logger(__name__)


class ExpertTensorParallel(DynamicShardingAxes):
    """Expert Tensor Parallelism (EPxTP) sharding axes."""

    axes: tp.ClassVar = [TP, EMPTY, EMPTY]
    mode: tp.ClassVar = MODE_TRAIN


[docs]@register_config("qwen3_moe") class Qwen3MoeConfig(EasyDeLBaseConfig): """Configuration for the Qwen3-MoE variant with routed experts.""" model_type = "qwen3_moe" def __init__( self, vocab_size=151936, hidden_size=2048, intermediate_size=6144, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=4, hidden_act="silu", max_position_embeddings=32768, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, use_sliding_window=False, sliding_window=4096, max_window_layers=28, attention_dropout=0.0, decoder_sparse_step=1, moe_intermediate_size=768, num_experts_per_tok=8, num_experts=128, norm_topk_prob=False, output_router_logits=False, router_aux_loss_coef=0.001, mlp_only_layers=None, layer_types: list[str] | None = None, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window if use_sliding_window else None self.max_window_layers = max_window_layers self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] self.decoder_sparse_step = decoder_sparse_step self.moe_intermediate_size = moe_intermediate_size self.num_experts_per_tok = num_experts_per_tok self.num_experts = num_experts self.norm_topk_prob = norm_topk_prob self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef self.mlp_only_layers = [] if mlp_only_layers is None else mlp_only_layers self.layer_types = layer_types if self.layer_types is None: self.layer_types = [ "sliding_attention" if self.sliding_window is not None and i >= self.max_window_layers else "full_attention" for i in range(self.num_hidden_layers) ] super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
[docs] def get_partition_rules(self, *args, **kwargs): """ Get the partition rules for the model. Returns: `tp.Tuple[tp.Tuple[str, PartitionSpec]]`: The partition rules. """ pmag = self.partition_manager return ( (r"embed_tokens/embedding", pmag.resolve(ColumnWise)), (r"self_attn/(q_proj|k_proj|v_proj)/kernel", pmag.resolve(ColumnWise)), (r"self_attn/o_proj/kernel", pmag.resolve(RowWise)), (r"self_attn/.*proj/bias", pmag.resolve(Replicated)), (r"self_attn/(q_norm|k_norm)/kernel", pmag.resolve(Replicated)), (r"mlp/(gate_proj|up_proj)/kernel", pmag.resolve(ColumnWise)), (r"mlp/down_proj/kernel", pmag.resolve(RowWise)), (r"mlp/.*proj/bias", pmag.resolve(Replicated)), (r"mlp/gate/kernel", pmag.resolve(Replicated if self.use_expert_tensor_mode else ColumnWise)), (r"mlp/gate/bias", pmag.resolve(Replicated)), ( r"mlp/experts/(gate_proj|up_proj)/kernel", get_moe_partition_spec( partition_manager=self.partition_manager, direction="column", tensors_are_expert=self.use_expert_tensor_mode, is_bias=False, fsdp_is_ep_bound=self.fsdp_is_ep_bound, sp_is_ep_bound=self.sp_is_ep_bound, module_view=True, ), ), ( r"mlp/experts/down_proj/kernel", get_moe_partition_spec( partition_manager=self.partition_manager, direction="row", tensors_are_expert=self.use_expert_tensor_mode, is_bias=False, fsdp_is_ep_bound=self.fsdp_is_ep_bound, sp_is_ep_bound=self.sp_is_ep_bound, module_view=True, ), ), (r"mlp/experts/.*bias", pmag.resolve(Replicated)), (r".*/(input_layernorm|post_attention_layernorm)/kernel", pmag.resolve(Replicated)), (r"norm/scale", pmag.resolve(Replicated)), (r"norm/bias", pmag.resolve(Replicated)), (r"lm_head/kernel", pmag.resolve(ColumnWise)), (r"score/kernel", pmag.resolve(RowWise)), (r".*bias", pmag.resolve(Replicated)), (r".*", pmag.resolve(Replicated)), )
[docs] def get_mask_details(self) -> dict[int, AttnMaskDetail]: """Retrieve attention mask details for each layer in the model. This method generates a dictionary mapping layer indices to their corresponding attention mask details. If a sliding window is defined, each layer is assigned a sliding window attention mask with the specified size. Returns: dict[int, AttnMaskDetail]: A dictionary where keys are layer indices (int) and values are AttnMaskDetail objects specifying the attention mask type and size for each layer. Notes: - If `self.sliding_window` is None, an empty dictionary is returned. - The method iterates over `self.num_hidden_layers` to assign mask details for each layer. - The attention mask type is set to `AttnMaskType.SLIDING` when a sliding window is defined. """ mapping = {} for layer_idx in range(self.num_hidden_layers): if self.sliding_window is not None and self.use_sliding_window and layer_idx >= self.max_window_layers: mapping[layer_idx] = AttnMaskDetail(mask_type=AttnMaskType.SLIDING, size=self.sliding_window) return mapping
__all__ = ["Qwen3MoeConfig"]