easydel.layers.linear#

class easydel.layers.linear.ParallelLinear(*args: Any, **kwargs: Any)[source]#

Bases: Module

A Linear layer with optional parallelism.

Behaves like nnx.Linear but can distribute computation and parameters across devices based on the TensorParallelConfig.

in_features#

Number of input features.

out_features#

Number of output features.

use_bias#

Whether to include a bias term. Default is True.

dtype#

The dtype of the computation (defaults to inferred from input).

param_dtype#

The dtype of the parameters. Default is float32.

precision#

JAX precision for the dot product. Default is None.

kernel_init#

Initializer for the kernel weights.

bias_init#

Initializer for the bias.

parallel_config#

Configuration for tensor parallelism. If None, the layer behaves like a standard non-parallel Linear layer.

collective_forward(inputs: Array) Array[source]#
native_forward(inputs: Array) Array[source]#

Applies the linear transformation with optional tensor parallelism.

Parameters

inputs – The input array. Shape: (…, in_features). For ROW parallelism, the input is expected to be sharded along the feature dimension (axis_name).

Returns

The transformed output array. Shape: (…, out_features). Output is sharded for COLUMN parallelism if reduce_scatter_output is True. Otherwise, output is fully replicated.

class easydel.layers.linear.TensorParallelConfig(mesh: Mesh = None, axis_name: str = 'tp', matmul_method: Optional[MatrixMultiplyMethod] = None, reduce_output: bool = False, reduce_scatter_output: bool = False)[source]#

Bases: object

Configuration for Tensor Parallelism.

mesh#

The JAX device mesh.

Type

jax._src.mesh.Mesh

axis_name#

The name of the mesh axis to use for tensor parallelism.

Type

str

matmul_type#

The type of matmul (MatrixMultiplyMethod).

reduce_scatter_output#

If True and parallel_type is COLUMN, use reduce-scatter instead of all-gather for the output. This keeps the output sharded (useful for sequence parallelism or subsequent RowParallel layers). Defaults to False.

Type

bool

axis_name: str = 'tp'#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

matmul_method: Optional[MatrixMultiplyMethod] = None#
mesh: Mesh = None#
reduce_output: bool = False#
reduce_scatter_output: bool = False#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

easydel.layers.linear.get_matmul_output_sharding(lhs_pspec, rhs_pspec)[source]#

Determine the output sharding PartitionSpec for a matrix multiplication based on the partition specs of the input matrices.

For matrix multiplication X @ W: - The contracting dimensions get reduced during matmul - The non-contracting dimensions of X and W determine the output sharding - Ensures no duplicate sharding dimensions in the output - Ensures correct output dimensionality with None padding if needed

Parameters
  • lhs_pspec – PartitionSpec for the left-hand side matrix X

  • rhs_pspec – PartitionSpec for the right-hand side matrix W

Returns

PartitionSpec for the output of X @ W

easydel.layers.linear.get_output_partition_spec(lhs: Array, rhs: Array, method: MatrixMultiplyMethod, axis_name: str)[source]#

Calculate output partition spec based on input arrays and matmul method.

Parameters
  • lhs – Left-hand side array (inputs)

  • rhs – Right-hand side array (weights)

  • method – Matrix multiplication method

  • axis_name – Axis name for sharding

Returns

Output partition specification

easydel.layers.linear.get_sharding(arr: Array) Optional[PartitionSpec][source]#

Gets the sharding of an array.

Parameters

arr – Array to get sharding from.

Returns

Sharding of the array.