easydel.kernels.collective_matmul#

class easydel.kernels.collective_matmul.MatrixMultiplyMethod(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]#

Bases: Enum

Enumeration of distributed matrix multiplication methods.

ALL_GATHER#

Matrix multiplication using all-gather communication pattern. Suitable when the output needs to be fully replicated across devices.

REDUCE_SCATTER#

Matrix multiplication using reduce-scatter communication pattern. Efficient when the output can be partitioned across devices.

ALL_GATHER = 1#
REDUCE_SCATTER = 2#
easydel.kernels.collective_matmul.calculate_mesh_dimension_size(axis_names: Optional[Union[Tuple[str, ...], str, Any]]) int[source]#

Calculates the total number of devices along the specified mesh dimension(s).

This function computes the product of the number of devices along each specified mesh dimension, providing the total size of the submesh defined by these axes.

Parameters

axis_names โ€“ A single mesh dimension name (str) or a sequence (list/tuple) of mesh dimension names. For sequences, the order doesnโ€™t affect the result since multiplication is commutative.

Returns

The total number of devices in the submesh defined by the dimension(s).

Returns 1 if axis_names is an empty sequence.

Return type

int

Raises

TypeError โ€“ If axis_names is not a str or a sequence of str.

Examples

>>> calculate_mesh_dimension_size("data")  # Single dimension
8
>>> calculate_mesh_dimension_size(["data", "model"])  # Multiple dimensions
32
easydel.kernels.collective_matmul.compute_device_linear_index(axis_names: Optional[Union[Tuple[str, ...], str, Any]]) int[source]#

Computes the linear index of the current device within the specified mesh dimensions.

This function flattens the multi-dimensional coordinates of the device within the submesh defined by axis_names into a single integer index using row-major ordering.

Parameters

axis_names โ€“ A single mesh dimension name (str) or a sequence (list/tuple) of mesh dimension names, ordered from major to minor dimensions. The order is important as it affects the resulting linear index.

Returns

The 0-based linear index of the current device within the submesh.

Returns 0 if axis_names is an empty sequence.

Return type

int

Raises

TypeError โ€“ If axis_names is not a str or a sequence of str.

Examples

>>> compute_device_linear_index("data")  # Single dimension
2
>>> compute_device_linear_index(["data", "model"])  # Multi-dimensional
9

Note

The calculation assumes row-major ordering where the rightmost dimension varies fastest (similar to C-style arrays).

easydel.kernels.collective_matmul.create_distributed_matmul(method: MatrixMultiplyMethod, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Callable[[Array, Array], Array][source]#

Creates a distributed matrix multiplication function using the specified method.

This factory function returns a specialized matrix multiplication function that implements the requested distributed computation strategy.

Parameters
  • method โ€“ The distributed matrix multiplication method to use

  • partition_dims โ€“ Dimension names for collective operations

Returns

A function that performs distributed matrix multiplication using the specified method

Raises

ValueError โ€“ If an unsupported matrix multiplication method is provided

Example

>>> matmul_fn = create_distributed_matmul(MatrixMultiplyMethod.ALL_GATHER, "data")
>>> result = matmul_fn(left_matrix, right_matrix)
easydel.kernels.collective_matmul.perform_all_gather_matmul(lhs: Array, rhs: Array, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Array[source]#

Performs matrix multiplication with an all-gather communication pattern.

This function implements an efficient distributed matrix multiplication algorithm that uses an all-gather communication pattern. It processes chunks of the right-hand side matrix row-wise and accumulates partial results while shuffling the left-hand side matrix between devices.

Parameters
  • lhs โ€“ Left-hand side matrix

  • rhs โ€“ Right-hand side matrix (should be pre-processed with prepare_matrix_for_all_gather)

  • partition_dims โ€“ Dimension names for collective operations

Returns

The result of the distributed matrix multiplication

Return type

Array

Note

This implementation achieves better performance compared to naive distributed matrix multiplication by optimizing communication patterns.

easydel.kernels.collective_matmul.perform_reduce_scatter_matmul(lhs: Array, rhs: Array, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Array[source]#

Performs matrix multiplication with a reduce-scatter communication pattern.

This function implements an efficient distributed matrix multiplication algorithm that uses a reduce-scatter communication pattern to minimize data movement. The algorithm processes chunks of the right-hand side matrix and accumulates partial results while shuffling data between devices.

Parameters
  • lhs โ€“ Left-hand side matrix

  • rhs โ€“ Right-hand side matrix (should be pre-processed with prepare_matrix_for_reduce_scatter)

  • partition_dims โ€“ Dimension names for collective operations

Returns

The result of the distributed matrix multiplication

Return type

Array

Note

This implementation achieves better performance compared to naive distributed matrix multiplication by optimizing communication patterns.

easydel.kernels.collective_matmul.prepare_matrix_for_all_gather(matrix: Array, device_mesh: Mesh, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Array[source]#

Prepares a matrix for all-gather collective matrix multiplication by reshuffling data.

This function reorganizes the input matrix across devices to optimize the subsequent all-gather collective matrix multiplication operation. It performs data swapping between pairs of devices to ensure proper data alignment.

Parameters
  • matrix โ€“ The input matrix to be prepared

  • device_mesh โ€“ The device mesh used for distributed computation

  • partition_dims โ€“ The dimension names along which the matrix is partitioned

Returns

The prepared matrix with data appropriately reshuffled

Return type

Array

Note

This preprocessing step is crucial for the efficiency of the subsequent all-gather collective matrix multiplication.

easydel.kernels.collective_matmul.prepare_matrix_for_reduce_scatter(matrix: Array, device_mesh: Mesh, partition_dims: Optional[Union[Tuple[str, ...], str, Any]]) Array[source]#

Prepares a matrix for reduce-scatter collective matrix multiplication by reshuffling data.

This function reorganizes the input matrix across devices to optimize the subsequent reduce-scatter collective matrix multiplication operation. It performs data swapping between pairs of devices along the column dimension.

Parameters
  • matrix โ€“ The input matrix to be prepared

  • device_mesh โ€“ The device mesh used for distributed computation

  • partition_dims โ€“ The dimension names along which the matrix is partitioned

Returns

The prepared matrix with data appropriately reshuffled

Return type

Array

Note

This preprocessing step ensures efficient communication patterns during the subsequent reduce-scatter collective matrix multiplication.

easydel.kernels.collective_matmul.run_all_tests()[source]#

Runs all distributed matrix multiplication tests and reports results.

This function executes both the all-gather and reduce-scatter matrix multiplication tests and collects their results.

Returns

A dictionary containing test results for both methods

Return type

dict

easydel.kernels.collective_matmul.test_all_gather_matmul()[source]#

Tests the all-gather distributed matrix multiplication implementation.

This function creates a test case with random matrices, computes the expected result using standard matrix multiplication, and verifies that the distributed implementation produces the same result within numerical tolerance.

Returns

A tuple containing (actual_result, expected_result)

Return type

tuple

easydel.kernels.collective_matmul.test_reduce_scatter_matmul()[source]#

Tests the reduce-scatter distributed matrix multiplication implementation.

This function creates a test case with random matrices, computes the expected result using standard matrix multiplication, and verifies that the distributed implementation produces the same result within numerical tolerance.

Returns

A tuple containing (actual_result, expected_result)

Return type

tuple