# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
import typing as tp
from eformer.common_types import (
EMPTY,
MODE_TRAIN,
TP,
ColumnWise,
DynamicShardingAxes,
Replicated,
RowWise,
)
from easydel.infra.base_module import EasyDeLBaseConfig
from easydel.infra.factory import register_config
from easydel.layers.moe.utils import get_moe_partition_spec
from easydel.layers.rotary_embedding import RopeConfig
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
[docs]class ExpertTensorParallel(DynamicShardingAxes):
"""Expert Tensor Parallelism (EPxTP) sharding axes."""
axes: tp.ClassVar = [TP, EMPTY, EMPTY]
mode: tp.ClassVar = MODE_TRAIN
[docs]@register_config("deepseek_v3")
class DeepseekV3Config(EasyDeLBaseConfig):
r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to
instantiate an DeepSeek model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV3Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1407):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
n_shared_experts (`int`, *optional*, defaults to None):
Number of shared experts, None means dense model.
n_routed_experts (`int`, *optional*, defaults to None):
Number of routed experts, None means dense model.
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
Scaling factor or routed experts.
topk_method (`str`, *optional*, defaults to `gready`):
Topk method used in routed gate.
n_group (`int`, *optional*, defaults to None):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts
is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to False):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to 'softmax'):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
seq_aux = (`bool`, *optional*, defaults to True):
Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "deepseek_v3"
keys_to_ignore_at_inference: typing.ClassVar = ["past_key_values"]
def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
moe_intermediate_size=2048,
num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128,
num_key_value_heads=128,
n_shared_experts=1,
n_routed_experts=256,
ep_size=1,
routed_scaling_factor=2.5,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="noaux_tc",
n_group=8,
topk_group=4,
num_experts_per_tok=8,
moe_layer_freq=1,
first_k_dense_replace=3,
norm_topk_prob=True,
scoring_func="sigmoid",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
"""Initialize a new DeepseekV3Config instance.
Args:
vocab_size (int, optional): Size of the vocabulary. Defaults to 129280.
hidden_size (int, optional): Dimensionality of the embeddings and hidden states. Defaults to 7168.
intermediate_size (int, optional): Dimensionality of the MLP layer. Defaults to 18432.
moe_intermediate_size (int, optional): Dimensionality of the MoE intermediate layer. Defaults to 2048.
num_hidden_layers (int, optional): Number of hidden layers in the model. Defaults to 61.
num_nextn_predict_layers (int, optional): Number of next-n prediction layers. Defaults to 1.
num_attention_heads (int, optional): Number of attention heads. Defaults to 128.
num_key_value_heads (int, optional): Number of key/value heads (for GQA). Defaults to 128.
n_shared_experts (int, optional): Number of shared MoE experts. Defaults to 1.
n_routed_experts (int, optional): Number of routed MoE experts. Defaults to 256.
ep_size (int, optional): Expert parallelism size. Defaults to 1.
routed_scaling_factor (float, optional): Scaling factor for routed experts. Defaults to 2.5.
kv_lora_rank (int, optional): Rank for KV LoRA. Defaults to 512.
q_lora_rank (int, optional): Rank for Q LoRA. Defaults to 1536.
qk_rope_head_dim (int, optional): Head dimension for QK with RoPE. Defaults to 64.
v_head_dim (int, optional): Head dimension for V. Defaults to 128.
qk_nope_head_dim (int, optional): Head dimension for QK without RoPE. Defaults to 128.
topk_method (str, optional): Method for top-k expert selection. Defaults to "noaux_tc".
n_group (int, optional): Number of expert groups. Defaults to 8.
topk_group (int, optional): Top-k groups. Defaults to 4.
num_experts_per_tok (int, optional): Number of experts per token. Defaults to 8.
moe_layer_freq (int, optional): Frequency of MoE layers. Defaults to 1.
first_k_dense_replace (int, optional): First k dense layers to replace. Defaults to 3.
norm_topk_prob (bool, optional): Whether to normalize top-k probabilities. Defaults to True.
scoring_func (str, optional): Scoring function for expert selection. Defaults to "sigmoid".
aux_loss_alpha (float, optional): Weight for auxiliary loss. Defaults to 0.001.
seq_aux (bool, optional): Whether to use sequence auxiliary loss. Defaults to True.
hidden_act (str, optional): Activation function. Defaults to "silu".
max_position_embeddings (int, optional): Maximum sequence length. Defaults to 4096.
initializer_range (float, optional): Range for weight initialization. Defaults to 0.02.
rms_norm_eps (float, optional): Epsilon for RMS normalization. Defaults to 1e-6.
use_cache (bool, optional): Whether to use KV cache for generation. Defaults to True.
pad_token_id (int, optional): ID for padding token. Defaults to None.
bos_token_id (int, optional): ID for beginning of sequence token. Defaults to 0.
eos_token_id (int, optional): ID for end of sequence token. Defaults to 1.
pretraining_tp (int, optional): Tensor parallelism size during pretraining. Defaults to 1.
tie_word_embeddings (bool, optional): Whether to tie input/output embeddings. Defaults to False.
rope_theta (float, optional): Base value for RoPE. Defaults to 10000.0.
rope_scaling (Dict, optional): RoPE scaling configuration. Defaults to None.
attention_bias (bool, optional): Whether to use bias in attention. Defaults to False.
attention_dropout (float, optional): Dropout rate for attention. Defaults to 0.0.
**kwargs: Additional arguments.
"""
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
[docs] def get_partition_rules(self, *args, **kwargs):
"""
Get the partition rules for the model.
Returns:
`tp.Tuple[tp.Tuple[str, PartitionSpec]]`: The partition rules.
"""
pmag = self.partition_manager
return (
(r"embed_tokens/embedding", pmag.resolve(ColumnWise)),
(r"self_attn/q_proj/kernel", pmag.resolve(ColumnWise)),
(r"self_attn/q_a_proj/kernel", pmag.resolve(ColumnWise)),
(r"self_attn/q_b_proj/kernel", pmag.resolve(ColumnWise)),
(r"self_attn/kv_a_proj_with_mqa/kernel", pmag.resolve(ColumnWise)),
(r"self_attn/kv_b_proj/kernel", pmag.resolve(ColumnWise)),
(r"self_attn/o_proj/kernel", pmag.resolve(RowWise)),
(r"self_attn/.*proj/bias", pmag.resolve(Replicated)),
(r"self_attn/(q_a_layernorm|kv_a_layernorm)/kernel", pmag.resolve(Replicated)),
(r"mlp/(gate_proj|up_proj)/kernel", pmag.resolve(ColumnWise)),
(r"mlp/down_proj/kernel", pmag.resolve(RowWise)),
(r"mlp/gate/kernel", pmag.resolve(Replicated if self.use_expert_tensor_mode else ColumnWise)),
(r"mlp/gate/e_score_correction_bias", pmag.resolve(Replicated)),
(
r"mlp/experts/(gate_proj|up_proj)/kernel",
get_moe_partition_spec(
partition_manager=self.partition_manager,
direction="column",
tensors_are_expert=self.use_expert_tensor_mode,
is_bias=False,
fsdp_is_ep_bound=self.fsdp_is_ep_bound,
sp_is_ep_bound=self.sp_is_ep_bound,
module_view=True,
),
),
(
r"mlp/experts/down_proj/kernel",
get_moe_partition_spec(
partition_manager=self.partition_manager,
direction="row",
tensors_are_expert=self.use_expert_tensor_mode,
is_bias=False,
fsdp_is_ep_bound=self.fsdp_is_ep_bound,
sp_is_ep_bound=self.sp_is_ep_bound,
module_view=True,
),
),
(r"mlp/shared_experts/(gate_proj|up_proj)/kernel", pmag.resolve(ColumnWise)),
(r"mlp/shared_experts/down_proj/kernel", pmag.resolve(RowWise)),
(r".*(input_layernorm|post_attention_layernorm|norm)/kernel", pmag.resolve(Replicated)),
(r"lm_head/kernel", pmag.resolve(ColumnWise)),
(r".*bias", pmag.resolve(Replicated)),
(r".*", pmag.resolve(Replicated)),
)
def _get_rope_config(self) -> RopeConfig:
"""Get RoPE configuration from the instance attributes."""
if not hasattr(self, "rope_scaling") or self.rope_scaling is None:
config = RopeConfig.from_dict(
dict(
rope_type="yarn",
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
)
)
else:
config = RopeConfig.from_dict(self.rope_scaling)
if config.original_max_position_embeddings is None:
config.original_max_position_embeddings = getattr(self, "original_max_position_embeddings", None)
return config