Source code for easydel.modules.gpt_oss.gpt_oss_configuration

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPT-OSS Model Configuration

This module provides configuration classes for the GPT-OSS model,
a transformer-based language model with Mixture of Experts (MoE) architecture. The model
features sparse routing, sliding window attention, and efficient parameter sharding for
distributed training.

The configuration includes custom sharding specifications for MoE components and
comprehensive model hyperparameters.
"""

from eformer.common_types import ColumnWise, Replicated, RowWise

from easydel.infra.base_module import EasyDeLBaseConfig
from easydel.infra.factory import register_config
from easydel.layers.moe.utils import get_moe_partition_spec


[docs]@register_config("gpt_oss") class GptOssConfig(EasyDeLBaseConfig): """Configuration class for GPT-OSS model. GPT-OSS is a transformer-based language model featuring: - Mixture of Experts (MoE) architecture with sparse routing - Alternating sliding window and full attention layers - RMSNorm for layer normalization - Rotary Position Embeddings (RoPE) with optional scaling - Efficient parameter sharding for distributed training Attributes: num_hidden_layers (int): Number of transformer layers. Default: 36 num_local_experts (int): Number of expert networks per MoE layer. Default: 128 vocab_size (int): Size of the vocabulary. Default: 201088 hidden_size (int): Dimension of hidden representations. Default: 2880 intermediate_size (int): Dimension of MLP intermediate layer. Default: 2880 head_dim (int): Dimension of each attention head. Default: 64 num_attention_heads (int): Number of attention heads. Default: 64 num_key_value_heads (int): Number of key-value heads for GQA. Default: 8 sliding_window (int): Size of sliding window for local attention. Default: 128 rope_theta (float): Base frequency for RoPE. Default: 150000.0 tie_word_embeddings (bool): Whether to tie input/output embeddings. Default: False hidden_act (str): Activation function for MLP. Default: "silu" initializer_range (float): Standard deviation for weight initialization. Default: 0.02 max_position_embeddings (int): Maximum sequence length. Default: 131072 rms_norm_eps (float): Epsilon for RMS normalization. Default: 1e-5 rope_scaling (dict): Configuration for RoPE scaling. Default: YARN scaling with factor 32 attention_dropout (float): Dropout rate for attention weights. Default: 0.0 num_experts_per_tok (int): Number of experts to route each token to. Default: 4 router_aux_loss_coef (float): Coefficient for load balancing auxiliary loss. Default: 0.9 output_router_logits (bool): Whether to output router logits. Default: False use_cache (bool): Whether to use key-value caching. Default: True layer_types (list): Attention type for each layer. Default: alternating sliding/full Example: >>> config = GptOssConfig( ... num_hidden_layers=24, ... num_local_experts=64, ... hidden_size=2048, ... num_attention_heads=32 ... ) >>> model = GptOssForCausalLM(config) """ model_type = "gpt_oss" def __init__( self, num_hidden_layers: int = 36, num_local_experts: int = 128, vocab_size: int = 201088, hidden_size: int = 2880, intermediate_size: int = 2880, head_dim: int = 64, num_attention_heads: int = 64, num_key_value_heads: int = 8, sliding_window: int = 128, rope_theta: float = 150000.0, tie_word_embeddings=False, hidden_act: str = "silu", initializer_range: float = 0.02, max_position_embeddings=131072, rms_norm_eps: float = 1e-5, rope_scaling=None, attention_dropout: float = 0.0, num_experts_per_tok=4, router_aux_loss_coef: float = 0.9, output_router_logits=False, use_cache=True, layer_types=None, mlp_activations_limit: float = 7.0, **kwargs, ): """Initialize GPT-OSS configuration. Args: num_hidden_layers: Number of transformer layers num_local_experts: Number of expert networks per MoE layer vocab_size: Size of the vocabulary hidden_size: Dimension of hidden representations intermediate_size: Dimension of MLP intermediate layer head_dim: Dimension of each attention head num_attention_heads: Number of attention heads num_key_value_heads: Number of key-value heads for grouped-query attention sliding_window: Size of sliding window for local attention rope_theta: Base frequency for rotary position embeddings tie_word_embeddings: Whether to tie input and output embeddings hidden_act: Activation function name for MLP layers initializer_range: Standard deviation for weight initialization max_position_embeddings: Maximum sequence length the model can handle rms_norm_eps: Epsilon value for RMS normalization layers rope_scaling: Dictionary configuring RoPE scaling (e.g., YARN parameters) attention_dropout: Dropout probability for attention weights num_experts_per_tok: Number of experts each token is routed to router_aux_loss_coef: Coefficient for the load balancing auxiliary loss output_router_logits: Whether to output router logits for analysis use_cache: Whether to use key-value caching for inference layer_types: List specifying attention type for each layer **kwargs: Additional configuration parameters """ if rope_scaling is None: rope_scaling = {"rope_type": "yarn", "factor": 32.0, "beta_fast": 32.0, "beta_slow": 1.0, "truncate": False} self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_local_experts = num_local_experts self.sliding_window = sliding_window self.num_experts_per_tok = num_experts_per_tok if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads self.layer_types = layer_types if self.layer_types is None: # Default: alternating sliding window and full attention self.layer_types = [ "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers) ] if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] self.attention_bias = True self.max_position_embeddings = max_position_embeddings self.router_aux_loss_coef = router_aux_loss_coef self.output_router_logits = output_router_logits self.use_cache = use_cache self.mlp_activations_limit = mlp_activations_limit super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
[docs] def get_partition_rules(self, *args, **kwargs): """Get the partition rules for distributed training of GPT-OSS model. Returns partition specifications for different parameter groups to enable efficient model parallelism. The rules specify how to shard parameters across devices for: - Embeddings: Column-wise sharding - Attention: Column-wise for QKV, row-wise for output projection - MoE: Custom expert-parallel sharding for expert parameters - Normalization: Replicated across devices Returns: tuple: Partition rules as (regex_pattern, PartitionSpec) pairs """ pmag = self.partition_manager kws = dict( fsdp_is_ep_bound=self.fsdp_is_ep_bound, sp_is_ep_bound=self.sp_is_ep_bound, module_view=True, tensors_are_expert=self.use_expert_tensor_mode, partition_manager=self.partition_manager, ) eck = get_moe_partition_spec(direction="column", is_bias=False, **kws) erk = get_moe_partition_spec(direction="row", is_bias=False, **kws) ecb = get_moe_partition_spec(direction="column", is_bias=True, **kws) erb = get_moe_partition_spec(direction="row", is_bias=True, **kws) return ( (r".*embed_tokens/embedding", pmag.resolve(ColumnWise)), (r".*self_attn/(q_proj|k_proj|v_proj)/kernel", pmag.resolve(ColumnWise)), (r".*self_attn/o_proj/kernel", pmag.resolve(RowWise)), (r".*self_attn/.*proj/bias", pmag.resolve(Replicated)), ( r".*mlp/gate/kernel", pmag.resolve(Replicated if self.use_expert_tensor_mode else ColumnWise), ), (r".*mlp/gate/bias", pmag.resolve(Replicated)), # Legacy paths (original HuggingFace parameter names) (r".*mlp/experts/gate_up_proj$", eck), (r".*mlp/experts/down_proj$", erk), (r".*mlp/experts/gate_up_proj_bias$", ecb), (r".*mlp/experts/down_proj_bias$", erb), # New split paths (after reform_param transformation) (r".*mlp/experts/(gate_proj|up_proj)/kernel", eck), (r".*mlp/experts/down_proj/kernel", erk), (r".*mlp/experts/(gate_proj|up_proj)/bias", ecb), (r".*mlp/experts/down_proj/bias", erb), (r".*layernorm/scale", pmag.resolve(Replicated)), (r".*rms_norm/scale", pmag.resolve(Replicated)), (r".*norm/scale", pmag.resolve(Replicated)), (r".*lm_head/kernel", pmag.resolve(ColumnWise)), (r".*score/kernel", pmag.resolve(ColumnWise)), (r".*bias", pmag.resolve(Replicated)), (r".*", pmag.resolve(Replicated)), )
__all__ = ["GptOssConfig"]