# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
# Copyright 2024 The Improved Version Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Linear layers with parallel and distributed computation support.
Provides optimized linear layers with support for model parallelism,
tensor parallelism, and various sharding strategies for distributed training.
Classes:
ParallelLinear: Linear layer with tensor/model parallelism support
Linear: Standard linear layer (alias for ParallelLinear)
Functions:
get_sharding: Extract sharding specification from an array
get_output_partition_spec: Calculate output sharding for matmul
get_matmul_output_sharding: Determine output sharding from input specs
Key Features:
- Automatic sharding and gathering for distributed training
- Support for various matrix multiplication methods
- Mixed precision support
- Efficient initialization strategies
- Integration with JAX's shard_map
Example:
>>> from easydel.layers import ParallelLinear
>>> # Create a parallel linear layer
>>> layer = ParallelLinear(
... features=768,
... use_bias=True,
... gather_output=False,
... axis_name="model",
... dtype=jnp.bfloat16
... )
>>> output = layer(input_tensor)
"""
from __future__ import annotations
import typing as tp
import jax.numpy as jnp
from eformer import escale as es
from eformer.pytree import auto_pytree
from flax import nnx as nn
from flax.nnx.nn.dtypes import promote_dtype
from jax import lax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as Ps
from jaxtyping import Array, Shaped
Dtype = jnp.dtype
Initializer = nn.initializers.Initializer
PrecisionLike = lax.PrecisionLike
Shape = tp.Sequence[int]
AxisNames = str | tp.Sequence[str] | tuple[str, ...]
# Default initializers
default_kernel_init = nn.initializers.lecun_normal()
default_bias_init = nn.initializers.zeros
[docs]def get_sharding(arr: Shaped[Array, "..."]) -> Ps | None:
"""Get the sharding specification of an array.
Extracts the PartitionSpec from a sharded JAX array.
Args:
arr: Array to get sharding from.
Returns:
PartitionSpec of the array, or None if not sharded.
"""
sharding: tp.Any | None = getattr(arr, "sharding", None)
has_sharding: bool = sharding is not None
result: Ps | None
if has_sharding:
spec: Ps = sharding.spec
result = spec
else:
result = None
return result
[docs]def get_output_partition_spec(
lhs: Shaped[Array, "..."],
rhs: Shaped[Array, "..."],
method: "MatrixMultiplyMethod", # noqa #type:ignore
axis_name: str,
) -> Ps | None:
"""Calculate output partition spec for matrix multiplication.
Determines the appropriate output sharding based on input
sharding and the matrix multiplication method used.
Args:
lhs: Left-hand side array (inputs).
rhs: Right-hand side array (weights).
method: Matrix multiplication method.
axis_name: Axis name for sharding.
Returns:
Output partition specification for the result.
"""
from jax.sharding import PartitionSpec as P
lhs_spec: Ps | None = get_sharding(lhs)
rhs_spec: Ps | None = get_sharding(rhs)
lhs_is_none: bool = lhs_spec is None
rhs_is_none: bool = rhs_spec is None
either_none: bool = lhs_is_none or rhs_is_none
if either_none:
return None
lhs_ndim: int = lhs.ndim
is_2d: bool = lhs_ndim == 2
result: Ps
if is_2d:
rhs_spec_1: str | None = rhs_spec[1]
rhs_spec_0: str | None = rhs_spec[0]
result = P(rhs_spec_1, rhs_spec_0)
else:
num_none: int = lhs_ndim - 1
none_tuple: tuple[None, ...] = (None,) * num_none
axis_tuple: tuple[str] = (axis_name,)
combined_tuple: tuple[None | str, ...] = none_tuple + axis_tuple
result = P(*combined_tuple)
return result
[docs]def get_matmul_output_sharding(lhs_pspec: Ps | None, rhs_pspec: Ps | None) -> Ps:
"""Determine output sharding for matrix multiplication.
Calculates the output PartitionSpec based on input partition specs,
following matrix multiplication rules where contracting dimensions
are reduced and non-contracting dimensions determine output sharding.
For X @ W:
- Contracting dimensions are reduced during matmul
- Non-contracting dimensions determine output sharding
- Ensures no duplicate sharding dimensions in output
Args:
lhs_pspec: PartitionSpec for left-hand side matrix.
rhs_pspec: PartitionSpec for right-hand side matrix.
Returns:
Output PartitionSpec for the multiplication result.
- Ensures correct output dimensionality with None padding if needed
Args:
lhs_pspec: PartitionSpec for the left-hand side matrix X
rhs_pspec: PartitionSpec for the right-hand side matrix W
Returns:
PartitionSpec for the output of X @ W
"""
lhs_is_none: bool = lhs_pspec is None
rhs_is_none: bool = rhs_pspec is None
either_none: bool = lhs_is_none or rhs_is_none
if either_none:
empty_spec: Ps = Ps()
return empty_spec
lhs_length: int = len(lhs_pspec)
lhs_gt_one: bool = lhs_length > 1
lhs_output_dims: tuple
if lhs_gt_one:
lhs_output_dims = lhs_pspec[:-1]
else:
lhs_output_dims = ()
rhs_length: int = len(rhs_pspec)
rhs_ge_two: bool = rhs_length >= 2
rhs_output_dims: tuple
if rhs_ge_two:
rhs_last: str | None = rhs_pspec[-1]
rhs_output_dims = (rhs_last,)
else:
rhs_is_empty: bool = not rhs_pspec
if rhs_is_empty:
rhs_output_dims = ()
else:
rhs_last_item: str | None = rhs_pspec[-1]
rhs_output_dims = (rhs_last_item,)
all_shard_dims: set[str] = set()
output_dims: list[str | None | tuple] = []
# Process LHS dimensions
for dim in lhs_output_dims:
dim_is_tuple: bool = isinstance(dim, tuple)
if dim_is_tuple:
filtered_tuple: tuple = tuple(d for d in dim if d not in all_shard_dims)
for d in dim:
all_shard_dims.add(d)
filtered_is_nonempty: bool = bool(filtered_tuple)
dim_is_nonempty: bool = bool(dim)
if filtered_is_nonempty:
output_dims.append(filtered_tuple)
elif dim_is_nonempty:
output_dims.append(None)
else:
dim_not_in_set: bool = dim not in all_shard_dims
if dim_not_in_set:
output_dims.append(dim)
all_shard_dims.add(dim)
else:
output_dims.append(None)
# Process RHS dimensions
for dim in rhs_output_dims:
dim_is_tuple_rhs: bool = isinstance(dim, tuple)
if dim_is_tuple_rhs:
filtered_tuple_rhs: tuple = tuple(d for d in dim if d not in all_shard_dims)
for d in dim:
all_shard_dims.add(d)
filtered_rhs_nonempty: bool = bool(filtered_tuple_rhs)
dim_rhs_nonempty: bool = bool(dim)
if filtered_rhs_nonempty:
output_dims.append(filtered_tuple_rhs)
elif dim_rhs_nonempty:
output_dims.append(None)
else:
dim_rhs_not_in_set: bool = dim not in all_shard_dims
if dim_rhs_not_in_set:
output_dims.append(dim)
all_shard_dims.add(dim)
else:
output_dims.append(None)
# Pad with None to match expected dimensionality
output_dims_length: int = len(output_dims)
lhs_pspec_length: int = len(lhs_pspec)
target_length: int = lhs_pspec_length - 1 + 1
needs_padding: bool = output_dims_length < target_length
while needs_padding:
output_dims.append(None)
output_dims_length = len(output_dims)
needs_padding = output_dims_length < target_length
result_spec: Ps = Ps(*output_dims)
return result_spec
[docs]@auto_pytree
class TensorParallelConfig:
"""Configuration for Tensor Parallelism.
Attributes:
mesh: The JAX device mesh.
axis_name: The name of the mesh axis to use for tensor parallelism.
matmul_type: The type of matmul (MatrixMultiplyMethod).
reduce_scatter_output: If True and parallel_type is COLUMN,
use reduce-scatter instead of all-gather for the output.
This keeps the output sharded (useful for sequence parallelism
or subsequent RowParallel layers). Defaults to False.
"""
mesh: Mesh = None
axis_name: str = "tp"
matmul_method: None = None
reduce_output: bool = False
reduce_scatter_output: bool = False
def __post_init__(self):
msg: str | None = None
has_matmul_method: bool = self.matmul_method is not None
if has_matmul_method:
mesh_is_none: bool = self.mesh is None
if mesh_is_none:
self.mesh = es.get_incontext_mesh()
axis_names: tuple[str, ...] = self.mesh.axis_names
axis_not_in_mesh: bool = self.axis_name not in axis_names
if axis_not_in_mesh:
axis_name_str: str = self.axis_name
axis_names_str: str = str(axis_names)
msg = f"axis_name '{axis_name_str}' not found in mesh axis names: {axis_names_str}"
has_error: bool = msg is not None
if has_error:
raise ValueError(msg)
[docs]class ParallelLinear(nn.Module):
"""A Linear layer with optional parallelism.
Behaves like `nnx.Linear` but can distribute computation and parameters
across devices based on the `TensorParallelConfig`.
Attributes:
in_features: Number of input features.
out_features: Number of output features.
use_bias: Whether to include a bias term. Default is True.
dtype: The dtype of the computation (defaults to inferred from input).
param_dtype: The dtype of the parameters. Default is float32.
precision: JAX precision for the dot product. Default is None.
kernel_init: Initializer for the kernel weights.
bias_init: Initializer for the bias.
parallel_config: Configuration for tensor parallelism. If None,
the layer behaves like a standard non-parallel Linear layer.
"""
_direction: tp.Literal["row", "column"] | None = None
def __init__(
self,
in_features: int,
out_features: int,
*,
scale: float | tp.Literal["fan_in", "fan_out"] = 1.0,
use_bias: bool = True,
dtype: Dtype | None = None,
param_dtype: Dtype = jnp.float32,
precision: PrecisionLike = None,
kernel_init: Initializer = default_kernel_init,
bias_init: Initializer = default_bias_init,
parallel_config: TensorParallelConfig | None = None,
rngs: nn.Rngs | None = None,
):
rngs_computed: nn.Rngs
if rngs is None:
rngs_computed = nn.Rngs(0)
else:
rngs_computed = rngs
scale_computed: float
scale_is_fan_in: bool = scale == "fan_in"
scale_is_fan_out: bool = scale == "fan_out"
if scale_is_fan_in:
scale_computed = in_features**-0.5
elif scale_is_fan_out:
scale_computed = out_features**-0.5
else:
scale_computed = scale
scale_is_one: bool = scale_computed != 1.0
if scale_is_one:
def _scale_operator(x: Array) -> Array:
scaled: Array = x * scale_computed
return scaled
else:
def _scale_operator(x: Array) -> Array:
return x
self._scale_operator: tp.Callable[[Array], Array] = _scale_operator
self.in_features: int = in_features
self.out_features: int = out_features
self.use_bias: bool = use_bias
self.dtype: Dtype | None = dtype
self.param_dtype: Dtype = param_dtype
self.precision: PrecisionLike = precision
self.kernel_init: Initializer = kernel_init
self.bias_init: Initializer = bias_init
self.parallel_config: TensorParallelConfig | None = parallel_config
self.rngs: nn.Rngs = rngs_computed
out_features_is_sequence: bool = isinstance(out_features, tp.Sequence)
tp_merged: int
if out_features_is_sequence:
tp_merged = len(out_features)
else:
tp_merged = 1
self.tp_merged: int = tp_merged
tp_merged_gt_one: bool = self.tp_merged > 1
out_features_sum: int
if tp_merged_gt_one:
out_features_sum = sum(out_features)
else:
out_features_sum = out_features
kernel_key: tp.Any = rngs_computed.params()
kernel_shape: tuple[int, int] = (in_features, out_features_sum)
kernel_initialized: Array = kernel_init(kernel_key, kernel_shape, param_dtype)
self.kernel: nn.Param = nn.Param(kernel_initialized)
if use_bias:
bias_key: tp.Any = rngs_computed.params()
bias_shape: tuple[int] = (out_features,)
bias_initialized: Array = bias_init(bias_key, bias_shape, param_dtype)
self.bias: nn.Param | None = nn.Param(bias_initialized)
else:
self.bias = None
self.distributed_matmul: tp.Any | None = None
# if parallel_config is not None and parallel_config.matmul_method is not None:
# self.distributed_matmul = create_distributed_matmul(
# parallel_config.matmul_method,
# parallel_config.axis_name,
# )
# def collective_forward(
# self,
# inputs: Shaped[Array, "... in_features"],
# w: Array | None = None,
# ) -> Shaped[Array, "... out_features"]:
# kernel = self.kernel.value if w is None else w
# bias = self.bias.value if self.use_bias else None
# if bias is not None:
# inputs, kernel, bias = promote_dtype((inputs, kernel, bias), dtype=self.dtype)
# else:
# inputs, kernel = promote_dtype((inputs, kernel), dtype=self.dtype)
# # Ensure inputs are 2D
# orig_shape = inputs.shape
# inputs_2d = inputs.reshape(-1, inputs.shape[-1])
# # Get partition specs
# input_spec = get_sharding(inputs_2d)
# kernel_spec = get_sharding(kernel)
# output_spec = get_output_partition_spec(
# inputs_2d,
# kernel,
# self.parallel_config.matmul_method,
# self.parallel_config.axis_name,
# )
# if self.parallel_config.matmul_method == MatrixMultiplyMethod.REDUCE_SCATTER:
# kernel = prepare_matrix_for_reduce_scatter(
# kernel,
# self.parallel_config.mesh,
# self.parallel_config.axis_name,
# )
# elif self.parallel_config.matmul_method == MatrixMultiplyMethod.ALL_GATHER:
# kernel = prepare_matrix_for_all_gather(
# kernel,
# self.parallel_config.mesh,
# self.parallel_config.axis_name,
# )
# output_2d = shard_map(
# self.distributed_matmul,
# mesh=self.parallel_config.mesh,
# in_specs=(input_spec, kernel_spec),
# out_specs=output_spec,
# check_vma=False,
# )(inputs_2d, kernel)
# output = output_2d.reshape((*orig_shape[:-1], self.out_features))
# output = self._scale_operator(output)
# if bias is not None:
# output = output + jnp.reshape(bias, (1,) * (output.ndim - 1) + (-1,))
# return output
[docs] def native_forward(
self,
inputs: Shaped[Array, "... in_features"],
w: Array | None = None,
) -> Shaped[Array, "... out_features"]:
"""Applies the linear transformation with optional tensor parallelism.
Args:
inputs: The input array. Shape: (..., in_features).
For ROW parallelism, the input is expected to be sharded
along the feature dimension (`axis_name`).
Returns:
The transformed output array.
Shape: (..., out_features).
Output is sharded for COLUMN parallelism if `reduce_scatter_output` is True.
Otherwise, output is fully replicated.
"""
w_is_none: bool = w is None
kernel: Array
if w_is_none:
kernel = self.kernel.value
else:
kernel = w
has_bias: bool = self.use_bias
bias: Array | None
if has_bias:
bias = self.bias.value
else:
bias = None
bias_is_not_none: bool = bias is not None
inputs_promoted: Array
kernel_promoted: Array
bias_promoted: Array | None
if bias_is_not_none:
inputs_promoted, kernel_promoted, bias_promoted = promote_dtype((inputs, kernel, bias), dtype=self.dtype)
else:
inputs_promoted, kernel_promoted = promote_dtype((inputs, kernel), dtype=self.dtype)
bias_promoted = None
inputs_ndim: int = inputs_promoted.ndim
inputs_gt_one_dim: bool = inputs_ndim > 1
subscript: str
if inputs_gt_one_dim:
subscript = "...ik,...kj->...ij"
else:
subscript = "...k,...kj->...j"
y: Shaped[Array, "... out_features"] = jnp.einsum(
subscript,
inputs_promoted,
kernel_promoted,
precision=self.precision,
optimize=True,
)
y_scaled: Shaped[Array, "... out_features"] = self._scale_operator(y)
y_final: Shaped[Array, "... out_features"]
if bias_promoted is not None:
y_ndim: int = y_scaled.ndim
num_ones: int = y_ndim - 1
ones_tuple: tuple[int, ...] = (1,) * num_ones
final_dim: tuple[int] = (-1,)
reshape_spec: tuple[int, ...] = ones_tuple + final_dim
bias_reshaped: Array = jnp.reshape(bias_promoted, reshape_spec)
y_final = y_scaled + bias_reshaped
else:
y_final = y_scaled
return y_final
def __call__(
self,
inputs: Shaped[Array, "... in_features"],
w: Array | None = None,
) -> Shaped[Array, "... out_features"]:
# if self.distributed_matmul is None:
return self.native_forward(inputs=inputs, w=w)
# return self.collective_forward(inputs=inputs, w=w)
[docs]class RowParallelLinear(ParallelLinear):
_direction: tp.Literal["row", "column"] | None = "row"
[docs]class ColumnParallelLinear(ParallelLinear):
_direction: tp.Literal["row", "column"] | None = "column"