# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from eformer.common_types import ColumnWise, Replicated, RowWise
from eformer.loggings import get_logger
from easydel.infra.base_module import EasyDeLBaseConfig
from easydel.infra.factory import register_config
from easydel.infra.utils import AttnMaskDetail, AttnMaskType
logger = get_logger(__name__)
[docs]@register_config("qwen3")
class Qwen3Config(EasyDeLBaseConfig):
"""Configuration container for the Qwen3 decoder architecture."""
model_type = "qwen3"
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
layer_types: list[str] | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
[docs] def get_partition_rules(self, *args, **kwargs):
"""
Get the partition rules for the model.
Returns:
`tp.Tuple[tp.Tuple[str, PartitionSpec]]`: The partition rules.
"""
pmag = self.partition_manager
return (
(r"embed_tokens/embedding", pmag.resolve(ColumnWise)),
(r"self_attn/(q_proj|k_proj|v_proj)/kernel", pmag.resolve(ColumnWise)),
(r"self_attn/o_proj/kernel", pmag.resolve(RowWise)),
(r"self_attn/(q_norm|k_norm)/kernel", pmag.resolve(Replicated)),
(r"self_attn/(q_proj|k_proj|v_proj|o_proj)/bias", pmag.resolve(Replicated)),
(r"mlp/(gate_proj|up_proj)/kernel", pmag.resolve(ColumnWise)),
(r"mlp/down_proj/kernel", pmag.resolve(RowWise)),
(r"mlp/.*proj/bias", pmag.resolve(Replicated)),
(
r".*/(input_layernorm|post_attention_layernorm|norm)/kernel",
pmag.resolve(Replicated),
),
(r"lm_head/kernel", pmag.resolve(ColumnWise)),
(r"score/kernel", pmag.resolve(RowWise)),
(r".*bias", pmag.resolve(Replicated)),
(r".*", pmag.resolve(Replicated)),
)
[docs] def get_mask_details(self) -> dict[int, AttnMaskDetail]:
"""Retrieve attention mask details for each layer in the model.
This method generates a dictionary mapping layer indices to their corresponding attention mask details.
If a sliding window is defined, each layer is assigned a sliding window attention mask with the specified size.
Returns:
dict[int, AttnMaskDetail]: A dictionary where keys are layer indices (int) and values are AttnMaskDetail
objects specifying the attention mask type and size for each layer.
Notes:
- If `self.sliding_window` is None, an empty dictionary is returned.
- The method iterates over `self.num_hidden_layers` to assign mask details for each layer.
- The attention mask type is set to `AttnMaskType.SLIDING` when a sliding window is defined.
"""
mapping = {}
if self.layer_types is not None:
for layer_idx in range(self.num_hidden_layers):
mapping[layer_idx] = AttnMaskDetail(
mask_type=AttnMaskType.from_hf(self.layer_types[layer_idx]),
size=self.sliding_window,
)
return mapping
__all__ = ["Qwen3Config"]