easydel.layers.linear#

Linear layers with parallel and distributed computation support.

Provides optimized linear layers with support for model parallelism, tensor parallelism, and various sharding strategies for distributed training.

Classes:

ParallelLinear: Linear layer with tensor/model parallelism support Linear: Standard linear layer (alias for ParallelLinear)

Functions:

get_sharding: Extract sharding specification from an array get_output_partition_spec: Calculate output sharding for matmul get_matmul_output_sharding: Determine output sharding from input specs

Key Features:
  • Automatic sharding and gathering for distributed training

  • Support for various matrix multiplication methods

  • Mixed precision support

  • Efficient initialization strategies

  • Integration with JAX’s shard_map

Example

>>> from easydel.layers import ParallelLinear
>>> # Create a parallel linear layer
>>> layer = ParallelLinear(
...     features=768,
...     use_bias=True,
...     gather_output=False,
...     axis_name="model",
...     dtype=jnp.bfloat16
... )
>>> output = layer(input_tensor)
class easydel.layers.linear.ColumnParallelLinear(*args: Any, **kwargs: Any)[source]#

Bases: ParallelLinear

class easydel.layers.linear.ParallelLinear(*args: Any, **kwargs: Any)[source]#

Bases: Module

A Linear layer with optional parallelism.

Behaves like nnx.Linear but can distribute computation and parameters across devices based on the TensorParallelConfig.

in_features#

Number of input features.

out_features#

Number of output features.

use_bias#

Whether to include a bias term. Default is True.

dtype#

The dtype of the computation (defaults to inferred from input).

param_dtype#

The dtype of the parameters. Default is float32.

precision#

JAX precision for the dot product. Default is None.

kernel_init#

Initializer for the kernel weights.

bias_init#

Initializer for the bias.

parallel_config#

Configuration for tensor parallelism. If None, the layer behaves like a standard non-parallel Linear layer.

native_forward(inputs: Shaped[Array, '... in_features'], w: jax.Array | None = None) Shaped[Array, '... out_features'][source]#

Applies the linear transformation with optional tensor parallelism.

Parameters

inputs – The input array. Shape: (…, in_features). For ROW parallelism, the input is expected to be sharded along the feature dimension (axis_name).

Returns

The transformed output array. Shape: (…, out_features). Output is sharded for COLUMN parallelism if reduce_scatter_output is True. Otherwise, output is fully replicated.

class easydel.layers.linear.RowParallelLinear(*args: Any, **kwargs: Any)[source]#

Bases: ParallelLinear

class easydel.layers.linear.TensorParallelConfig(mesh: Mesh = None, axis_name: str = 'tp', matmul_method: None = None, reduce_output: bool = False, reduce_scatter_output: bool = False)[source]#

Bases: object

Configuration for Tensor Parallelism.

mesh#

The JAX device mesh.

Type

jax._src.mesh.Mesh

axis_name#

The name of the mesh axis to use for tensor parallelism.

Type

str

matmul_type#

The type of matmul (MatrixMultiplyMethod).

reduce_scatter_output#

If True and parallel_type is COLUMN, use reduce-scatter instead of all-gather for the output. This keeps the output sharded (useful for sequence parallelism or subsequent RowParallel layers). Defaults to False.

Type

bool

axis_name: str = 'tp'#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

matmul_method: None = None#
mesh: Mesh = None#
reduce_output: bool = False#
reduce_scatter_output: bool = False#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

easydel.layers.linear.get_matmul_output_sharding(lhs_pspec: jax.sharding.PartitionSpec | None, rhs_pspec: jax.sharding.PartitionSpec | None) PartitionSpec[source]#

Determine output sharding for matrix multiplication.

Calculates the output PartitionSpec based on input partition specs, following matrix multiplication rules where contracting dimensions are reduced and non-contracting dimensions determine output sharding.

For X @ W: - Contracting dimensions are reduced during matmul - Non-contracting dimensions determine output sharding - Ensures no duplicate sharding dimensions in output

Parameters
  • lhs_pspec – PartitionSpec for left-hand side matrix.

  • rhs_pspec – PartitionSpec for right-hand side matrix.

Returns

Output PartitionSpec for the multiplication result.

  • Ensures correct output dimensionality with None padding if needed

Parameters
  • lhs_pspec – PartitionSpec for the left-hand side matrix X

  • rhs_pspec – PartitionSpec for the right-hand side matrix W

Returns

PartitionSpec for the output of X @ W

easydel.layers.linear.get_output_partition_spec(lhs: Shaped[Array, '...'], rhs: Shaped[Array, '...'], method: MatrixMultiplyMethod, axis_name: str) Ps | None[source]#

Calculate output partition spec for matrix multiplication.

Determines the appropriate output sharding based on input sharding and the matrix multiplication method used.

Parameters
  • lhs – Left-hand side array (inputs).

  • rhs – Right-hand side array (weights).

  • method – Matrix multiplication method.

  • axis_name – Axis name for sharding.

Returns

Output partition specification for the result.

easydel.layers.linear.get_sharding(arr: Shaped[Array, '...']) jax.sharding.PartitionSpec | None[source]#

Get the sharding specification of an array.

Extracts the PartitionSpec from a sharded JAX array.

Parameters

arr – Array to get sharding from.

Returns

PartitionSpec of the array, or None if not sharded.