easydel.layers.moe

Contents

easydel.layers.moe#

EasyDeL Mixture-of-Experts (MoE) layers and utilities.

This package provides comprehensive building blocks for implementing Mixture of Experts models in JAX with EasyDeL, supporting various routing strategies, load balancing techniques, and distributed training optimizations.

Core Components#

Base Classes:
  • BaseMoeModule: Abstract base class for MoE implementations with routing, permutation, metrics computation, and distributed execution utilities.

Linear Layers:
  • ParallelMoELinear: Batched per-expert linear transformation layer with support for ragged/grouped matrix multiplication

  • RowParallelMoELinear: Row-parallel variant (input dimension partitioned)

  • ColumnParallelMoELinear: Column-parallel variant (output dimension partitioned)

Enumerations:
  • MoEMethods: Execution methods (FUSED_MOE, STANDARD_MOE, DENSE_MOE)

  • MoeRoutingStrategy: Token routing strategies (TOP_K, SWITCH, EXPERT_CHOICE, HASH)

  • MoeLoadBalancingStrategy: Load balancing loss strategies (STANDARD, SWITCH_TRANSFORMER, EXPERT_CHOICE)

Data Classes:
  • MoeFusedHooks: Hook system for custom interventions during MoE execution

  • MoeMetrics: Container for MoE metrics (expert loads, routing entropy, auxiliary losses)

Partition Spec Utilities:
  • get_moe_partition_spec: Generate partition specs for MoE weight tensors

Features#

Routing Strategies:
  • Top-K routing with weight normalization

  • Switch Transformer (top-1) routing

  • Expert Choice routing (inverted selection)

  • Hash-based deterministic routing

  • Custom hooks for implementing novel routing methods

Load Balancing:
  • Standard load balancing loss

  • Switch Transformer auxiliary loss

  • Expert Choice variance-based loss

  • Router z-loss for stability

Distributed Training:
  • Expert Parallelism (EP): Partition experts across devices

  • Tensor Parallelism (TP): Partition weight matrices within experts

  • Data Parallelism (DP): Replicate across data batches

  • Fully Sharded Data Parallel (FSDP): Memory-efficient parameter sharding

  • Sequence Parallelism (SP): Partition sequence dimension

  • Expert Tensor Mode: Alternative sharding with experts on TP axis

  • 3D Expert Mesh: Simplified mesh combining FSDP, EP, and SP into single expert dimension

Execution Modes:
  • Fused MoE: Optimized grouped matmul with shard_map (best for TPU/GPU)

  • Standard MoE: Traditional permute-compute-unpermute (most flexible)

  • Dense MoE: Per-token einsum operations (debugging/fallback)

Communication Patterns:
  • Ring-of-Experts: Efficient all-gather pattern for expert parallelism

  • All-to-All: Ragged all-to-all communication for token redistribution

  • Automatic selection based on mesh configuration

Recent Improvements#

3D Expert Mesh (v0.0.81+):

The MoE implementation has been refactored to use a simplified 3D expert mesh that combines FSDP, EP, and SP axes into a single unified expert dimension. This provides:

  • Cleaner sharding specifications with (dp, expert, tp) layout

  • Simplified partition spec generation via get_moe_partition_spec

  • Better compatibility with grouped matmul kernels

  • Automatic resharding between 5D model mesh and 3D expert mesh

The 3D expert mesh is created automatically in BaseMoeModule._create_expert_mesh() and used internally for all MoE operations, with automatic resharding at boundaries.

Example Usage#

Basic MoE Layer:

