# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mixture of Experts (MoE) layer implementations for EasyDeL.
This module provides a comprehensive implementation of Mixture of Experts (MoE) layers
for large-scale neural networks. It includes support for various routing strategies,
load balancing techniques, and distributed training optimizations.
Key Components:
- **BaseMoeModule**: Abstract base class for MoE implementations with common
utilities for routing, permutation, and metric computation.
- **ParallelMoELinear**: Batched linear transformation layer for expert networks
with support for ragged and grouped matrix multiplication.
- **Routing Strategies**: Multiple routing algorithms including top-k, switch,
expert choice, and hash-based routing.
- **Load Balancing**: Various strategies to ensure balanced expert utilization
including standard, switch transformer, and expert choice methods.
- **Distributed Support**: Full support for expert parallelism (EP), tensor
parallelism (TP), and data parallelism (DP) with optimized all-to-all
communication patterns.
The module is designed for efficient execution on TPUs and GPUs with optimizations
for:
- Custom VJP for gradient-efficient sorting operations
- Pallas-based grouped matrix multiplication kernels for TPUs
- Ragged tensor operations for variable-length expert assignments
- Automatic sharding and partitioning for distributed training
Example:
>>> from easydel.layers.moe import BaseMoeModule, ParallelMoELinear
>>> # Create a custom MoE layer by extending BaseMoeModule
>>> class CustomMoE(BaseMoeModule):
... def __init__(self, config):
... super().__init__(config)
... # Initialize gate and expert layers
... def __call__(self, hidden_states):
... # Implement forward pass using _moe_call_standard helper
... return self._moe_call_standard(...)
"""
from __future__ import annotations
import typing
import typing as tp
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
import jax
import jax.extend
from eformer import common_types
from eformer.loggings import get_logger
from ejkernel.modules import GroupedMatmulConfig, grouped_matmul
from flax import nnx as nn
from jax import numpy as jnp
from jax import shard_map
from jax.ad_checkpoint import checkpoint_name
from jax.sharding import PartitionSpec
from jaxtyping import Array, Bool, Float, Int
from easydel.infra.base_module import EasyDeLBaseConfig
from easydel.utils.helpers import check_bool_flag
from .utils import (
MoeFusedHooks,
MoeLoadBalancingStrategy,
MoEMethods,
MoeMetrics,
MoeRoutingStrategy,
get_all_to_all_params,
get_experts_location,
get_moe_partition_spec,
local_permute,
permute,
resolve_eformer_axis,
sort_activations,
unpermute,
)
if typing.TYPE_CHECKING:
from easydel.infra.base_config import EasyDeLBaseConfig
logger = get_logger(__name__)
BATCH = common_types.BATCH
EMPTY = common_types.EMPTY
EMBED = common_types.EMBED
EXPERT = common_types.EXPERT
MODE_TRAIN = common_types.MODE_TRAIN
EP = common_types.EP
DP = common_types.DP
FSDP = common_types.FSDP
TP = common_types.TP
SP = common_types.SP
[docs]class BaseMoeModule(nn.Module, ABC):
"""An abstract base class for Mixture of Experts (MoE) modules.
This class provides a foundational structure and common utilities for
implementing various MoE architectures. It includes methods for token routing,
data permutation for efficient expert computation, load balancing loss
calculation, and sharding for distributed environments. Subclasses are
expected to implement the `__call__` method to define the specific MoE forward
pass.
Attributes:
config: The configuration object for the MoE module.
mesh: The JAX device mesh for distributed computation.
n_routed_experts: The total number of experts available for routing.
num_experts_per_tok: The number of experts each token is routed to (k).
hidden_size: The dimension of the hidden states.
lbl_coef: The coefficient for the load balancing loss.
rzl_coef: The coefficient for the router z-loss.
routing_strategy: The strategy used for routing tokens to experts.
load_balancing_strategy: The strategy used for calculating the load
balancing loss.
"""
def __init__(
self,
config: EasyDeLBaseConfig,
n_routed_experts: int | None = None,
num_experts_per_tok: int | None = None,
hidden_size: int | None = None,
lbl_coef: float | None = None,
rzl_coef: float | None = None,
routing_strategy: MoeRoutingStrategy = MoeRoutingStrategy.TOP_K,
load_balancing_strategy: MoeLoadBalancingStrategy = MoeLoadBalancingStrategy.STANDARD,
moe_hooks: MoeFusedHooks | None = None,
):
"""Initializes the BaseMoeModule.
Args:
config: The configuration object for this MoE module.
n_routed_experts: The total number of experts. If None, it's taken
from `config.n_routed_experts`.
num_experts_per_tok: The number of experts to route each token to. If
None, it's taken from `config.num_experts_per_tok`.
hidden_size: The hidden dimension of the input and output. If None,
it's taken from `config.hidden_size`.
lbl_coef: The coefficient for the load balancing loss.
rzl_coef: The coefficient for the router z-loss.
routing_strategy: The strategy for routing tokens to experts.
load_balancing_strategy: The strategy for load balancing.
moe_hooks: Hook system for custom interventions during MoE execution.
If None, uses default MoeFusedHooks with no custom hooks.
"""
super().__init__()
self.config = config
self.mesh = config.mesh
self.partition_manager = config.partition_manager
self.n_routed_experts = n_routed_experts or config.n_routed_experts
self.num_experts_per_tok = num_experts_per_tok or config.num_experts_per_tok
self.hidden_size = hidden_size or config.hidden_size
self.lbl_coef = lbl_coef
self.rzl_coef = rzl_coef
self.routing_strategy = routing_strategy
self.load_balancing_strategy = load_balancing_strategy
self.moe_hooks = MoeFusedHooks() if moe_hooks is None else moe_hooks
self.module_moe_method = self.config.moe_method
self.expert_mesh = self.config.expert_mesh
self.auto_expert_mesh = self.config.auto_expert_mesh
self.expert_abstract_mesh = self.config.expert_abstract_mesh
self.dtype = getattr(self, "dtype", jnp.bfloat16)
[docs] def get_moe_spec(
self,
direction: tp.Literal["row", "column"],
tensors_are_expert: bool,
is_bias: bool = False,
) -> PartitionSpec:
"""Generate partition spec for MoE weight tensors.
This helper creates appropriate partition specs for MoE expert weights
based on the sharding strategy and tensor properties.
Args:
direction: Weight matrix orientation:
- "column": For column-wise sharding (wi/wu kernels)
- "row": For row-wise sharding (wd kernel)
tensors_are_expert: If True, uses expert tensor mode (experts on TP axis).
If False, uses standard mode (experts on EP axis).
is_bias: If True, generates spec for bias tensor (2D instead of 3D).
Returns:
PartitionSpec appropriate for the tensor.
Examples:
Standard mode (tensors_are_expert=False):
- Column weight: [expert, None, tp] # wi/wu: [E, H, M]
- Row weight: [expert, tp, None] # wd: [E, M, H]
- Bias: [expert, None] # [E, dim]
Expert tensor mode (tensors_are_expert=True):
- Column weight: [tp, None, None] # Experts on TP
- Row weight: [tp, None, None] # Experts on TP
- Bias: [tp, None] # [E, dim]
"""
return get_moe_partition_spec(
partition_manager=self.partition_manager,
direction=direction,
tensors_are_expert=tensors_are_expert,
is_bias=is_bias,
fsdp_is_ep_bound=self.config.fsdp_is_ep_bound,
sp_is_ep_bound=self.config.sp_is_ep_bound,
module_view=False,
)
def _get_sharding_status(self):
"""Resolves and returns all parallelism axis names and sizes for this MoE layer.
This method queries the partition manager to resolve logical axis names to
physical mesh axis names, and retrieves their sizes from the device mesh.
It handles both standard and expert-tensor parallelism modes.
In standard mode:
- Expert axis → EP (expert parallel)
- Tensor axis → TP (tensor parallel)
In expert-tensor mode (`use_expert_tensor_mode=True`):
- Expert axis → TP (tensor parallel)
- Tensor axis → EP (expert parallel)
This axis swapping allows alternative sharding strategies for specific use cases.
Returns:
A tuple containing 10 elements:
1. data_axis_name (str): Resolved data parallel axis name
2. fsdp_axis_name (str): Resolved FSDP axis name
3. expert_axis_name (str): Resolved expert parallel axis name
4. tensor_axis_name (str): Resolved tensor parallel axis name
5. sp_axis_name (str): Resolved sequence parallel axis name
6. dp_size (int): Data parallel degree (number of devices)
7. fsdp_size (int): FSDP degree
8. ep_size (int): Expert parallel degree
9. tp_size (int): Tensor parallel degree
10. sp_size (int): Sequence parallel degree
Note:
Sizes default to 1 if the axis doesn't exist in the mesh.
"""
partition_manager = self.partition_manager
data_axis_name = resolve_eformer_axis(DP, partition_manager)
if self.config.use_expert_tensor_mode:
expert_axis_name = resolve_eformer_axis(TP, partition_manager)
tensor_axis_name = resolve_eformer_axis(EP, partition_manager)
else:
expert_axis_name = resolve_eformer_axis(EP, partition_manager)
tensor_axis_name = resolve_eformer_axis(TP, partition_manager)
fsdp_axis_name = resolve_eformer_axis(FSDP, partition_manager)
sp_axis_name = resolve_eformer_axis(SP, partition_manager)
dp_size = self.mesh.shape.get(data_axis_name, 1)
ep_size = self.mesh.shape.get(expert_axis_name, 1)
tp_size = self.mesh.shape.get(tensor_axis_name, 1)
fsdp_size = self.mesh.shape.get(fsdp_axis_name, 1)
sp_size = self.mesh.shape.get(sp_axis_name, 1)
return (
data_axis_name,
fsdp_axis_name,
expert_axis_name,
tensor_axis_name,
sp_axis_name,
dp_size,
fsdp_size,
ep_size,
tp_size,
sp_size,
)
def _replicate_and_sort_tokens(
self,
inputs_flat: jax.Array,
selected_experts: jax.Array,
use_custom_sort_vjp: bool = True,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
"""Replicates tokens k times and sorts them by assigned expert ID.
This function prepares tokens for expert computation by:
1. Replicating each token k times (once per selected expert)
2. Sorting all replicated tokens so tokens for the same expert are contiguous
3. Computing group sizes (how many tokens per expert)
4. Creating sorted expert ID array aligned with sorted tokens
The sorted layout enables efficient grouped/ragged matrix multiplication where
each expert processes its assigned tokens as a contiguous batch.
Args:
inputs_flat: Flattened token representations. Shape: (num_tokens, hidden_dim).
selected_experts: Expert assignments per token. Shape: (num_tokens, k)
where k = `num_experts_per_tok`.
use_custom_sort_vjp: Whether to use custom VJP for memory-efficient sorting.
Defaults to True.
Returns:
A tuple containing:
- sorted_inputs: Token representations sorted by expert. Shape: (num_tokens*k, hidden_dim).
- sorted_by_expert: Sorting indices for the permutation. Shape: (num_tokens*k,).
- group_sizes: Number of tokens assigned to each expert. Shape: (n_routed_experts,).
- sorted_experts: Expert IDs aligned with sorted_inputs. Shape: (num_tokens*k,).
Example:
>>> # 2 tokens, 2 experts per token, 4 total experts
>>> inputs = jnp.ones((2, 128))
>>> selected = jnp.array([[0, 2], [1, 3]]) # token 0→experts 0,2; token 1→experts 1,3
>>> sorted_inputs, indices, sizes, expert_ids = _replicate_and_sort_tokens(inputs, selected)
>>> # sorted_inputs: tokens grouped as [token0_e0, token1_e1, token0_e2, token1_e3]
>>> # sizes: [1, 1, 1, 1] - one token per expert
"""
k = selected_experts.shape[-1]
flat_idx = selected_experts.reshape(-1)
sorted_by_expert = jnp.argsort(flat_idx)
replicated = jnp.repeat(inputs_flat, k, axis=0)
sorted_inputs = sort_activations(replicated, sorted_by_expert, use_custom_sort_vjp)
group_sizes = jnp.bincount(flat_idx, length=self.n_routed_experts)
sorted_experts = jnp.repeat(
jnp.arange(self.n_routed_experts),
repeats=group_sizes,
total_repeat_length=flat_idx.shape[0],
)
return sorted_inputs, sorted_by_expert, group_sizes, sorted_experts
def _apply_capacity_mask(
self,
selected_experts: jax.Array,
weights: jax.Array,
capacity_factor: float,
) -> jax.Array:
"""Applies soft capacity constraints to expert assignments.
This method limits the number of tokens each expert can process by zeroing out
weights for tokens that exceed the expert's capacity. This helps prevent expert
overload and improves load balancing during training.
The capacity is computed as:
capacity = max(ceil(tokens_per_batch / n_experts) * capacity_factor, capacity_factor)
Tokens are processed in order, and once an expert reaches capacity, subsequent
token assignments to that expert receive zero weight.
Args:
selected_experts: Expert assignments per token. Shape: (B, S, k) where
B=batch_size, S=seq_len, k=num_experts_per_tok.
weights: Expert weights per token. Shape: (B, S, k).
capacity_factor: Multiplier for base capacity. Values > 1.0 allow more tokens,
values < 1.0 enforce stricter limits. Typically in range [1.0, 2.0].
Returns:
Modified weights with overflow tokens masked to zero. Shape: (B, S, k).
Tokens within capacity retain their original weights; overflow tokens get 0.
Example:
>>> # 2 batches, 4 tokens per batch, 2 experts per token, 4 total experts
>>> experts = jnp.array([[[0, 1], [0, 2], [1, 3], [2, 3]], # batch 0
... [[0, 1], [1, 2], [2, 3], [3, 0]]]) # batch 1
>>> weights = jnp.ones((2, 4, 2))
>>> masked_weights = _apply_capacity_mask(experts, weights, capacity_factor=1.5)
>>> # Some tokens will have zero weight if they exceed expert capacity
"""
B, S, k = selected_experts.shape
tokens_per_batch = S * k
cap = int(max(jnp.ceil(tokens_per_batch / self.n_routed_experts) * capacity_factor, capacity_factor))
expert_mask = jax.nn.one_hot(selected_experts, num_classes=self.n_routed_experts, dtype=jnp.int32)
fused = expert_mask.reshape(B, S * k, self.n_routed_experts)
counts = jnp.cumsum(fused, axis=1)
counts = counts.reshape(B, S, k, self.n_routed_experts)
keep = (counts <= cap).astype(weights.dtype)
keep_for_slot = jnp.sum(keep, axis=-1)
return weights * keep_for_slot
def _expert_group_mask(self, gate_logits: jax.Array, n_groups: int, topk_groups: int) -> jax.Array:
"""Creates a mask for hierarchical routing with grouped experts.
This method implements hierarchical or grouped routing where experts are organized
into groups, and tokens first select top-k groups, then select experts within
those groups. This can improve routing efficiency and reduce computation when
the number of experts is very large.
The algorithm:
1. Partition experts into n_groups
2. For each group, compute a group score (sum of top-2 expert logits in that group)
3. Select topk_groups with highest scores
4. Create a mask that zeros out logits for experts in non-selected groups
Args:
gate_logits: Router logits for all experts. Shape: (batch*seq, n_experts).
n_groups: Number of expert groups to partition experts into.
Must evenly divide n_routed_experts.
topk_groups: Number of groups to select per token. Typically 1 or 2.
Returns:
Binary mask for gate logits. Shape: (batch*seq, n_experts).
1.0 for experts in selected groups, 0.0 for others.
Example:
>>> # 8 experts divided into 4 groups of 2 experts each
>>> logits = jnp.ones((16, 8)) # 16 tokens, 8 experts
>>> mask = _expert_group_mask(logits, n_groups=4, topk_groups=2)
>>> # mask will have 1.0 for experts in 2 selected groups, 0.0 for others
>>> masked_logits = logits * mask # Zero out non-selected groups
"""
BS = gate_logits.shape[0]
experts_per_group = self.n_routed_experts // n_groups
scores_grouped = gate_logits.reshape(BS, n_groups, experts_per_group)
top2_vals, _ = jax.lax.top_k(scores_grouped, k=2)
group_scores = jnp.sum(top2_vals.astype(jnp.float32), axis=-1)
_, group_idx = jax.lax.top_k(group_scores, k=topk_groups)
mask_groups = jax.nn.one_hot(group_idx, num_classes=n_groups, dtype=jnp.float32).sum(axis=-2)
mask = jnp.broadcast_to(mask_groups[..., None], (BS, n_groups, experts_per_group)).reshape(BS, -1)
return mask
def _compute_load_balancing_loss(
self,
router_probs: jax.Array,
expert_loads: jax.Array,
strategy: MoeLoadBalancingStrategy | None = None,
) -> float | None:
"""Computes the load balancing auxiliary loss to distribute tokens evenly across experts."""
strategy = strategy or self.load_balancing_strategy
if strategy == MoeLoadBalancingStrategy.NONE or self.lbl_coef is None:
return None
if strategy == MoeLoadBalancingStrategy.STANDARD:
f = expert_loads * self.n_routed_experts / self.num_experts_per_tok
p = jnp.mean(router_probs, axis=0)
return self.lbl_coef * jnp.sum(f * p)
elif strategy == MoeLoadBalancingStrategy.SWITCH_TRANSFORMER:
num_tokens = router_probs.shape[0]
expert_fraction = expert_loads / num_tokens
router_fraction = jnp.mean(router_probs, axis=0)
return self.lbl_coef * self.n_routed_experts * jnp.sum(expert_fraction * router_fraction)
elif strategy == MoeLoadBalancingStrategy.EMPTY_CHOICE:
return self.lbl_coef * jnp.var(expert_loads)
else:
raise ValueError(f"Unknown load balancing strategy: {strategy}")
def _compute_router_z_loss(self, router_logits: Float[Array, "batch_seq num_experts"]) -> float | None:
"""Computes the router z-loss to encourage small logit magnitudes for training stability."""
if self.rzl_coef is None:
return None
log_z = jax.nn.logsumexp(router_logits.astype(jnp.float32), axis=-1)
return self.rzl_coef * jnp.mean(log_z**2)
def _compute_metrics(
self,
router_logits: jax.Array,
router_probs: jax.Array,
selected_experts: jax.Array,
selected_weights: jax.Array,
expert_loads: jax.Array,
) -> MoeMetrics:
"""Computes and aggregates all MoE-related metrics and auxiliary losses."""
metrics = MoeMetrics(
expert_loads=expert_loads,
router_probs=router_probs,
selected_experts=selected_experts,
selected_weights=selected_weights,
)
metrics.load_balancing_loss = self._compute_load_balancing_loss(router_probs, expert_loads)
metrics.router_z_loss = self._compute_router_z_loss(router_logits)
metrics.expert_utilization = jnp.mean(expert_loads > 0)
metrics.routing_entropy = jnp.mean(-jnp.sum(router_probs * jnp.log(router_probs + 1e-8), axis=-1))
return metrics
def _apply_expert_sharding(self, tensor: Float[Array, ...], tensor_type: str = "weight") -> Float[Array, ...]:
"""Applies expert parallel sharding to a tensor for distributed training.
This method determines the appropriate sharding specification for expert parameters
based on the tensor type and shape, then places the tensor on devices according
to that specification. The sharding is currently set to replicate all dimensions
(EMPTY) but can be extended to support expert-parallel sharding.
Args:
tensor: The tensor to shard. Can be weights, biases, or activations.
tensor_type: Type hint for determining sharding strategy. Options:
- "weight_col": Column-parallel weight (output dim partitioned)
- "weight_row": Row-parallel weight (input dim partitioned)
- "bias": Bias parameters
- "weight" (default): Generic weight tensor
Returns:
The input tensor with sharding applied, placed on appropriate devices
according to the resolved partition specification.
Note:
Current implementation uses EMPTY (replicated) sharding for all dimensions.
Future versions may shard the expert dimension across EP devices.
Example:
>>> weight = jnp.ones((n_experts, hidden_dim, intermediate_dim))
>>> sharded_weight = _apply_expert_sharding(weight, "weight_col")
>>> # weight is now sharded across devices according to mesh configuration
"""
pmag = self.partition_manager
if tensor_type == "weight_col":
if tensor.ndim == 3 and tensor.shape[0] == self.n_routed_experts:
sharding_spec = pmag.resolve(axes=[EMPTY, EMPTY, EMPTY], mode=MODE_TRAIN, shape=tensor.shape)
elif tensor.ndim == 2:
sharding_spec = pmag.resolve(axes=[EMPTY, EMPTY], mode=MODE_TRAIN, shape=tensor.shape)
else:
sharding_spec = pmag.resolve(axes=[EMPTY], mode=MODE_TRAIN)
elif tensor_type == "weight_row":
if tensor.ndim == 3 and tensor.shape[0] == self.n_routed_experts:
sharding_spec = pmag.resolve(axes=[EMPTY, EMPTY, EMPTY], mode=MODE_TRAIN, shape=tensor.shape)
elif tensor.ndim == 2:
sharding_spec = pmag.resolve(axes=[EMPTY, EMPTY], mode=MODE_TRAIN, shape=tensor.shape)
else:
sharding_spec = pmag.resolve(axes=[EMPTY], mode=MODE_TRAIN)
elif tensor_type == "bias":
if tensor.ndim == 2 and tensor.shape[0] == self.n_routed_experts:
sharding_spec = pmag.resolve(axes=[EMPTY, EMPTY], mode=MODE_TRAIN, shape=tensor.shape)
else:
sharding_spec = pmag.resolve(axes=[EMPTY], mode=MODE_TRAIN, shape=tensor.shape)
else:
if tensor.ndim == 3 and tensor.shape[0] == self.n_routed_experts:
sharding_spec = pmag.resolve(axes=[EMPTY, EMPTY, EMPTY], mode=MODE_TRAIN, shape=tensor.shape)
elif tensor.ndim == 2 and tensor.shape[0] == self.n_routed_experts:
sharding_spec = pmag.resolve(axes=[EMPTY, EMPTY], mode=MODE_TRAIN, shape=tensor.shape)
else:
sharding_spec = pmag.resolve(axes=[EMPTY], mode=MODE_TRAIN)
return jax.device_put(tensor, jax.sharding.NamedSharding(self.mesh, sharding_spec))
def _get_gate_layer_sharding(self, weight_shape: tuple) -> PartitionSpec:
"""Returns the partition specification for gate/router layer weights.
The gate layer maps hidden states to expert logits, producing routing decisions.
This method determines how the gate weight matrix should be sharded across devices.
Args:
weight_shape: Shape of the gate weight matrix, typically (hidden_dim, n_experts).
Returns:
PartitionSpec defining how to shard the gate weights. Currently uses
[EMPTY, EMPTY] (replicated across all dimensions).
Note:
Gate weights are typically small relative to expert FFN weights and
are usually replicated for efficient routing computation.
"""
pmag = self.partition_manager
return pmag.resolve(axes=[EMPTY, EMPTY], mode=MODE_TRAIN, shape=weight_shape)
def _get_gate_layer_bias_sharding(self, bias_shape: tuple) -> PartitionSpec:
"""Returns the partition specification for gate/router layer bias.
Args:
bias_shape: Shape of the gate bias vector, typically (n_experts,).
Returns:
PartitionSpec defining how to shard the gate bias. Currently uses
[EMPTY] (replicated).
Note:
Like gate weights, bias is usually replicated for efficient routing.
"""
pmag = self.partition_manager
return pmag.resolve(axes=[EMPTY], mode=MODE_TRAIN, shape=bias_shape)
def _validate_routing_inputs(
self, hidden_states: Float[Array, "batch seq hidden_dim"], router_logits: Float[Array, "batch_seq num_experts"]
) -> None:
"""Validates the shapes of inputs for routing operations."""
if hidden_states.shape[-1] != self.hidden_size:
raise ValueError(
f"Input hidden dimension {hidden_states.shape[-1]} doesn't "
f"match config hidden dimension {self.hidden_size}"
)
if router_logits.shape[-1] != self.n_routed_experts:
raise ValueError(
f"Router logits expert dimension {router_logits.shape[-1]} doesn't match "
f"config expert count {self.n_routed_experts}"
)
if router_logits.shape[0] != hidden_states.shape[0] * hidden_states.shape[1]:
raise ValueError(
f"Router logits batch dimension {router_logits.shape[0]} doesn't match "
f"flattened input batch dimension {hidden_states.shape[0] * hidden_states.shape[1]}"
)
def _apply_capacity_constraint(
self,
selected_experts: jax.Array,
selected_weights: jax.Array,
capacity_factor: float | None = None,
) -> tuple[jax.Array, jax.Array]:
"""Applies soft capacity constraint to limit tokens per expert."""
if capacity_factor is None:
capacity_factor = 1.0
num_tokens = selected_experts.shape[0]
max_capacity = int(capacity_factor * num_tokens / self.n_routed_experts)
expert_counts = jnp.bincount(selected_experts.flatten(), length=self.n_routed_experts)
over_capacity_ratio = jnp.maximum(expert_counts / max_capacity, 1.0)
weight_adjustments = 1.0 / over_capacity_ratio[selected_experts]
constrained_weights = selected_weights * weight_adjustments
weight_sum = jnp.sum(constrained_weights, axis=-1, keepdims=True)
constrained_weights = jnp.where(weight_sum > 0, constrained_weights / weight_sum, constrained_weights)
return selected_experts, constrained_weights
def _create_expert_mask(
self,
selected_experts: Int[Array, "batch_seq k"],
expert_id: int,
) -> Bool[Array, "batch_seq"]: # type: ignore #noqa
"""Creates a boolean mask identifying tokens assigned to a specific expert.
This utility method is useful for per-expert analysis, debugging, or when
processing experts individually rather than in batched/grouped fashion.
Args:
selected_experts: Expert assignments per token. Shape: (batch*seq, k)
where k = num_experts_per_tok.
expert_id: The expert ID to create a mask for (0 to n_routed_experts-1).
Returns:
Boolean mask where True indicates the token was assigned to the specified
expert. Shape: (batch*seq,).
Example:
>>> selected = jnp.array([[0, 2], [1, 3], [0, 1]]) # 3 tokens, 2 experts each
>>> mask = _create_expert_mask(selected, expert_id=0)
>>> # mask = [True, False, True] - tokens 0 and 2 use expert 0
"""
return jnp.any(selected_experts == expert_id, axis=-1)
def _sparse_moe_call(
self,
hidden_state: jax.Array, # [B, S, H]
gate_layer: nn.Module, # [H, E]
wi_kernel: jax.Array, # [E, H, M]
wu_kernel: jax.Array, # [E, H, M]
wd_kernel: jax.Array, # [E, M, H]
wi_bias: jax.Array | None = None, # [E, H]
wu_bias: jax.Array | None = None, # [E, H]
wd_bias: jax.Array | None = None, # [E, M]
ffn_activation: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
*,
act_fn: Callable[[jax.Array], jax.Array],
):
"""Fused MoE path using grouped matmul and shard_map.
This is the core fused MoE implementation that routes tokens to experts,
permutes them to an expert-grouped layout, applies expert FFNs via grouped
matmul kernels, and unpermutes/combines outputs. It supports both
ring-of-experts and all-to-all expert-parallel communication depending on
configuration and mesh sizes.
**Architecture Overview:**
1. **Routing**: Compute router logits via gate_layer and apply softmax
2. **Permutation**: Sort tokens by expert assignment for grouped computation
3. **Expert Computation**: Apply grouped matmul for W_i, W_u, activation, W_d
4. **Communication** (if EP > 1):
- Ring-of-Experts: All-gather pattern with local expert subsets
- All-to-All: Ragged all-to-all for token redistribution
5. **Unpermutation**: Restore token order and combine expert outputs
6. **Resharding**: Convert from 3D expert mesh back to 5D model mesh
**Hook Integration:**
This method reads hooks from `self.moe_hooks` (automatically configured
by `moe_call()` based on routing strategy). The following hooks
are used at specific points:
- `select_hook`: Refines expert selection weights. For TOP_K routing,
this defaults to weight normalization (softmax). Called during permutation.
- `refine_weights_hook`: Refines weights before W_i and W_u projections.
- `refine_inputs_hook`: Refines token representations before expert-parallel
all-to-all communication in distributed settings.
Other hooks in `MoeFusedHooks` are not used in this path but could be
added in future extensions.
Args:
hidden_state: Input tensor. Shape: [B, S, H].
gate_layer: Router module mapping H -> E (produces logits).
wi_kernel: Expert W_i kernel. Shape: [E, H, M].
wu_kernel: Expert W_u kernel. Shape: [E, H, M].
wd_kernel: Expert W_d kernel. Shape: [E, M, H].
wi_bias: Optional bias for W_i. Shape: [E, H].
wu_bias: Optional bias for W_u. Shape: [E, H].
wd_bias: Optional bias for W_d. Shape: [E, M].
ffn_activation: Optional custom activation combining (w0, w1) -> output.
act_fn: Activation used when `ffn_activation` is not provided.
Returns:
Tuple `(output, router_logits)` where:
- output: MoE layer output. Shape: [B, S, H].
- router_logits: Pre-softmax router logits for auxiliary losses. Shape: [B*S, E].
Example:
>>> # Setup MoE layer with 8 experts, top-2 routing
>>> config.n_routed_experts = 8
>>> config.num_experts_per_tok = 2
>>> config.use_ring_of_experts = False # Use all-to-all
>>>
>>> # Initialize expert kernels
>>> wi_kernel = jax.random.normal(key, (8, 768, 3072)) # gate/up
>>> wu_kernel = jax.random.normal(key, (8, 768, 3072)) # up
>>> wd_kernel = jax.random.normal(key, (8, 3072, 768)) # down
>>>
>>> # Call fused MoE
>>> hidden_states = jnp.ones((2, 512, 768)) # (batch, seq, hidden)
>>> output, logits = moe_layer._sparse_moe_call(
... hidden_state=hidden_states,
... gate_layer=gate,
... wi_kernel=wi_kernel,
... wu_kernel=wu_kernel,
... wd_kernel=wd_kernel,
... act_fn=jax.nn.silu,
... )
>>> # output.shape = (2, 512, 768)
>>> # logits.shape = (1024, 8) # batch*seq, n_experts
"""
select_hook = self.moe_hooks.select_hook if self.moe_hooks else None
refine_weights_hook = self.moe_hooks.refine_weights_hook if self.moe_hooks else None
refine_inputs_hook = self.moe_hooks.refine_inputs_hook if self.moe_hooks else None
hooks = self.moe_hooks
_BS, _SQLN, HD = hidden_state.shape
hidden_state = hidden_state.astype(self.dtype)
if hooks is not None and hooks.before_gate is not None:
hidden_state = hooks.before_gate(hidden_state)
prein_gate_logits = gate_layer(hidden_state.reshape(-1, HD))
if hooks is not None and hooks.after_gate is not None:
prein_gate_logits = hooks.after_gate(prein_gate_logits)
gate_logits = jax.nn.softmax(prein_gate_logits.astype("f4"), axis=-1).astype(prein_gate_logits.dtype)
if hooks is not None and hooks.before_topk is not None:
gate_logits = hooks.before_topk(gate_logits)
# Use expert_mesh (3D: dp, ep, tp) for cleaner sharding
expert_mesh = self.auto_expert_mesh
pm = self.partition_manager
# Resolve axis names from partition_manager (not directly from mesh)
dp_axis_name = resolve_eformer_axis(DP, pm)
expert_axis_name = resolve_eformer_axis(EP, pm)
tensor_axis_name = resolve_eformer_axis(TP, pm)
ep_size = expert_mesh.shape[expert_axis_name]
tp_size = expert_mesh.shape[tensor_axis_name]
if self.config.use_expert_tensor_mode:
assert tp_size == 1, "if using `ExpertTensorMode` Expert Parallel size should be 1."
# Simplified partition specs using 3D expert_mesh
input_ps = jax.sharding.PartitionSpec(dp_axis_name, None, None)
glps = jax.sharding.PartitionSpec(dp_axis_name, None)
if self.config.use_expert_tensor_mode:
output_ps = jax.sharding.PartitionSpec(dp_axis_name, None, None)
else:
output_ps = jax.sharding.PartitionSpec(dp_axis_name, None, tensor_axis_name)
if ffn_activation is None:
def ffn_activation(x0: jax.Array, x1: jax.Array) -> jax.Array:
return act_fn(x0) * x1
# Generate weight sharding specs using helper function
use_expert_tensor = self.config.use_expert_tensor_mode
wikps = self.get_moe_spec("column", use_expert_tensor, is_bias=False)
wukps = self.get_moe_spec("column", use_expert_tensor, is_bias=False)
wdkps = self.get_moe_spec("row", use_expert_tensor, is_bias=False)
wibps = self.get_moe_spec("column", use_expert_tensor, is_bias=True) if wi_bias is not None else None
wubps = self.get_moe_spec("column", use_expert_tensor, is_bias=True) if wu_bias is not None else None
wdbps = self.get_moe_spec("row", use_expert_tensor, is_bias=True) if wd_bias is not None else None
gmm_kws = {"preferred_element_type": jnp.bfloat16}
if self.config.moe_force_xla_gmm:
gmm_kws.update(dict(cfg=GroupedMatmulConfig(platform="xla", bypass_xla_tiling=True)))
else:
if jax.default_backend() == "tpu":
gmm_kws.update(platform="pallas")
if check_bool_flag("DISABLE_MOE_AUTOTUNE_ON_TPU", False):
gmm_kws.update(
cfg=GroupedMatmulConfig(
platform="pallas",
bypass_xla_tiling=True,
block_m=1024,
block_n=1024,
block_k=512,
)
)
@partial(
shard_map,
mesh=expert_mesh, # Use 3D expert_mesh instead of 5D mesh
in_specs=(input_ps, glps, wikps, wukps, wdkps, wibps, wubps, wdbps),
out_specs=output_ps,
check_vma=False,
)
def _sparse_call(
x: jax.Array,
gate_logits: jax.Array,
wi_kernel: jax.Array,
wu_kernel: jax.Array,
wd_kernel: jax.Array,
wi_bias: jax.Array | None,
wu_bias: jax.Array | None,
wd_bias: jax.Array | None,
):
batch_size, sequence_length, _ = x.shape
expert_shard_id = jax.lax.axis_index(expert_axis_name)
if self.config.use_ring_of_experts:
x, gate_logits = tuple(
jax.lax.all_gather(
z,
axis_name=expert_axis_name,
tiled=True,
)
for z in (x, gate_logits)
)
# "Route" tokens within each shard.
experts_per_shard = self.n_routed_experts // ep_size
x, sorted_selected_experts, weights, group_sizes, selected_experts = permute(
inputs=x,
gate_logits=gate_logits,
pre_bias_logits=None,
use_custom_sort_vjp=True,
roll_to_expert_id=experts_per_shard * expert_shard_id,
num_experts_per_tok=self.num_experts_per_tok,
num_experts=self.n_routed_experts,
dtype=self.dtype,
select_hook=select_hook,
refine_weights_hook=refine_weights_hook,
refine_inputs_hook=refine_inputs_hook,
)
group_sizes = group_sizes[:experts_per_shard] # only the local experts
# Optimize: use dynamic slice instead of masking to avoid wasted computation
valid_token_count = jnp.sum(group_sizes)
x = jax.lax.dynamic_slice_in_dim(x, 0, valid_token_count, axis=0)
else:
x, sorted_selected_experts, weights, group_sizes, selected_experts = permute(
inputs=x,
gate_logits=gate_logits,
pre_bias_logits=None,
use_custom_sort_vjp=True,
roll_to_expert_id=None,
num_experts_per_tok=self.num_experts_per_tok,
num_experts=self.n_routed_experts,
dtype=self.dtype,
select_hook=select_hook,
refine_weights_hook=refine_weights_hook,
refine_inputs_hook=refine_inputs_hook,
)
if ep_size > 1:
local_expert_size = self.n_routed_experts // ep_size
reshaped_group_sizes = jnp.sum(group_sizes.reshape(-1, local_expert_size), axis=1)
global_group_sizes = group_sizes
x, _local_sorted_indices, group_sizes, selected_experts = local_permute(
x,
global_group_sizes[None, :],
local_expert_size,
shard_index=expert_shard_id,
is_offset=True,
global_sorted_experts=selected_experts,
use_custom_sort_vjp=True,
)
layer_w0 = grouped_matmul(x, wi_kernel, group_sizes, **gmm_kws)
layer_w0 = checkpoint_name(layer_w0, "mlp_gate")
if wi_bias is not None:
layer_w0 = layer_w0 + wi_bias[selected_experts]
layer_w1 = grouped_matmul(x, wu_kernel, group_sizes, **gmm_kws)
layer_w1 = checkpoint_name(layer_w1, "mlp_up")
if wu_bias is not None:
layer_w1 = layer_w1 + wu_bias[selected_experts]
intermediate_layer = ffn_activation(layer_w0, layer_w1)
intermediate_output = grouped_matmul(intermediate_layer, wd_kernel, group_sizes, **gmm_kws)
intermediate_output = checkpoint_name(intermediate_output, "mlp_down")
# TP reduction: psum_scatter to shard output across TP on hidden dimension
# This matches output_ps = [DP, EMPTY, TP]
if tp_size > 1:
intermediate_output = jax.lax.psum_scatter(
intermediate_output,
tensor_axis_name,
scatter_dimension=1,
tiled=True,
)
if wd_bias is not None:
intermediate_output = intermediate_output + wd_bias[selected_experts]
if self.config.use_ring_of_experts:
# No need to mask - intermediate_output was already sliced to valid size
# If needed for unpermute shape matching, pad back to expected size
expected_size = sorted_selected_experts.shape[0]
current_size = intermediate_output.shape[0]
if current_size < expected_size:
padding = jnp.zeros(
(expected_size - current_size, intermediate_output.shape[1]), dtype=intermediate_output.dtype
)
intermediate_output = jnp.concatenate([intermediate_output, padding], axis=0)
output = unpermute(
intermediate_output,
sorted_selected_experts,
weights,
batch_size=batch_size,
sequence_length=sequence_length,
use_custom_sort_vjp=self.config.use_custom_sort_vjp,
num_experts_per_tok=self.num_experts_per_tok,
dtype=self.dtype,
)
output = jnp.reshape(output, (-1, sequence_length, HD))
output = jax.lax.psum_scatter(output, expert_axis_name, scatter_dimension=0, tiled=True)
else:
if ep_size > 1:
original_inputs_first_dim = batch_size * sequence_length * self.num_experts_per_tok
if sorted_selected_experts.shape[0] != original_inputs_first_dim:
raise ValueError("original_inputs_first_dim does not match the original tensor shape!")
output_shape = jnp.zeros((original_inputs_first_dim, HD // tp_size), dtype=intermediate_output.dtype)
input_offsets, send_sizes, output_offsets, recv_sizes = get_all_to_all_params(
reshaped_group_sizes,
expert_shard_id,
ep_size,
is_batch_sharded=False,
)
intermediate_output = jax.lax.ragged_all_to_all(
intermediate_output,
output_shape,
input_offsets,
send_sizes,
output_offsets,
recv_sizes,
axis_name=expert_axis_name,
)
output = unpermute(
intermediate_output,
sorted_selected_experts,
weights,
batch_size=batch_size,
sequence_length=sequence_length,
use_custom_sort_vjp=True,
num_experts_per_tok=self.num_experts_per_tok,
dtype=self.dtype,
)
return output
# print(
# wi_kernel.shape,
# wikps,
# wi_bias.shape,
# wibps,
# wu_kernel.shape,
# wukps,
# wu_bias.shape,
# wubps,
# wd_kernel.shape,
# wdkps,
# wd_bias.shape,
# wdbps,
# )
output = _sparse_call(
hidden_state,
gate_logits,
wi_kernel,
wu_kernel,
wd_kernel,
wi_bias,
wu_bias,
wd_bias,
)
# Reshard output back to original 5D mesh for compatibility with rest of model
# This ensures the output can be used in residual connections
original_output_ps = self.partition_manager.resolve(
axes=[DP, EMPTY, TP] if not self.config.use_expert_tensor_mode else [DP, EMPTY, EMPTY],
mode=MODE_TRAIN,
shape=output.shape,
)
output = jax.lax.with_sharding_constraint(output, jax.sharding.NamedSharding(self.mesh, original_output_ps))
return output, prein_gate_logits
[docs] def moe_call(
self,
hidden_state: jax.Array, # [B, S, H]
gate_layer: nn.Module,
expert_layer: nn.Module,
wi_kernel: jax.Array, # [E, H, M]
wu_kernel: jax.Array, # [E, H, M]
wd_kernel: jax.Array, # [E, M, H]
wi_bias: jax.Array | None = None, # [E, H]
wu_bias: jax.Array | None = None, # [E, H]
wd_bias: jax.Array | None = None, # [E, M]
ffn_activation: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
reform_router_probs_fn: typing.Callable[[jax.Array], jax.Array] | None = None,
*,
act_fn: Callable[[jax.Array], jax.Array],
output_metrics: bool = False,
layer_idx: int | None = None,
):
"""Wrapper for fused MoE call with automatic hook configuration.
This method dispatches to either standard or fused MoE based on config, and
automatically configures hooks based on the routing strategy to ensure correct
expert weight handling.
**Hook Auto-Configuration:**
Before calling the fused MoE path, this method automatically configures
default hooks for the routing strategy if they're not already set by the user:
- **TOP_K**: Normalizes weights by their sum (softmax-like distribution).
- **TOP_K_NDIV**: Passes weights through unchanged (raw logit values).
- **SWITCH**: Enforces hard assignment with weight = 1.0.
- **EMPTY_CHOICE**: Uniform weights across expert selections.
- **HASH**: Uniform weights for deterministic assignments.
Each strategy gets an appropriate default `select_hook` that ensures
correct weight handling without requiring manual setup. Users can override
defaults by setting custom hooks on `self.moe_hooks` before calling the layer.
Args:
hidden_state: Input tensor. Shape: [B, S, H].
gate_layer: Router/gate module mapping H -> E (produces logits).
expert_layer: Expert layer module.
wi_kernel: Expert W_i (down/first) kernel. Shape: [E, H, M].
wu_kernel: Expert W_u (up/second) kernel. Shape: [E, H, M].
wd_kernel: Expert W_d (output/down) kernel. Shape: [E, M, H].
wi_bias: Optional bias for W_i. Shape: [E, H].
wu_bias: Optional bias for W_u. Shape: [E, H].
wd_bias: Optional bias for W_d. Shape: [E, M].
ffn_activation: Optional custom activation combining (w0, w1) -> output.
reform_router_probs_fn: Optional function to modify router probabilities
(used in standard MoE mode only).
act_fn: Activation function used when `ffn_activation` is not provided.
output_metrics: Whether to return metrics in standard MoE mode.
Returns:
Tuple of (output, logits) where:
- output: MoE layer output. Shape: [B, S, H].
- logits: Router logits for auxiliary loss computation. Shape: [B*S, E].
"""
self._configure_hooks_for_routing_strategy()
with self.auto_expert_mesh:
match self.module_moe_method:
case MoEMethods.STANDARD_MOE:
logger.warn_once(
"You are using MoEMethods.STANDARD_MOE which is not really recommended please switch to FUSED_MOE"
)
return self._moe_call_standard(
gate_layer=gate_layer,
expert_layer=expert_layer,
hidden_state=hidden_state,
output_metrics=output_metrics,
validate_inputs=False,
apply_capacity_constraint=False,
reform_router_probs_fn=reform_router_probs_fn,
layer_idx=layer_idx,
)
case MoEMethods.FUSED_MOE:
return self._sparse_moe_call(
hidden_state=hidden_state,
gate_layer=gate_layer,
wi_kernel=wi_kernel,
wu_kernel=wu_kernel,
wd_kernel=wd_kernel,
wi_bias=wi_bias,
wu_bias=wu_bias,
wd_bias=wd_bias,
act_fn=act_fn,
ffn_activation=ffn_activation,
)
case MoEMethods.DENSE_MOE:
logger.warn_once(
"You are using MoEMethods.DENSE_MOE which is not really recommended please switch to FUSED_MOE"
)
return self._moe_call_dense(
hidden_state=hidden_state,
gate_layer=gate_layer,
wi_kernel=wi_kernel,
wu_kernel=wu_kernel,
wd_kernel=wd_kernel,
wi_bias=wi_bias,
wu_bias=wu_bias,
wd_bias=wd_bias,
act_fn=act_fn,
ffn_activation=ffn_activation,
)
case _:
raise NotImplementedError()
def _moe_call_dense(
self,
hidden_state: jax.Array, # [B, S, H]
gate_layer: nn.Module,
wi_kernel: jax.Array, # [E, H, M]
wu_kernel: jax.Array, # [E, H, M]
wd_kernel: jax.Array, # [E, M, H]
wi_bias: jax.Array | None = None, # [E, M]
wu_bias: jax.Array | None = None, # [E, M]
wd_bias: jax.Array | None = None, # [E, H]
ffn_activation: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
*,
act_fn: Callable[[jax.Array], jax.Array],
capacity_factor: float | None = None,
output_metrics: bool = False,
):
"""Dense MoE path using per-token batched matmuls instead of ragged grouping."""
self._configure_hooks_for_routing_strategy()
hooks = self.moe_hooks
hidden_state = hidden_state.astype(self.dtype)
if hooks.before_gate is not None:
hidden_state = hooks.before_gate(hidden_state)
batch_size, seq_len, hidden_dim = hidden_state.shape
tokens = batch_size * seq_len
hidden_flat = hidden_state.reshape(tokens, hidden_dim)
prein_gate_logits = gate_layer(hidden_flat)
gate_logits = prein_gate_logits
if hooks.after_gate is not None:
gate_logits = hooks.after_gate(gate_logits)
router_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
if hooks.before_topk is not None:
router_probs = hooks.before_topk(router_probs)
selected_weights, selected_experts = get_experts_location(
gate_logits=router_probs,
pre_bias_logits=None,
select_hook=hooks.select_hook,
refine_weights_hook=hooks.refine_weights_hook,
num_experts_per_tok=self.num_experts_per_tok,
)
weights = selected_weights.astype(self.dtype)
experts = selected_experts.astype(jnp.int32)
if capacity_factor is not None and capacity_factor > 0:
weights_shaped = weights.reshape(batch_size, seq_len, self.num_experts_per_tok)
experts_shaped = experts.reshape(batch_size, seq_len, self.num_experts_per_tok)
weights_shaped = self._apply_capacity_mask(experts_shaped, weights_shaped, capacity_factor)
weights = weights_shaped.reshape(tokens, self.num_experts_per_tok)
weight_sum = jnp.sum(weights, axis=-1, keepdims=True)
weights = jnp.where(weight_sum > 0, weights / weight_sum, weights)
if ffn_activation is None:
def ffn_activation(x0: jax.Array, x1: jax.Array) -> jax.Array:
return act_fn(x0) * x1
precision = getattr(self, "precision", None)
hidden_expanded = hidden_flat[:, None, :]
wi_sel = jnp.take(wi_kernel.astype(self.dtype), experts, axis=0)
w0 = jnp.einsum("tkh,tkhm->tkm", hidden_expanded, wi_sel, precision=precision)
if wi_bias is not None:
w0 = w0 + jnp.take(wi_bias.astype(self.dtype), experts, axis=0)
wu_sel = jnp.take(wu_kernel.astype(self.dtype), experts, axis=0)
w1 = jnp.einsum("tkh,tkhm->tkm", hidden_expanded, wu_sel, precision=precision)
if wu_bias is not None:
w1 = w1 + jnp.take(wu_bias.astype(self.dtype), experts, axis=0)
intermediate = ffn_activation(w0, w1)
if hooks.before_wo is not None:
intermediate = hooks.before_wo(intermediate)
wd_sel = jnp.take(wd_kernel.astype(self.dtype), experts, axis=0)
outputs = jnp.einsum("tkm,tkmh->tkh", intermediate, wd_sel, precision=precision)
if wd_bias is not None:
outputs = outputs + jnp.take(wd_bias.astype(self.dtype), experts, axis=0)
if hooks.after_wo is not None:
outputs = hooks.after_wo(outputs)
if hooks.before_combine is not None:
outputs, weights = hooks.before_combine(outputs, weights)
combined = jnp.sum(outputs * weights[..., None], axis=1)
output = combined.reshape(batch_size, seq_len, hidden_dim)
if hooks.finalize_output is not None:
output = hooks.finalize_output(output)
if output_metrics:
expert_mask = (weights > 0).astype(self.dtype)
expert_loads = jnp.bincount(
experts.reshape(-1),
weights=expert_mask.reshape(-1),
length=self.n_routed_experts,
).astype(self.dtype)
metrics = self._compute_metrics(
router_logits=prein_gate_logits,
router_probs=router_probs,
selected_experts=experts,
selected_weights=weights,
expert_loads=expert_loads,
)
return output, metrics
return output, prein_gate_logits
def _configure_hooks_for_routing_strategy(self) -> None:
"""Configure default hooks based on the current routing strategy.
This method ensures each routing strategy has appropriate hook configuration
without requiring manual setup. Only sets hooks if they haven't been explicitly
configured by the user.
**Hook Configuration by Strategy:**
TOP_K: Sets `select_hook` to normalize weights by their sum.
Ensures expert weights sum to 1.0 for proper weighted combination.
TOP_K_NDIV: Sets `select_hook` to pass through weights unchanged.
Uses raw logit values without normalization.
SWITCH: Sets `select_hook` to enforce hard assignment (weight = 1.0).
Only one expert gets non-zero weight.
EMPTY_CHOICE: Sets `select_hook` to normalize per-expert selections.
Each expert receives equal contribution from selected tokens.
HASH: Sets `select_hook` to uniform weight distribution.
All assigned experts get equal weight (1/k).
"""
# Only set default refine_weights_hook if one wasn't already configured by the user.
if self.moe_hooks.refine_weights_hook is not None:
return
refine_weights_hook = None
if self.routing_strategy == MoeRoutingStrategy.TOP_K:
# TOP_K: Normalize weights by their sum (softmax-like)
if self.moe_hooks.select_hook is None:
def normalize_selected_weights(weights: jax.Array) -> jax.Array:
"""Normalize top-k expert weights by their sum.
Ensures weights for each token sum to 1.0, creating a proper
probability distribution over selected experts.
"""
return weights / jnp.maximum(weights.sum(axis=-1, keepdims=True), 1e-8)
refine_weights_hook = normalize_selected_weights
elif self.routing_strategy == MoeRoutingStrategy.TOP_K_NDIV:
# TOP_K_NDIV: Use weights as-is (raw logits, no normalization)
if self.moe_hooks.select_hook is None:
def passthrough_weights(weights: jax.Array) -> jax.Array:
"""Pass through weights unchanged.
For TOP_K_NDIV routing, weights are used as raw logit values
without normalization. This allows unnormalized combinations.
"""
return weights
refine_weights_hook = passthrough_weights
elif self.routing_strategy == MoeRoutingStrategy.SWITCH:
# SWITCH: Hard assignment - single expert gets weight 1.0, others 0.0
if self.moe_hooks.select_hook is None:
def hard_assignment_weights(weights: jax.Array) -> jax.Array:
"""Enforce hard assignment for SWITCH routing.
Only one expert per token is selected (top-1). This hook ensures
the weight is exactly 1.0, creating a hard (non-differentiable)
expert assignment.
"""
return jnp.ones_like(weights)
refine_weights_hook = hard_assignment_weights
elif self.routing_strategy == MoeRoutingStrategy.EMPTY_CHOICE:
# EMPTY_CHOICE: Expert-driven selection - normalize per expert
if self.moe_hooks.select_hook is None:
def expert_choice_weights(weights: jax.Array) -> jax.Array:
"""Normalize weights for Expert Choice routing.
In Expert Choice routing, each expert selects its own top-k tokens.
This hook ensures proper weight distribution across the selected
tokens for each expert.
"""
# For expert choice, normalize differently - each expert's selections
# should have equal contribution
num_experts_selected = weights.shape[-1]
return jnp.ones_like(weights) / jnp.maximum(num_experts_selected, 1)
refine_weights_hook = expert_choice_weights
elif self.routing_strategy == MoeRoutingStrategy.HASH:
# HASH: Deterministic routing - uniform weights for all assigned experts
if self.moe_hooks.select_hook is None:
def uniform_weights(weights: jax.Array) -> jax.Array:
"""Uniform weights for hash-based routing.
In hash-based routing, tokens are deterministically assigned to experts
based on token ID. Each expert in the assignment group gets equal weight.
"""
num_experts_per_token = weights.shape[-1]
return jnp.ones_like(weights) / jnp.maximum(num_experts_per_token, 1)
refine_weights_hook = uniform_weights
if refine_weights_hook is not None:
self.moe_hooks = self.moe_hooks.replace(refine_weights_hook=refine_weights_hook)
def _moe_call_standard(
self,
gate_layer: nn.Module,
expert_layer: nn.Module,
hidden_state: jax.Array,
output_metrics: bool = False,
validate_inputs: bool = False,
apply_capacity_constraint: bool = False,
reform_router_probs_fn: typing.Callable[[jax.Array], jax.Array] | None = None,
layer_idx: int | None = None,
) -> tuple[jax.Array, MoeMetrics | jax.Array]:
"""Standard MoE forward pass: routing, permutation, expert computation, and combining.
This method uses the MoeFusedHooks system to allow custom interventions at key
points during execution:
**Hook Integration:**
- `before_gate`: Applied before gate/router computation.
- `after_gate`: Applied after gate logits computation.
- `before_topk`: Applied before expert selection (top-k).
- `refine_weights_hook`: Refines expert weights after selection.
- `refine_inputs_hook`: Refines token representations before expert computation.
- `before_combine`: Applied before combining expert outputs.
- `finalize_output`: Applied to the final output.
Args:
gate_layer: Router module mapping hidden states to expert logits.
expert_layer: Expert layer module for computing expert outputs.
hidden_state: Input tensor. Shape: [B, S, H].
output_metrics: Whether to return detailed MoE metrics.
validate_inputs: Whether to validate input shapes.
apply_capacity_constraint: Whether to apply capacity constraints.
reform_router_probs_fn: Optional function to modify router probabilities.
Returns:
Tuple of (output, metrics_or_logits) where:
- output: MoE layer output. Shape: [B, S, H].
- metrics_or_logits: MoeMetrics if output_metrics=True, else router_logits.
"""
self._configure_hooks_for_routing_strategy()
hooks = self.moe_hooks
hidden_state = hidden_state.astype(self.dtype)
if hooks.before_gate is not None:
hidden_state = hooks.before_gate(hidden_state)
batch_size, seq_len, hidden_size = hidden_state.shape
hidden_state_flat = hidden_state.reshape(-1, hidden_size)
router_logits = gate_layer(hidden_state_flat).astype(jnp.promote_types(self.dtype, jnp.float32))
# Store original logits BEFORE any hooks - used for expert selection (matching HF behavior).
prein_gate_logits = router_logits
# after_gate hook produces scattered probs for aux loss/logging, but we use original logits for selection.
if hooks.after_gate is not None:
router_probs = hooks.after_gate(router_logits)
else:
router_probs = jax.nn.softmax(router_logits, axis=-1)
if reform_router_probs_fn is not None:
router_probs = reform_router_probs_fn(router_probs)
if hooks.before_topk is not None:
router_probs = hooks.before_topk(router_probs)
if validate_inputs:
self._validate_routing_inputs(hidden_state, router_logits)
# Use original logits for expert selection (top-k on logits, then softmax on k selected via refine_weights_hook).
# This matches HuggingFace behavior where top-k is done on pre-softmax logits.
selected_weights, selected_experts = get_experts_location(
gate_logits=prein_gate_logits,
pre_bias_logits=None,
select_hook=hooks.select_hook,
refine_weights_hook=hooks.refine_weights_hook,
num_experts_per_tok=self.num_experts_per_tok,
)
# Detailed logging for debugging
if layer_idx is not None:
# Get top-k logits (before softmax) for comparison
top_k_logits_pre, _ = jax.lax.top_k(prein_gate_logits, self.num_experts_per_tok)
jax.debug.print(" [ED Router L{}] logits[0]: {}", layer_idx, prein_gate_logits[0])
jax.debug.print(
" [ED Router L{}] top_idx[0]: {}, top_logits[0]: {}",
layer_idx,
selected_experts[0],
top_k_logits_pre[0],
)
jax.debug.print(
" [ED Router L{}] top_weights[0]: {} (sum={})",
layer_idx,
selected_weights[0],
selected_weights[0].sum(),
)
jax.debug.print(" [ED Experts L{}] input[0,:5]: {}", layer_idx, hidden_state_flat[0, :5])
if apply_capacity_constraint:
selected_experts, selected_weights = self._apply_capacity_constraint(selected_experts, selected_weights)
if hooks.refine_inputs_hook is not None:
hidden_state_flat = hooks.refine_inputs_hook(
hidden_state_flat,
selected_weights,
(batch_size, seq_len, hidden_size),
)
(
sorted_inputs,
sort_order,
group_sizes,
sorted_experts,
) = self._replicate_and_sort_tokens(hidden_state_flat, selected_experts)
out_sorted = expert_layer(sorted_inputs, group_sizes, sorted_experts)
out_unsorted = sort_activations(out_sorted, jnp.argsort(sort_order))
out_unflat = out_unsorted.reshape(batch_size * seq_len, self.num_experts_per_tok, hidden_size)
if hooks.before_combine is not None:
out_unflat, selected_weights = hooks.before_combine(out_unflat, selected_weights)
output = jnp.sum(out_unflat * selected_weights[..., None], axis=1).reshape(batch_size, seq_len, hidden_size)
# Log expert output
if layer_idx is not None:
jax.debug.print(" [ED Experts L{}] output[0,:5]: {}", layer_idx, output.reshape(-1, hidden_size)[0, :5])
if hooks.finalize_output is not None:
output = hooks.finalize_output(output)
if output_metrics:
metrics = self._compute_metrics(
router_logits,
router_probs,
selected_experts,
selected_weights,
group_sizes,
)
return output, metrics
return output, router_logits
@abstractmethod
def __call__(
self,
hidden_states: Float[Array, "batch seq hidden_dim"],
**kwargs,
) -> tuple[Float[Array, "batch seq hidden_dim"], MoeMetrics]:
"""Performs the forward pass of the MoE module.
Subclasses must implement this method to define the specific logic of their
MoE layer.
Args:
hidden_states: The input tensor.
**kwargs: Additional keyword arguments that may be required by the
specific implementation.
Returns:
A tuple containing:
- output: The output tensor from the MoE layer.
- metrics: A `MoeMetrics` object containing metrics and auxiliary losses.
"""
pass