easydel.layers.linear#
- class easydel.layers.linear.ParallelLinear(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleA Linear layer with optional parallelism.
Behaves like nnx.Linear but can distribute computation and parameters across devices based on the TensorParallelConfig.
- in_features#
Number of input features.
- out_features#
Number of output features.
- use_bias#
Whether to include a bias term. Default is True.
- dtype#
The dtype of the computation (defaults to inferred from input).
- param_dtype#
The dtype of the parameters. Default is float32.
- precision#
JAX precision for the dot product. Default is None.
- kernel_init#
Initializer for the kernel weights.
- bias_init#
Initializer for the bias.
- parallel_config#
Configuration for tensor parallelism. If None, the layer behaves like a standard non-parallel Linear layer.
- native_forward(inputs: Array) Array[source]#
Applies the linear transformation with optional tensor parallelism.
- Parameters
inputs – The input array. Shape: (…, in_features). For ROW parallelism, the input is expected to be sharded along the feature dimension (axis_name).
- Returns
The transformed output array. Shape: (…, out_features). Output is sharded for COLUMN parallelism if reduce_scatter_output is True. Otherwise, output is fully replicated.
- class easydel.layers.linear.TensorParallelConfig(mesh: Mesh = None, axis_name: str = 'tp', matmul_method: Optional[MatrixMultiplyMethod] = None, reduce_output: bool = False, reduce_scatter_output: bool = False)[source]#
Bases:
objectConfiguration for Tensor Parallelism.
- mesh#
The JAX device mesh.
- Type
- axis_name#
The name of the mesh axis to use for tensor parallelism.
- Type
str
- matmul_type#
The type of matmul (MatrixMultiplyMethod).
- reduce_scatter_output#
If True and parallel_type is COLUMN, use reduce-scatter instead of all-gather for the output. This keeps the output sharded (useful for sequence parallelism or subsequent RowParallel layers). Defaults to False.
- Type
bool
- axis_name: str = 'tp'#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- matmul_method: Optional[MatrixMultiplyMethod] = None#
- reduce_output: bool = False#
- reduce_scatter_output: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- easydel.layers.linear.get_matmul_output_sharding(lhs_pspec, rhs_pspec)[source]#
Determine the output sharding PartitionSpec for a matrix multiplication based on the partition specs of the input matrices.
For matrix multiplication X @ W: - The contracting dimensions get reduced during matmul - The non-contracting dimensions of X and W determine the output sharding - Ensures no duplicate sharding dimensions in the output - Ensures correct output dimensionality with None padding if needed
- Parameters
lhs_pspec – PartitionSpec for the left-hand side matrix X
rhs_pspec – PartitionSpec for the right-hand side matrix W
- Returns
PartitionSpec for the output of X @ W
- easydel.layers.linear.get_output_partition_spec(lhs: Array, rhs: Array, method: MatrixMultiplyMethod, axis_name: str)[source]#
Calculate output partition spec based on input arrays and matmul method.
- Parameters
lhs – Left-hand side array (inputs)
rhs – Right-hand side array (weights)
method – Matrix multiplication method
axis_name – Axis name for sharding
- Returns
Output partition specification
- easydel.layers.linear.get_sharding(arr: Array) Optional[PartitionSpec][source]#
Gets the sharding of an array.
- Parameters
arr – Array to get sharding from.
- Returns
Sharding of the array.