>>> from easydel.layers.moe import (
...     BaseMoeModule,
...     MoEMethods,
...     MoeRoutingStrategy,
...     MoeLoadBalancingStrategy,
... )
>>> from flax import nnx as nn
>>>
>>> # Configure MoE execution
>>> config.moe_method = MoEMethods.FUSED_MOE
>>> config.routing_strategy = MoeRoutingStrategy.TOP_K
>>> config.load_balancing_strategy = MoeLoadBalancingStrategy.STANDARD
>>> config.n_routed_experts = 8
>>> config.num_experts_per_tok = 2
>>>
>>> # Create custom MoE layer by extending BaseMoeModule
>>> class MyMoELayer(BaseMoeModule):
...     def __init__(self, config, rngs):
...         super().__init__(config)
...         self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, rngs=rngs)
...         # Initialize expert FFN weights...
...
...     def __call__(self, hidden_states):
...         output, router_logits = self.moe_call(
...             hidden_state=hidden_states,
...             gate_layer=self.gate,
...             wi_kernel=self.wi_kernel,
...             wu_kernel=self.wu_kernel,
...             wd_kernel=self.wd_kernel,
...             act_fn=nn.silu,
...         )
...         return output

Custom Routing with Hooks:

>>> from easydel.layers.moe import MoeFusedHooks
>>>
>>> # Define custom weight refinement
>>> def temperature_scaling(weights):
...     temperature = 0.5
...     return jax.nn.softmax(weights / temperature)
>>>
>>> # Create hooks with custom logic
>>> hooks = MoeFusedHooks(refine_weights_hook=temperature_scaling)
>>>
>>> # Use hooks in MoE layer
>>> moe_layer = MyMoELayer(config, moe_hooks=hooks, rngs=rngs)

Distributed Training Setup:

>>> import jax
>>> from jax.sharding import Mesh
>>> from eformer.escale import PartitionManager
>>>
>>> # Create 5D mesh for model training
>>> devices = jax.devices()
>>> mesh = Mesh(
...     devices.reshape(1, 1, 4, 2, 1),  # (dp, fsdp, ep, tp, sp)
...     axis_names=("dp", "fsdp", "expert", "tensor", "sequence")
... )
>>>
>>> # Configure partition manager
>>> config.mesh = mesh
>>> config.partition_manager = PartitionManager(mesh, ...)
>>>
>>> # MoE layer automatically creates 3D expert mesh:
>>> # Combines FSDP*EP*SP into single expert axis (4*1 = 4)
>>> # Resulting expert_mesh shape: (dp=1, expert=4, tp=2)

See also

-, -, -

Notes

The modules are designed to work seamlessly with JAX distributed meshes and EFormer’s PartitionManager for automatic sharding. All implementations support gradient checkpointing and mixed precision training.

For optimal performance on TPUs, use MoEMethods.FUSED_MOE with the grouped matmul kernel. On GPUs, both FUSED_MOE and STANDARD_MOE work well, with FUSED_MOE providing better performance for large expert counts.

class easydel.layers.moe.BaseMoeModule(*args: Any, **kwargs: Any)[source]#

Bases: Module, ABC

An abstract base class for Mixture of Experts (MoE) modules.

This class provides a foundational structure and common utilities for implementing various MoE architectures. It includes methods for token routing, data permutation for efficient expert computation, load balancing loss calculation, and sharding for distributed environments. Subclasses are expected to implement the __call__ method to define the specific MoE forward pass.

config#

The configuration object for the MoE module.

mesh#

The JAX device mesh for distributed computation.

n_routed_experts#

The total number of experts available for routing.

num_experts_per_tok#

The number of experts each token is routed to (k).

hidden_size#

The dimension of the hidden states.

lbl_coef#

The coefficient for the load balancing loss.

rzl_coef#

The coefficient for the router z-loss.

routing_strategy#

The strategy used for routing tokens to experts.

load_balancing_strategy#

The strategy used for calculating the load balancing loss.

get_moe_spec(direction: Literal['row', 'column'], tensors_are_expert: bool, is_bias: bool = False) PartitionSpec[source]#

Generate partition spec for MoE weight tensors.

This helper creates appropriate partition specs for MoE expert weights based on the sharding strategy and tensor properties.

