easydel.kernels.collective_matmul#
- class easydel.kernels.collective_matmul.MatrixMultiplyMethod(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#
Bases:
EnumEnumeration of distributed matrix multiplication methods.
- ALL_GATHER#
Matrix multiplication using all-gather communication pattern. Suitable when the output needs to be fully replicated across devices.
- REDUCE_SCATTER#
Matrix multiplication using reduce-scatter communication pattern. Efficient when the output can be partitioned across devices.
- ALL_GATHER = 1#
- REDUCE_SCATTER = 2#
- easydel.kernels.collective_matmul.calculate_mesh_dimension_size(axis_names: Optional[Union[Tuple[str, ...], str, Any]]) int[source]#
Calculates the total number of devices along the specified mesh dimension(s).
This function computes the product of the number of devices along each specified mesh dimension, providing the total size of the submesh defined by these axes.
- Parameters
axis_names โ A single mesh dimension name (str) or a sequence (list/tuple) of mesh dimension names. For sequences, the order doesnโt affect the result since multiplication is commutative.
- Returns
- The total number of devices in the submesh defined by the dimension(s).
Returns 1 if axis_names is an empty sequence.
- Return type
int
- Raises
TypeError โ If axis_names is not a str or a sequence of str.
Examples
>>> calculate_mesh_dimension_size("data") # Single dimension 8 >>> calculate_mesh_dimension_size(["data", "model"]) # Multiple dimensions 32
- easydel.kernels.collective_matmul.compute_device_linear_index(axis_names: Optional[Union[Tuple[str, ...], str, Any]]) int[source]#
Computes the linear index of the current device within the specified mesh dimensions.
This function flattens the multi-dimensional coordinates of the device within the submesh defined by axis_names into a single integer index using row-major ordering.
- Parameters
axis_names โ A single mesh dimension name (str) or a sequence (list/tuple) of mesh dimension names, ordered from major to minor dimensions. The order is important as it affects the resulting linear index.
- Returns
- The 0-based linear index of the current device within the submesh.
Returns 0 if axis_names is an empty sequence.
- Return type
int
- Raises
TypeError โ If axis_names is not a str or a sequence of str.
Examples
>>> compute_device_linear_index("data") # Single dimension 2 >>> compute_device_linear_index(["data", "model"]) # Multi-dimensional 9
Note
The calculation assumes row-major ordering where the rightmost dimension varies fastest (similar to C-style arrays).
- easydel.kernels.collective_matmul.create_distributed_matmul(method: MatrixMultiplyMethod, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Callable[[Array, Array], Array][source]#
Creates a distributed matrix multiplication function using the specified method.
This factory function returns a specialized matrix multiplication function that implements the requested distributed computation strategy.
- Parameters
method โ The distributed matrix multiplication method to use
partition_dims โ Dimension names for collective operations
- Returns
A function that performs distributed matrix multiplication using the specified method
- Raises
ValueError โ If an unsupported matrix multiplication method is provided
Example
>>> matmul_fn = create_distributed_matmul(MatrixMultiplyMethod.ALL_GATHER, "data") >>> result = matmul_fn(left_matrix, right_matrix)
- easydel.kernels.collective_matmul.perform_all_gather_matmul(lhs: Array, rhs: Array, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Array[source]#
Performs matrix multiplication with an all-gather communication pattern.
This function implements an efficient distributed matrix multiplication algorithm that uses an all-gather communication pattern. It processes chunks of the right-hand side matrix row-wise and accumulates partial results while shuffling the left-hand side matrix between devices.
- Parameters
lhs โ Left-hand side matrix
rhs โ Right-hand side matrix (should be pre-processed with prepare_matrix_for_all_gather)
partition_dims โ Dimension names for collective operations
- Returns
The result of the distributed matrix multiplication
- Return type
Array
Note
This implementation achieves better performance compared to naive distributed matrix multiplication by optimizing communication patterns.
- easydel.kernels.collective_matmul.perform_reduce_scatter_matmul(lhs: Array, rhs: Array, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Array[source]#
Performs matrix multiplication with a reduce-scatter communication pattern.
This function implements an efficient distributed matrix multiplication algorithm that uses a reduce-scatter communication pattern to minimize data movement. The algorithm processes chunks of the right-hand side matrix and accumulates partial results while shuffling data between devices.
- Parameters
lhs โ Left-hand side matrix
rhs โ Right-hand side matrix (should be pre-processed with prepare_matrix_for_reduce_scatter)
partition_dims โ Dimension names for collective operations
- Returns
The result of the distributed matrix multiplication
- Return type
Array
Note
This implementation achieves better performance compared to naive distributed matrix multiplication by optimizing communication patterns.
- easydel.kernels.collective_matmul.prepare_matrix_for_all_gather(matrix: Array, device_mesh: Mesh, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Array[source]#
Prepares a matrix for all-gather collective matrix multiplication by reshuffling data.
This function reorganizes the input matrix across devices to optimize the subsequent all-gather collective matrix multiplication operation. It performs data swapping between pairs of devices to ensure proper data alignment.
- Parameters
matrix โ The input matrix to be prepared
device_mesh โ The device mesh used for distributed computation
partition_dims โ The dimension names along which the matrix is partitioned
- Returns
The prepared matrix with data appropriately reshuffled
- Return type
Array
Note
This preprocessing step is crucial for the efficiency of the subsequent all-gather collective matrix multiplication.
- easydel.kernels.collective_matmul.prepare_matrix_for_reduce_scatter(matrix: Array, device_mesh: Mesh, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Array[source]#
Prepares a matrix for reduce-scatter collective matrix multiplication by reshuffling data.
This function reorganizes the input matrix across devices to optimize the subsequent reduce-scatter collective matrix multiplication operation. It performs data swapping between pairs of devices along the column dimension.
- Parameters
matrix โ The input matrix to be prepared
device_mesh โ The device mesh used for distributed computation
partition_dims โ The dimension names along which the matrix is partitioned
- Returns
The prepared matrix with data appropriately reshuffled
- Return type
Array
Note
This preprocessing step ensures efficient communication patterns during the subsequent reduce-scatter collective matrix multiplication.
- easydel.kernels.collective_matmul.run_all_tests()[source]#
Runs all distributed matrix multiplication tests and reports results.
This function executes both the all-gather and reduce-scatter matrix multiplication tests and collects their results.
- Returns
A dictionary containing test results for both methods
- Return type
dict
- easydel.kernels.collective_matmul.test_all_gather_matmul()[source]#
Tests the all-gather distributed matrix multiplication implementation.
This function creates a test case with random matrices, computes the expected result using standard matrix multiplication, and verifies that the distributed implementation produces the same result within numerical tolerance.
- Returns
A tuple containing (actual_result, expected_result)
- Return type
tuple
- easydel.kernels.collective_matmul.test_reduce_scatter_matmul()[source]#
Tests the reduce-scatter distributed matrix multiplication implementation.
This function creates a test case with random matrices, computes the expected result using standard matrix multiplication, and verifies that the distributed implementation produces the same result within numerical tolerance.
- Returns
A tuple containing (actual_result, expected_result)
- Return type
tuple