easydel.layers.linear#
Linear layers with parallel and distributed computation support.
Provides optimized linear layers with support for model parallelism, tensor parallelism, and various sharding strategies for distributed training.
- Classes:
ParallelLinear: Linear layer with tensor/model parallelism support Linear: Standard linear layer (alias for ParallelLinear)
- Functions:
get_sharding: Extract sharding specification from an array get_output_partition_spec: Calculate output sharding for matmul get_matmul_output_sharding: Determine output sharding from input specs
- Key Features:
Automatic sharding and gathering for distributed training
Support for various matrix multiplication methods
Mixed precision support
Efficient initialization strategies
Integration with JAX’s shard_map
Example
>>> from easydel.layers import ParallelLinear
>>> # Create a parallel linear layer
>>> layer = ParallelLinear(
... features=768,
... use_bias=True,
... gather_output=False,
... axis_name="model",
... dtype=jnp.bfloat16
... )
>>> output = layer(input_tensor)
- class easydel.layers.linear.ColumnParallelLinear(*args: Any, **kwargs: Any)[source]#
Bases:
ParallelLinear
- class easydel.layers.linear.ParallelLinear(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA Linear layer with optional parallelism.
Behaves like nnx.Linear but can distribute computation and parameters across devices based on the TensorParallelConfig.
- in_features#
Number of input features.
- out_features#
Number of output features.
- use_bias#
Whether to include a bias term. Default is True.
- dtype#
The dtype of the computation (defaults to inferred from input).
- param_dtype#
The dtype of the parameters. Default is float32.
- precision#
JAX precision for the dot product. Default is None.
- kernel_init#
Initializer for the kernel weights.
- bias_init#
Initializer for the bias.
- parallel_config#
Configuration for tensor parallelism. If None, the layer behaves like a standard non-parallel Linear layer.
- native_forward(inputs: Shaped[Array, '... in_features'], w: jax.Array | None = None) Shaped[Array, '... out_features'][source]#
Applies the linear transformation with optional tensor parallelism.
- Parameters
inputs – The input array. Shape: (…, in_features). For ROW parallelism, the input is expected to be sharded along the feature dimension (axis_name).
- Returns
The transformed output array. Shape: (…, out_features). Output is sharded for COLUMN parallelism if reduce_scatter_output is True. Otherwise, output is fully replicated.
- class easydel.layers.linear.RowParallelLinear(*args: Any, **kwargs: Any)[source]#
Bases:
ParallelLinear
- class easydel.layers.linear.TensorParallelConfig(mesh: Mesh = None, axis_name: str = 'tp', matmul_method: None = None, reduce_output: bool = False, reduce_scatter_output: bool = False)[source]#
Bases:
objectConfiguration for Tensor Parallelism.
- mesh#
The JAX device mesh.
- Type
- axis_name#
The name of the mesh axis to use for tensor parallelism.
- Type
str
- matmul_type#
The type of matmul (MatrixMultiplyMethod).
- reduce_scatter_output#
If True and parallel_type is COLUMN, use reduce-scatter instead of all-gather for the output. This keeps the output sharded (useful for sequence parallelism or subsequent RowParallel layers). Defaults to False.
- Type
bool
- axis_name: str = 'tp'#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- reduce_output: bool = False#
- reduce_scatter_output: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- easydel.layers.linear.get_matmul_output_sharding(lhs_pspec: jax.sharding.PartitionSpec | None, rhs_pspec: jax.sharding.PartitionSpec | None) PartitionSpec[source]#
Determine output sharding for matrix multiplication.
Calculates the output PartitionSpec based on input partition specs, following matrix multiplication rules where contracting dimensions are reduced and non-contracting dimensions determine output sharding.
For X @ W: - Contracting dimensions are reduced during matmul - Non-contracting dimensions determine output sharding - Ensures no duplicate sharding dimensions in output
- Parameters
lhs_pspec – PartitionSpec for left-hand side matrix.
rhs_pspec – PartitionSpec for right-hand side matrix.
- Returns
Output PartitionSpec for the multiplication result.
Ensures correct output dimensionality with None padding if needed
- Parameters
lhs_pspec – PartitionSpec for the left-hand side matrix X
rhs_pspec – PartitionSpec for the right-hand side matrix W
- Returns
PartitionSpec for the output of X @ W
- easydel.layers.linear.get_output_partition_spec(lhs: Shaped[Array, '...'], rhs: Shaped[Array, '...'], method: MatrixMultiplyMethod, axis_name: str) Ps | None[source]#
Calculate output partition spec for matrix multiplication.
Determines the appropriate output sharding based on input sharding and the matrix multiplication method used.
- Parameters
lhs – Left-hand side array (inputs).
rhs – Right-hand side array (weights).
method – Matrix multiplication method.
axis_name – Axis name for sharding.
- Returns
Output partition specification for the result.
- easydel.layers.linear.get_sharding(arr: Shaped[Array, '...']) jax.sharding.PartitionSpec | None[source]#
Get the sharding specification of an array.
Extracts the PartitionSpec from a sharded JAX array.
- Parameters
arr – Array to get sharding from.
- Returns
PartitionSpec of the array, or None if not sharded.