Parameters
  • direction – Weight matrix orientation: - “column”: For column-wise sharding (wi/wu kernels) - “row”: For row-wise sharding (wd kernel)

  • tensors_are_expert – If True, uses expert tensor mode (experts on TP axis). If False, uses standard mode (experts on EP axis).

  • is_bias – If True, generates spec for bias tensor (2D instead of 3D).

Returns

PartitionSpec appropriate for the tensor.

Examples

Standard mode (tensors_are_expert=False):
  • Column weight: [expert, None, tp] # wi/wu: [E, H, M]

  • Row weight: [expert, tp, None] # wd: [E, M, H]

  • Bias: [expert, None] # [E, dim]

Expert tensor mode (tensors_are_expert=True):
  • Column weight: [tp, None, None] # Experts on TP

  • Row weight: [tp, None, None] # Experts on TP

  • Bias: [tp, None] # [E, dim]

moe_call(hidden_state: Array, gate_layer: Module, expert_layer: Module, wi_kernel: Array, wu_kernel: Array, wd_kernel: Array, wi_bias: jax.Array | None = None, wu_bias: jax.Array | None = None, wd_bias: jax.Array | None = None, ffn_activation: collections.abc.Callable[[jax.Array, jax.Array], jax.Array] | None = None, reform_router_probs_fn: Optional[Callable[[Array], Array]] = None, *, act_fn: Callable[[Array], Array], output_metrics: bool = False, layer_idx: int | None = None)[source]#

Wrapper for fused MoE call with automatic hook configuration.

This method dispatches to either standard or fused MoE based on config, and automatically configures hooks based on the routing strategy to ensure correct expert weight handling.

Hook Auto-Configuration:

Before calling the fused MoE path, this method automatically configures default hooks for the routing strategy if they’re not already set by the user:

  • TOP_K: Normalizes weights by their sum (softmax-like distribution).

  • TOP_K_NDIV: Passes weights through unchanged (raw logit values).

  • SWITCH: Enforces hard assignment with weight = 1.0.

  • EMPTY_CHOICE: Uniform weights across expert selections.

  • HASH: Uniform weights for deterministic assignments.

Each strategy gets an appropriate default select_hook that ensures correct weight handling without requiring manual setup. Users can override defaults by setting custom hooks on self.moe_hooks before calling the layer.

Parameters
  • hidden_state – Input tensor. Shape: [B, S, H].

  • gate_layer – Router/gate module mapping H -> E (produces logits).

  • expert_layer – Expert layer module.

  • wi_kernel – Expert W_i (down/first) kernel. Shape: [E, H, M].

  • wu_kernel – Expert W_u (up/second) kernel. Shape: [E, H, M].

  • wd_kernel – Expert W_d (output/down) kernel. Shape: [E, M, H].

  • wi_bias – Optional bias for W_i. Shape: [E, H].

  • wu_bias – Optional bias for W_u. Shape: [E, H].

  • wd_bias – Optional bias for W_d. Shape: [E, M].

  • ffn_activation – Optional custom activation combining (w0, w1) -> output.

  • reform_router_probs_fn – Optional function to modify router probabilities (used in standard MoE mode only).

  • act_fn – Activation function used when ffn_activation is not provided.

  • output_metrics – Whether to return metrics in standard MoE mode.

Returns

  • output: MoE layer output. Shape: [B, S, H].

  • logits: Router logits for auxiliary loss computation. Shape: [B*S, E].

Return type

Tuple of (output, logits) where

class easydel.layers.moe.ColumnParallelMoELinear(*args: Any, **kwargs: Any)[source]#

Bases: ParallelMoELinear

Column-parallel variant of ParallelMoELinear.

This class specializes ParallelMoELinear for column-wise parallelism, where the output dimension is partitioned across devices. In column parallelism, each device computes a subset of output features independently without requiring reduction.

The weight matrix is partitioned along the output dimension (columns), and each device produces its portion of the output directly without communication.

_direction#

Fixed to “column” to indicate column-wise parallelism.

Type

Optional[Literal[‘row’, ‘column’]]

