Source code for easydel.layers.linear

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
# Copyright 2024 The Improved Version Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp

import jax.numpy as jnp
from eformer import escale as es
from eformer.pytree import auto_pytree
from flax import nnx as nn
from flax.nnx.nn.dtypes import promote_dtype
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as Ps

from easydel.kernels.collective_matmul import (
	MatrixMultiplyMethod,
	create_distributed_matmul,
	prepare_matrix_for_all_gather,
	prepare_matrix_for_reduce_scatter,
)

# Type Aliases
Array = jnp.ndarray
Dtype = jnp.dtype
Initializer = nn.initializers.Initializer
PrecisionLike = lax.PrecisionLike
Shape = tp.Sequence[int]
AxisNames = tp.Union[str, tp.Sequence[str], tp.Tuple[str, ...]]

# Default initializers
default_kernel_init = nn.initializers.lecun_normal()
default_bias_init = nn.initializers.zeros


[docs]def get_sharding(arr: Array) -> tp.Optional[Ps]: """Gets the sharding of an array. Args: arr: Array to get sharding from. Returns: Sharding of the array. """ sharding = getattr(arr, "sharding", None) if sharding is not None: return sharding.spec return None
[docs]def get_output_partition_spec( lhs: Array, rhs: Array, method: MatrixMultiplyMethod, axis_name: str ): """Calculate output partition spec based on input arrays and matmul method. Args: lhs: Left-hand side array (inputs) rhs: Right-hand side array (weights) method: Matrix multiplication method axis_name: Axis name for sharding Returns: Output partition specification """ from jax.sharding import PartitionSpec as P lhs_spec = get_sharding(lhs) rhs_spec = get_sharding(rhs) if lhs_spec is None or rhs_spec is None: return None if lhs.ndim == 2: return P(rhs_spec[1], rhs_spec[0]) else: return P(*(None,) * (lhs.ndim - 1) + (axis_name,))
[docs]def get_matmul_output_sharding(lhs_pspec, rhs_pspec): """ Determine the output sharding PartitionSpec for a matrix multiplication based on the partition specs of the input matrices. For matrix multiplication X @ W: - The contracting dimensions get reduced during matmul - The non-contracting dimensions of X and W determine the output sharding - Ensures no duplicate sharding dimensions in the output - Ensures correct output dimensionality with None padding if needed Args: lhs_pspec: PartitionSpec for the left-hand side matrix X rhs_pspec: PartitionSpec for the right-hand side matrix W Returns: PartitionSpec for the output of X @ W """ if lhs_pspec is None or rhs_pspec is None: return Ps() # Extract non-contracting dimensions from LHS (all except the last) lhs_output_dims = lhs_pspec[:-1] if len(lhs_pspec) > 1 else () # Extract non-contracting dimensions from RHS (all except second-to-last) if len(rhs_pspec) >= 2: rhs_output_dims = (rhs_pspec[-1],) else: rhs_output_dims = (rhs_pspec[-1],) if rhs_pspec else () # Collect all sharding dimensions to check for duplicates all_shard_dims = set() output_dims = [] # Process LHS dimensions first for dim in lhs_output_dims: if isinstance(dim, tuple): # Check each nested dimension filtered_tuple = tuple(d for d in dim if d not in all_shard_dims) for d in dim: all_shard_dims.add(d) if filtered_tuple: # Only add if there's something left output_dims.append(filtered_tuple) elif dim: # If all dimensions were filtered out but original dim wasn't empty output_dims.append(None) # Add None as placeholder else: # Single dimension if dim not in all_shard_dims: output_dims.append(dim) all_shard_dims.add(dim) else: output_dims.append(None) # Add None if dimension is duplicate # Then process RHS dimensions for dim in rhs_output_dims: if isinstance(dim, tuple): # Check each nested dimension filtered_tuple = tuple(d for d in dim if d not in all_shard_dims) for d in dim: all_shard_dims.add(d) if filtered_tuple: # Only add if there's something left output_dims.append(filtered_tuple) elif dim: # If all dimensions were filtered out but original dim wasn't empty output_dims.append(None) # Add None as placeholder else: # Single dimension if dim not in all_shard_dims: output_dims.append(dim) all_shard_dims.add(dim) else: output_dims.append(None) # Add None if dimension is duplicate # Calculate expected output dimensionality # For X @ W where X is 3D and W is 2D, output should be 3D expected_dims = len(lhs_pspec) - 1 + 1 # All LHS dims except contracting + 1 from RHS # Add None padding if needed while len(output_dims) < expected_dims: output_dims.append(None) return Ps(*output_dims)
[docs]@auto_pytree class TensorParallelConfig: """Configuration for Tensor Parallelism. Attributes: mesh: The JAX device mesh. axis_name: The name of the mesh axis to use for tensor parallelism. matmul_type: The type of matmul (MatrixMultiplyMethod). reduce_scatter_output: If True and parallel_type is COLUMN, use reduce-scatter instead of all-gather for the output. This keeps the output sharded (useful for sequence parallelism or subsequent RowParallel layers). Defaults to False. """ mesh: Mesh = None axis_name: str = "tp" matmul_method: tp.Optional[MatrixMultiplyMethod] = None reduce_output: bool = False reduce_scatter_output: bool = False def __post_init__(self): msg = None if self.matmul_method is not None: if self.mesh is None: self.mesh = es.get_incontext_mesh() if self.axis_name not in self.mesh.axis_names: msg = f"axis_name '{self.axis_name}' not found in mesh axis names: {self.mesh.axis_names}" if msg is not None: raise ValueError(msg)
[docs]class ParallelLinear(nn.Module): """A Linear layer with optional parallelism. Behaves like `nnx.Linear` but can distribute computation and parameters across devices based on the `TensorParallelConfig`. Attributes: in_features: Number of input features. out_features: Number of output features. use_bias: Whether to include a bias term. Default is True. dtype: The dtype of the computation (defaults to inferred from input). param_dtype: The dtype of the parameters. Default is float32. precision: JAX precision for the dot product. Default is None. kernel_init: Initializer for the kernel weights. bias_init: Initializer for the bias. parallel_config: Configuration for tensor parallelism. If None, the layer behaves like a standard non-parallel Linear layer. """ def __init__( self, in_features: int, out_features: int, *, use_bias: bool = True, dtype: tp.Optional[Dtype] = None, param_dtype: Dtype = jnp.float32, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, bias_init: Initializer = default_bias_init, parallel_config: tp.Optional[TensorParallelConfig] = None, rngs: tp.Optional[nn.Rngs] = None, ): if rngs is None: rngs = nn.Rngs(0) self.in_features = in_features self.out_features = out_features self.use_bias = use_bias self.dtype = dtype self.param_dtype = param_dtype self.precision = precision self.kernel_init = kernel_init self.bias_init = bias_init self.parallel_config = parallel_config self.rngs = rngs self.tp_merged = len(out_features) if isinstance(out_features, tp.Sequence) else 1 out_features_sum = sum(out_features) if self.tp_merged > 1 else out_features kernel_key = rngs.params() kernel_shape = (in_features, out_features_sum) self.kernel = nn.Param(kernel_init(kernel_key, kernel_shape, param_dtype)) if use_bias: bias_key = rngs.params() bias_shape = (out_features,) self.bias = nn.Param(bias_init(bias_key, bias_shape, param_dtype)) else: self.bias = None self.distributed_matmul = None if parallel_config is not None and parallel_config.matmul_method is not None: self.distributed_matmul = create_distributed_matmul( parallel_config.matmul_method, parallel_config.axis_name, )
[docs] def collective_forward(self, inputs: Array) -> Array: kernel = self.kernel.value bias = self.bias.value if self.use_bias else None if bias is not None: inputs, kernel, bias = promote_dtype((inputs, kernel, bias), dtype=self.dtype) else: inputs, kernel = promote_dtype((inputs, kernel), dtype=self.dtype) # Ensure inputs are 2D orig_shape = inputs.shape inputs_2d = inputs.reshape(-1, inputs.shape[-1]) # Get partition specs input_spec = get_sharding(inputs_2d) kernel_spec = get_sharding(kernel) output_spec = get_output_partition_spec( inputs_2d, kernel, self.parallel_config.matmul_method, self.parallel_config.axis_name, ) if self.parallel_config.matmul_method == MatrixMultiplyMethod.REDUCE_SCATTER: kernel = prepare_matrix_for_reduce_scatter( kernel, self.parallel_config.mesh, self.parallel_config.axis_name, ) elif self.parallel_config.matmul_method == MatrixMultiplyMethod.ALL_GATHER: kernel = prepare_matrix_for_all_gather( kernel, self.parallel_config.mesh, self.parallel_config.axis_name, ) output_2d = shard_map( self.distributed_matmul, mesh=self.parallel_config.mesh, in_specs=(input_spec, kernel_spec), out_specs=output_spec, check_rep=False, )(inputs_2d, kernel) output = output_2d.reshape(orig_shape[:-1] + (self.out_features,)) if bias is not None: output = output + jnp.reshape(bias, (1,) * (output.ndim - 1) + (-1,)) return output
[docs] def native_forward(self, inputs: Array) -> Array: """Applies the linear transformation with optional tensor parallelism. Args: inputs: The input array. Shape: (..., in_features). For ROW parallelism, the input is expected to be sharded along the feature dimension (`axis_name`). Returns: The transformed output array. Shape: (..., out_features). Output is sharded for COLUMN parallelism if `reduce_scatter_output` is True. Otherwise, output is fully replicated. """ kernel = self.kernel.value bias = self.bias.value if self.use_bias else None if bias is not None: inputs, kernel, bias = promote_dtype((inputs, kernel, bias), dtype=self.dtype) else: inputs, kernel = promote_dtype((inputs, kernel), dtype=self.dtype) subscript = "...ik,...kj->...ij" if inputs.ndim > 1 else "...k,...kj->...j" y = jnp.einsum( subscript, inputs, kernel, precision=self.precision, optimize=True, ) if bias is not None: y = y + jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y
def __call__(self, inputs: Array) -> Array: if self.distributed_matmul is None: return self.native_forward(inputs=inputs) return self.collective_forward(inputs=inputs)