Example

>>> # Create a column-parallel MoE linear layer
>>> layer = ColumnParallelMoELinear(
...     num_experts=8,
...     in_features=768,
...     out_features=3072,
...     rngs=rngs
... )
class easydel.layers.moe.MoEMethods(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: str, Enum

Enumeration of available MoE execution methods.

This enum defines different strategies for executing Mixture of Experts layers, each optimized for different use cases and hardware configurations.

FUSED_MOE#

Fused execution path using grouped matrix multiplication and shard_map. Optimal for distributed training with expert parallelism on TPUs/GPUs. Uses ragged tensor operations and custom kernels for maximum efficiency. Automatically falls back to STANDARD_MOE when FSDP*SP axis size > 1.

STANDARD_MOE#

Standard token-by-token execution path. More flexible and easier to debug. Uses traditional permutation, expert computation, and unpermutation steps. Supports all sharding configurations and is the fallback when fused path is unavailable.

DENSE_MOE#

Dense batched execution using per-token matrix multiplications. Instead of ragged/grouped operations, uses dense einsum operations with expert selection via indexing. Useful for debugging or when grouped matmul kernels are not available.

Example

>>> from easydel.layers.moe import MoEMethods
>>> # Configure in model config
>>> config.moe_method = MoEMethods.FUSED_MOE
DENSE_MOE = 'dense_moe'#
FUSED_MOE = 'fused_moe'#
STANDARD_MOE = 'standard_moe'#
class easydel.layers.moe.MoeFusedHooks(before_gate: collections.abc.Callable | None = None, after_gate: collections.abc.Callable | None = None, before_topk: collections.abc.Callable | None = None, select_hook: collections.abc.Callable | None = None, refine_inputs_hook: collections.abc.Callable | None = None, after_ep_receive: collections.abc.Callable | None = None, refine_weights_hook: collections.abc.Callable | None = None, after_wiwu: collections.abc.Callable | None = None, before_wo: collections.abc.Callable | None = None, after_wo: collections.abc.Callable | None = None, before_combine: collections.abc.Callable | None = None, finalize_output: collections.abc.Callable | None = None)[source]#

Bases: object

Optional callbacks executed at key points of the fused MoE pipeline.

Hooks allow fine-grained customization of the MoE execution flow without modifying core logic. Each hook is invoked at a specific stage of the pipeline and can inspect/modify tensors.

Routing & Selection Hooks:
before_gate: Invoked before gate/router layer. Can preprocess hidden states.

Signature: (hidden_states: Array) -> Array

after_gate: Invoked after gate/router logits are computed. Can postprocess logits.

Signature: (gate_logits: Array) -> Array

before_topk: Invoked before top-k expert selection. Can modify logits before selection.

Signature: (gate_logits: Array) -> Array

select_hook: Invoked after top-k selection to refine expert weights/scores.

Used for custom routing logic or weight normalization. Signature: (selected_weights: Array, selected_experts: Array) -> (weights: Array, experts: Array)

Default for TOP_K routing: Normalizes weights by their sum (softmax-like normalization).

Expert Processing Hooks:
refine_weights_hook: Invoked before W_i and W_u (gate/up) projections.

Can refine expert weights before linear transformations. Signature: (weights: Array) -> Array

after_wiwu: Invoked after W_i and W_u (gate/up) projections.

Can post-process expert intermediate activations. Signature: (intermediate: Array) -> Array

before_wo: Invoked before W_o (output) projection.

Can modify combined expert outputs before final projection. Signature: (combined_output: Array) -> Array

after_wo: Invoked after W_o (output) projection.

Can post-process final expert layer outputs. Signature: (output: Array) -> Array

Distributed Execution Hooks:
refine_inputs_hook: Invoked before expert-parallel all-to-all communication.

Can refine token representations or route them to specific experts. Signature: (inputs: Array, weights: Array, shape: Tuple) -> Array

after_ep_receive: Invoked after receiving tokens from other expert shards.

Can refine received token representations before expert computation. Signature: (received_tokens: Array) -> Array

Output Combination Hooks:
before_combine: Invoked before combining outputs from multiple experts per token.

Can adjust expert weights or outputs before weighted sum. Signature: (outputs: Array, weights: Array) -> (outputs: Array, weights: Array)

finalize_output: Invoked at the very end of MoE computation.

Can apply final normalization, residual connections, etc. Signature: (final_output: Array) -> Array

Default behavior: All hooks are None, so the pipeline proceeds without intervention.

after_ep_receive: collections.abc.Callable | None = None#
after_gate: collections.abc.Callable | None = None#
after_wiwu: collections.abc.Callable | None = None#
after_wo: collections.abc.Callable | None = None#
before_combine: collections.abc.Callable | None = None#
before_gate: collections.abc.Callable | None = None#
before_topk: collections.abc.Callable | None = None#
before_wo: collections.abc.Callable | None = None#
finalize_output: collections.abc.Callable | None = None#
refine_inputs_hook: collections.abc.Callable | None = None#
refine_weights_hook: collections.abc.Callable | None = None#
replace(**kws) MoeFusedHooks[source]#
select_hook: collections.abc.Callable | None = None#
class easydel.layers.moe.MoeLoadBalancingStrategy(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: Enum

Defines the available strategies for calculating the load balancing loss.

STANDARD#

A common load balancing loss based on the product of expert loads and mean router probabilities.

SWITCH_TRANSFORMER#

The load balancing loss used in the Switch Transformer paper.

EMPTY_CHOICE#

A load balancing loss variant suitable for Expert Choice routing, often based on the variance of expert loads.

NONE#

No load balancing loss is applied.

EMPTY_CHOICE = 'expert_choice'#
NONE = 'none'#
STANDARD = 'standard'#
SWITCH_TRANSFORMER = 'switch_transformer'#
class easydel.layers.moe.MoeMetrics(expert_loads: Float[Array, 'num_experts'], router_probs: Float[Array, 'batch_seq num_experts'], selected_experts: Int[Array, 'batch_seq num_experts_per_tok'], selected_weights: Float[Array, 'batch_seq num_experts_per_tok'], load_balancing_loss: float | None = None, router_z_loss: float | None = None, expert_utilization: float | None = None, routing_entropy: float | None = None)[source]#

Bases: object

A container for storing metrics and auxiliary losses from an MoE layer.

expert_loads#

An array representing the number of tokens routed to each expert. Shape: (num_experts,).

Type

jaxtyping.Float[Array, ‘num_experts’]

router_probs#

The probabilities output by the router for each token and expert. Shape: (num_tokens, num_experts).

Type

jaxtyping.Float[Array, ‘batch_seq num_experts’]

selected_experts#

The indices of the experts selected for each token. Shape: (num_tokens, num_experts_per_tok).

Type

jaxtyping.Int[Array, ‘batch_seq num_experts_per_tok’]

selected_weights#

The weights assigned to the selected experts for each token. Shape: (num_tokens, num_experts_per_tok).

Type

jaxtyping.Float[Array, ‘batch_seq num_experts_per_tok’]

load_balancing_loss#

The calculated auxiliary loss to encourage balanced load across experts.

Type

float | None

router_z_loss#

The calculated auxiliary loss to encourage small router logits, promoting stability.

Type

float | None

expert_utilization#

The fraction of experts that were utilized (i.e., received at least one token).

Type

float | None

routing_entropy#

The entropy of the router probabilities, measuring routing confidence.

Type

float | None

expert_loads: Float[Array, 'num_experts']#
expert_utilization: float | None = None#
load_balancing_loss: float | None = None#
router_probs: Float[Array, 'batch_seq num_experts']#
router_z_loss: float | None = None#
routing_entropy: float | None = None#
selected_experts: Int[Array, 'batch_seq num_experts_per_tok']#
selected_weights: Float[Array, 'batch_seq num_experts_per_tok']#
class easydel.layers.moe.MoeRoutingStrategy(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: Enum

Defines the available strategies for routing tokens to experts in an MoE layer.

Each strategy determines how tokens are assigned to experts and how expert weights are computed. When using fused MoE, hooks are automatically configured based on the routing strategy to ensure correct behavior.

Attributes:

TOP_K: Standard top-k routing with weight normalization (softmax).
  • Each token is routed to the k experts with the highest router logits.

  • Expert weights are normalized by their sum (softmax normalization).

  • Default hook: select_hook normalizes weights so they sum to 1.0.

  • Use case: Most common approach, works well for balanced expert utilization.

TOP_K_NDIV: Top-k routing WITHOUT weight normalization.
  • Each token is routed to the k experts with the highest router logits.

  • Expert weights are NOT normalized (raw logit values used as weights).

  • Default hook: select_hook passes weights through unchanged.

  • Use case: When you want to use raw logits directly for expert combination.

SWITCH: Switch Transformer-style routing (token -> 1 expert only).
  • Each token is routed to ONLY the single top-1 expert.

  • Expert weight is enforced as exactly 1.0 (hard assignment).

  • Default hook: select_hook sets weights to 1.0 (hard gating).

  • Use case: Sparse, efficient routing with reduced computational cost.

EMPTY_CHOICE: Expert Choice routing (expert -> k tokens).
  • INVERTED routing: each expert selects its top-k tokens.

  • Better load balancing compared to top-k token routing.

  • Default hook: select_hook uses uniform weights (1/k per expert).

  • Use case: Scenarios requiring strict expert load balancing.

HASH: Hash-based routing (token -> expert by token_id % num_experts).
  • Simple deterministic routing based on token IDs.

  • All experts receive equal number of tokens.

  • Default hook: select_hook uses uniform weights (1/k per expert).

  • Use case: Debugging, baseline comparisons, or fully deterministic execution.

Hook Auto-Configuration:

When using MoeFusedHooks with fused MoE execution, the select_hook is automatically configured based on the routing strategy if not explicitly set:

  • TOP_K: Normalize weights → sum to 1.0 (probability distribution)

  • TOP_K_NDIV: Passthrough → raw weights unchanged

  • SWITCH: Hard gating → weights = 1.0

  • EMPTY_CHOICE: Uniform → weights = 1/k

  • HASH: Uniform → weights = 1/k

Custom hooks can override the defaults by setting them on self.moe_hooks before calling the MoE layer.

EMPTY_CHOICE = 'expert_choice'#
HASH = 'hash'#
SWITCH = 'switch'#
TOP_K = 'top_k'#
TOP_K_NDIV = 'top_k_ndiv'#
class easydel.layers.moe.ParallelMoELinear(*args: Any, **kwargs: Any)[source]#

Bases: Module

A batched linear transformation layer for Mixture of Experts (MoE) models.

This layer applies separate linear transformations for each expert in a MoE setup. The inputs are assumed to be sorted and grouped by expert, with group_sizes specifying how many tokens belong to each expert. It supports: - Ragged Matrix Multiplication via jax.lax.ragged_dot_general. - Grouped Matrix Multiplication (GMM) via a Pallas kernel for TPUs.

Can optionally integrate with a PartitionManager to shard parameters and use shard_map for distributed execution.

Distributed Execution:

This layer supports multiple parallelism strategies:

  • Expert Parallelism (EP): Partition experts across devices on the expert axis

  • Tensor Parallelism (TP): Partition weight matrices within each expert

  • Data Parallelism (DP): Replicate across data batches

  • Row/Column Parallelism: Control which dimension is partitioned (input vs output)

The sharding strategy is controlled by: 1. direction: “row” or “column” determines which dimension is partitioned 2. use_expert_tensor_mode: Whether experts are on TP axis (True) or EP axis (False) 3. partition_manager: Provides mesh and axis resolution for sharding

num_experts#

Number of experts.

in_features#

Input feature dimension.

out_features#

Output feature dimension.

out_first#

If True, kernel shape is (num_experts, out_features, in_features); otherwise (num_experts, in_features, out_features).

dtype#

Data type for computation. None means inherits from inputs.

param_dtype#

Data type for parameters (weights, biases).

kernel_init#

Initializer function for the kernel weights.

bias_init#

Initializer function for the bias.

kernel#

Weight matrix parameter for the transformation. Shape: (num_experts, out_features, in_features) if out_first else (num_experts, in_features, out_features).

bias#

Optional bias parameter. Shape: (num_experts, out_features) if out_first else (num_experts, in_features). None if use_bias=False.

partition_manager#

Handles sharding of parameters for distributed execution.

_direction#

Sharding direction for ALT sharding (“row”, “column”, or None).

Type

Optional[Literal[‘row’, ‘column’]]

Example

>>> from easydel.layers.moe import ParallelMoELinear
>>> from flax import nnx as nn
>>>
>>> # Create a column-parallel MoE linear layer
>>> layer = ParallelMoELinear(
...     num_experts=8,
...     in_features=768,
...     out_features=3072,
...     direction="column",
...     rngs=rngs
... )
>>>
>>> # Inputs are sorted tokens grouped by expert
>>> sorted_tokens = jnp.ones((1024, 768))  # 1024 tokens, 768 features
>>> group_sizes = jnp.array([128, 132, 125, 130, 127, 129, 126, 127])  # per expert
>>> sorted_experts = jnp.repeat(jnp.arange(8), group_sizes)
>>>
>>> # Apply expert FFN
>>> output = layer(sorted_tokens, group_sizes, sorted_experts)
>>> # output.shape = (1024, 3072)
property alt_sharding: jax.sharding.PartitionSpec | None#

Returns the ALT (Alternative) sharding configuration for this layer.

ALT sharding provides pre-defined sharding patterns for common parallelism strategies, simplifying the configuration of distributed execution.

property alt_sharding_axis: list[str] | None#

Returns the axis names for ALT sharding configuration.

Returns

List of axis names (e.g., [“expert”, “tp”, “dp”]) for the configured ALT sharding pattern, or None if no ALT sharding is configured.

property can_use_shard_map: bool#

Checks if this layer can use shard_map for distributed execution.

Returns

True if both a partition manager and parallelism direction are configured, indicating the layer is ready for distributed execution with shard_map.

property direction: Optional[Literal['row', 'column']]#

Returns the parallelism direction for this layer.

Returns

“row” for row-wise parallelism (input dimension partitioned), “column” for column-wise parallelism (output dimension partitioned), or None if no parallelism direction is set.

property expert_axis: str#

Semantic axis name representing the expert dimension.

class easydel.layers.moe.RowParallelMoELinear(*args: Any, **kwargs: Any)[source]#

Bases: ParallelMoELinear

Row-parallel variant of ParallelMoELinear.

This class specializes ParallelMoELinear for row-wise parallelism, where the input dimension is partitioned across devices. In row parallelism, each device holds a subset of input features and computes partial results that are then reduced across devices.

The weight matrix is partitioned along the input dimension (rows), and an all-reduce operation is performed after the matrix multiplication to combine partial results.

_direction#

Fixed to “row” to indicate row-wise parallelism.

Type

Optional[Literal[‘row’, ‘column’]]

Example

>>> # Create a row-parallel MoE linear layer
>>> layer = RowParallelMoELinear(
...     num_experts=8,
...     in_features=768,
...     out_features=3072,
...     rngs=rngs
... )
easydel.layers.moe.get_moe_partition_spec(partition_manager: PartitionManager, direction: Literal['row', 'column'], tensors_are_expert: bool, is_bias: bool = False, fsdp_is_ep_bound: bool = True, sp_is_ep_bound: bool = True, module_view: bool = False) PartitionSpec[source]#