Source code for eformer.escale.partition.constraints

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import dataclasses
import os
import re
import typing as tp
import warnings
from functools import partial

import chex
import jax
import jax.extend
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax import tree_util as tu
from jax.interpreters import pxla
from jax.lax import with_sharding_constraint as _with_sharding_constraint
from jax.sharding import Mesh, NamedSharding, PartitionSpec

from eformer.pytree import auto_pytree, named_tree_map

MIN_SHARDING_SIZE = int(os.environ.get("MIN_SHARDING_SIZE", "16384"))
LOG_SHARDING_MOVE = os.environ.get("LOG_SHARDING_MOVE", "false") in [
	"true",
	"yes",
	"1",
	"on",
]

AxisType = tp.Optional[tp.Union[tp.Tuple[str, ...], str, tp.Any]]


def names_in_current_mesh(*names: str) -> bool:
	"""
	Check if the given names are present in the current JAX mesh.

	Args:
	    *names: Variable number of axis names to check.

	Returns:
	    True if all given names are present in the current mesh, False otherwise.
	"""
	mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
	return set(names) <= set(mesh_axis_names)


def make_shard_and_gather_fns(
	partition_specs: tp.Dict[str, PartitionSpec],
	mesh: tp.Optional[Mesh] = None,
) -> tp.Tuple[tp.Dict[str, tp.Callable], tp.Dict[str, tp.Callable]]:
	"""
	Create shard and gather functions based on given partition specs and mesh.

	This function generates dictionaries of shard and gather functions that can be used
	to distribute and collect arrays across a JAX mesh. The functions are specifically
	designed for use with Flax's `tu.tree_map`.

	Args:
	        partition_specs: A dictionary mapping parameter names to their respective `PartitionSpec`.
	        mesh: The JAX mesh to use for sharding. If None, the current mesh is used.

	Returns:
	        A tuple containing two dictionaries:
	                - `shard_fns`: A dictionary mapping parameter names to their corresponding shard functions.
	                - `gather_fns`: A dictionary mapping parameter names to their corresponding gather functions.
	"""
	if mesh is None:
		mesh = get_incontext_mesh()

	named_shardings = tu.tree_map(
		lambda p: NamedSharding(mesh=mesh, spec=p),
		partition_specs,
	)

	def make_shard_fn(sharding: NamedSharding) -> tp.Callable:
		"""
		Create a shard function for a specific partition spec.
		"""
		if jax.process_count() > 1:

			@partial(jax.jit, out_shardings=sharding)
			def _self_shard(tensor):
				return jnp.asarray(tensor)

			def shard_fn(tensor: jnp.ndarray) -> jnp.ndarray:
				with mesh:
					tensor = jax.block_until_ready(_self_shard(tensor))
					assert tensor.sharding == sharding, "sharding Failed!."
				return tensor

			return shard_fn
		else:

			def shard_fn(tensor: jnp.ndarray) -> jnp.ndarray:
				with mesh:
					tensor = with_sharding_constraint(tensor, sharding=sharding)
				return tensor

			return shard_fn

	def make_gather_fn(sharding: NamedSharding) -> tp.Callable:
		"""
		Create a gather function for a specific partition spec.
		"""

		@partial(jax.jit, out_shardings=NamedSharding(mesh=mesh, spec=PartitionSpec()))
		def _self_gather(tensor):
			return jnp.asarray(tensor)

		def gather_fn(tensor: jnp.ndarray) -> jnp.ndarray:
			return jax.device_get(jax.block_until_ready(_self_gather(tensor)))

		return gather_fn

	shard_fns = tu.tree_map(make_shard_fn, named_shardings)
	gather_fns = tu.tree_map(make_gather_fn, named_shardings)
	return shard_fns, gather_fns


def get_names_from_partition_spec(
	partition_specs: tp.Dict[str, PartitionSpec],
) -> tp.List[str]:
	"""
	Extract axis names from a partition specification.

	This function recursively iterates through the provided `partition_specs`
	dictionary and extracts all unique axis names used in the sharding specifications.

	Args:
	        partition_specs: A dictionary mapping parameter names to their respective `PartitionSpec`.

	Returns:
	        A list of unique axis names used in the partition specs.
	"""
	names = set()
	if isinstance(partition_specs, dict):
		partition_specs = partition_specs.values()
	for item in partition_specs:
		if item is None:
			continue
		elif isinstance(item, str):
			names.add(item)
		else:
			names.update(get_names_from_partition_spec(item))
	return list(names)


def with_sharding_constraint(
	arr: jnp.ndarray,
	sharding: tp.Dict[str, tp.Union[PartitionSpec, NamedSharding]],
) -> jnp.ndarray:
	"""
	Apply sharding constraints if axis names are present in the current mesh.

	This is a smarter version of `jax.lax.with_sharding_constraint`. It only applies the
	sharding constraint if all the axis names specified in the `partition_specs` are
	present in the current JAX mesh.

	Args:
	        arr: The JAX array to apply sharding constraints to.
	        sharding: A dictionary mapping parameter names to their respective `PartitionSpec`.

	Returns:
	        The JAX array with sharding constraints applied (if applicable).
	"""
	if isinstance(arr, (jax.Array, jnp.ndarray)):
		if isinstance(sharding, NamedSharding):
			mesh = sharding.mesh
			sharding = sharding.spec
		else:
			mesh = None
		if mesh is None:
			mesh = get_incontext_mesh()
		axis_names = get_names_from_partition_spec(sharding)
		if names_in_current_mesh(*axis_names):
			with mesh or contextlib.nullcontext():
				arr = _with_sharding_constraint(arr, sharding)
	return arr


def match_partition_rules(
	rules: tp.List[tp.Tuple[str, PartitionSpec]],
	tree: tp.Dict,
) -> tp.Dict:
	"""
	Match partition rules to parameters based on their names.

	This function takes a list of partition rules (regular expressions and
	corresponding `PartitionSpec`) and applies them to a dictionary of parameters
	based on their names. It's useful for automatically defining sharding strategies.

	Args:
	        rules: A list of tuples, where each tuple contains:
	                         - A regular expression to match parameter names.
	                         - A `PartitionSpec` to apply if the name matches.
	        tree: A dictionary of parameters, where keys are parameter names.

	Returns:
	        A dictionary with the same keys as `tree`, but values are replaced
	        with the corresponding `PartitionSpec` based on matching rules.
	"""

	def get_partition_spec(name: str, leaf: jnp.ndarray) -> PartitionSpec:
		"""
		Determine the partition spec for a parameter based on its name.
		"""

		if not hasattr(leaf, "shape"):
			return PartitionSpec()
		size = np.prod(leaf.shape)
		if len(leaf.shape) == 0:
			""" Don't partition scalar values. """
			return PartitionSpec()

		for rule, ps in rules:
			if re.search(rule, name) is not None:
				if size < MIN_SHARDING_SIZE:
					if LOG_SHARDING_MOVE:
						warnings.warn(
							f"PartitionSpec Related to {name} was safer and faster being local array.",
							stacklevel=1,
						)
					return PartitionSpec()
				if len(ps) > leaf.ndim:
					ps = PartitionSpec(*tuple(ps[: leaf.ndim]))
					if LOG_SHARDING_MOVE:
						warnings.warn(
							f"PartitionSpec Related to {name} went out of range (will be auto trimed to {ps}).",
							stacklevel=1,
						)
				return ps
		raise ValueError(f"Partition rule not found for param: {name}")

	return named_tree_map(get_partition_spec, tree, sep="/")


def analyze_sharding_strategy(
	pytree: tp.Any,
	partition_specs: tp.Dict[str, PartitionSpec],
	mesh: tp.Optional[Mesh] = None,
) -> tp.Dict:
	"""
	Analyzes the effectiveness of a sharding strategy.

	Returns metrics like:
	- Memory usage per device
	- Load balance
	- Communication costs
	"""
	if mesh is None:
		mesh = get_incontext_mesh()

	analysis = {
		"total_parameters": 0,
		"sharded_parameters": 0,
		"memory_per_device": {},
		"balance_score": 0.0,
		"partition_stats": {},
	}

	def analyze_leaf(path: str, array: np.ndarray, spec: PartitionSpec):
		total_size = np.prod(array.shape) * array.dtype.itemsize
		analysis["total_parameters"] += np.prod(array.shape)

		if spec != PartitionSpec():
			analysis["sharded_parameters"] += np.prod(array.shape)

		# Calculate per-device memory
		sharded_size = total_size
		for axis, name in enumerate(spec):
			if name is not None:
				sharded_size //= mesh.shape[name]

		return sharded_size

	# Traverse the pytree and collect statistics
	tu.tree_map_with_path(analyze_leaf, pytree, partition_specs)

	return analysis


def create_pattern_based_partition_spec(
	pattern: str,
	mesh: tp.Optional[Mesh] = None,
	default_spec: tp.Optional[PartitionSpec] = None,
) -> tp.Callable[[str, chex.Array], PartitionSpec]:
	"""
	Creates a function that returns PartitionSpec based on parameter name patterns.

	Example:
	        pattern_fn = create_pattern_based_partition_spec(
	                "attention|mlp->data,hidden->model"
	        )
	"""
	if default_spec is None:
		default_spec = PartitionSpec()
	if mesh is None:
		mesh = get_incontext_mesh()

	rules = []
	for rule in pattern.split(","):
		if "->" in rule:
			patterns, spec = rule.split("->")
			patterns = patterns.split("|")
			spec = PartitionSpec(*spec.split("."))
			rules.extend((pattern, spec) for pattern in patterns)

	def get_partition_spec(name: str, array: chex.Array) -> PartitionSpec:
		for pattern, spec in rules:
			if re.search(pattern, name):
				return spec
		return default_spec

	return get_partition_spec


def extract_sharding_structure(pytree: tp.Any) -> tp.Any:
	"""
	Extract a PyTree of NamedShardings matching the input structure.
	Returns None for leaves without shardings.
	"""
	leaves, treedef = jax.tree_util.tree_flatten(pytree)

	sharding_leaves = []
	for leaf in leaves:
		if isinstance(leaf, jax.Array) and (shard := leaf.sharding) is not None:
			sharding_leaves.append(shard if isinstance(shard, NamedSharding) else None)
		else:
			sharding_leaves.append(None)

	return jax.tree_util.tree_unflatten(treedef, sharding_leaves)


def get_shardings_with_structure(pytree: tp.Any) -> tp.Any:
	"""
	Returns a PyTree matching the input structure containing either:
	- NamedSharding objects where present
	- None for leaves without NamedShardings
	"""
	return extract_sharding_structure(pytree)


def get_incontext_mesh() -> Mesh:
	"""Retrieves the mesh object active in the current execution context.

	This function accesses the physical mesh defined within the thread's
	resource environment (pxla.thread_resources.env.physical_mesh).

	Returns:
	    MeshType: The active mesh object for the current context.

	Raises:
	    AssertionError: If no mesh is found in the current context
	                    (i.e., mesh.empty() is True).
	"""
	mesh = pxla.thread_resources.env.physical_mesh
	if mesh.empty() if callable(mesh.empty) else mesh.empty:
		raise AssertionError("No mesh found under this context manager.")
	# It might be better practice to raise a more specific exception type
	# e.g., class NoActiveMeshError(RuntimeError): pass
	# raise NoActiveMeshError("No mesh found under this context manager.")
	return mesh


def get_axes_size_in_mesh(axis_names: AxisType, mesh: tp.Optional[Mesh] = None) -> int:
	"""
	Calculates the total size of the specified mesh axes.

	If a single axis name (string) is provided, it returns the size of that
	dimension in the mesh. If a sequence (list or tuple) of axis names is
	provided, it returns the product of the sizes of all specified axes.

	If no mesh is explicitly provided, it uses the mesh active in the
	current context obtained via `get_current_mesh()`.

	Args:
	    axis_names: The name of a single mesh axis (str) or a sequence
	                (list/tuple) of axis names whose sizes should be multiplied.
	    mesh: The mesh object to query. If None, the current context's mesh
	          is used. Defaults to None.

	Returns:
	    int: The size of the single specified axis, or the product of the sizes
	         of the sequence of specified axes.

	Raises:
	    KeyError: If any of the specified `axis_names` are not found in the
	              mesh's dimensions.
	    AssertionError: If `mesh` is None and no mesh is found in the current
	                   context (raised by `get_current_mesh()`).
	"""
	if mesh is None:
		mesh = get_incontext_mesh()

	# Assuming mesh.shape behaves like a dictionary {axis_name: size}
	mesh_shape: tp.Dict[str, int] = mesh.shape

	if isinstance(axis_names, str):
		# Raises KeyError if axis_names is not a valid key
		return mesh_shape[axis_names]
	elif isinstance(axis_names, (list, tuple)):
		product = 1
		# Iterate in the provided order, though order doesn't matter for product
		for axis in axis_names:
			# Raises KeyError if axis is not a valid key
			product *= mesh_shape[axis]
		return product
	else:
		# Handle unexpected type for axis_names
		raise TypeError(f"axis_names must be str or Sequence[str], got {type(axis_names)}")


def get_mesh_axis_names(mesh: tp.Optional[Mesh] = None) -> tp.List[str]:
	"""Retrieves the names of all axes defined in the mesh.

	These names typically correspond to the dimensions used for sharding or
	parallelism.

	If no mesh is explicitly provided, it uses the mesh active in the
	current context obtained via `get_current_mesh()`.

	Args:
	    mesh: The mesh object to query. If None, the current context's mesh
	          is used. Defaults to None.

	Returns:
	    List[str]: A list containing the names of all axes in the mesh.

	Raises:
	    AssertionError: If `mesh` is None and no mesh is found in the current
	                   context (raised by `get_current_mesh()`).
	"""
	if mesh is None:
		mesh = get_incontext_mesh()

	mesh_shape: tp.Dict[str, int] = mesh.shape
	return list(mesh_shape.keys())


def get_mesh_axis_size(axis_names: AxisType) -> int:
	"""Calculates the total number of devices along the specified mesh axis or axes.

	Args:
	    axis_names: The name of a single mesh axis (str) or a sequence (list/tuple)
	                of mesh axis names. The order in the sequence does not affect
	                the result (product is commutative).

	Returns:
	    The total number of devices (size) in the submesh defined by the axis/axes.
	    Returns 1 if axis_names is an empty sequence.

	Raises:
	    TypeError: If axis_names is not a str or a sequence of str.
	"""
	if isinstance(axis_names, str):
		# Size along a single axis dimension
		return lax.psum(1, axis_name=axis_names)
	elif isinstance(axis_names, (list, tuple)):
		if not axis_names:
			return 1  # The size of a submesh with zero dimensions is 1

		# Calculate the product of sizes along each specified axis
		product = 1
		for axis in axis_names:
			product *= lax.psum(1, axis_name=axis)
		return product
		# Alternative using math.prod (Python 3.8+)
		# return math.prod(lax.psum(1, axis_name=ax) for ax in axis_names)
	else:
		raise TypeError(
			f"Input 'axis_names' must be a string or sequence (list/tuple), "
			f"but got type {type(axis_names)}"
		)


def get_submesh_device_index(axis_names: AxisType) -> int:
	"""
	Calculates the linear index of the current device within the specified mesh axes.

	This effectively flattens the multi-dimensional coordinates of the device
	within the submesh defined by `axis_names` into a single integer index.

	IMPORTANT: It assumes the input `axis_names` sequence is ordered from
	most major to most minor dimension. The calculation performs a
	row-major-like flattening based on this order.

	Args:
	    axis_names: The name of a single mesh axis (str) or a sequence (list/tuple)
	                of mesh axis names, ordered from major to minor.

	Returns:
	    The 0-based linear index of the current device within the submesh.
	    Returns 0 if axis_names is an empty sequence.

	Raises:
	    TypeError: If axis_names is not a str or a sequence of str.
	"""
	if isinstance(axis_names, str):
		# Index along a single axis dimension
		return lax.axis_index(axis_name=axis_names)
	elif isinstance(axis_names, (list, tuple)):
		if not axis_names:
			return 0  # Index within a zero-dimensional submesh is 0

		linear_index = 0
		stride = 1
		# Iterate from the minor axis to the major axis (reverse of the input order)
		# This implements the formula: idx = sum(local_idx[dim] * stride[dim])
		# where stride[dim] = product(size[k] for k > dim)
		for axis in reversed(axis_names):
			index_on_axis = lax.axis_index(axis_name=axis)
			linear_index += index_on_axis * stride

			# Update stride for the next (more major) dimension
			axis_size = lax.psum(1, axis_name=axis)  # Use lax.psum, not the other func
			stride *= axis_size
		return linear_index
	else:
		raise TypeError(
			f"Input 'axis_names' must be a string or sequence (list/tuple), "
			f"but got type {type(axis_names)}"
		)


def extract_shardings(tree, mesh: Mesh = None):
	"""
	Extracts JAX NamedSharding objects from the leaves of a PyTree.

	This function traverses the input PyTree and inspects each leaf.
	- If a leaf has a `.sharding` attribute that is already a `NamedSharding`,
	  it's returned directly.
	- If a leaf has a `.sharding` attribute that is a `PartitionSpec`, it
	  attempts to convert it into a `NamedSharding` using the provided `mesh`.
	  If no `mesh` is provided, it tries to get one from the JAX context
	  (e.g., using `get_incontext_mesh`). If no mesh is available in either
	  case, an AssertionError is raised.
	- If a leaf does not have a `.sharding` attribute, or if its sharding
	  is not a `NamedSharding` or convertible `PartitionSpec`, `None` is
	  returned for that leaf in the output tree.

	Args:
	    tree: The input PyTree (e.g., nested dictionary, list, tuple) potentially
	          containing JAX arrays or other objects with sharding information.
	    mesh: An optional `jax.sharding.Mesh`. If provided, it's used to convert
	          `PartitionSpec` objects to `NamedSharding`. If `None`, the function
	          attempts to find a mesh from the current JAX context.

	Returns:
	    A PyTree with the same structure as the input `tree`. Each leaf will
	    contain either a `jax.sharding.NamedSharding` object corresponding
	    to the input leaf's sharding, or `None` if no valid sharding
	    information was found or could be constructed.

	Raises:
	    AssertionError: If a leaf has a `PartitionSpec` sharding but no `mesh`
	                    is provided or found in the context.
	"""
	if mesh is None:
		mesh = get_incontext_mesh()

	def cond(x):
		sharding = x.sharding if hasattr(x, "sharding") else None
		if isinstance(sharding, jax.sharding.PartitionSpec):
			assert mesh is not None, "Mesh Can not be none (use function under with `mesh`)."
			sharding = jax.sharding.NamedSharding(mesh=mesh, spec=sharding)
		if not isinstance(sharding, jax.sharding.NamedSharding):
			return None
		return sharding

	return jax.tree_util.tree_map(cond, tree)


def get_partition_spec(tree):
	"""
	Retrieves the PartitionSpec for each leaf in a PyTree.

	This function traverses the input PyTree and determines the
	`jax.sharding.PartitionSpec` for each leaf based on its type:
	- If the leaf is a `jax.Array`, it returns the `PartitionSpec` from
	  `leaf.sharding.spec`.
	- If the leaf is a Python scalar (`int` or `float`), it returns an
	  empty `PartitionSpec()`, assuming scalars are typically replicated.
	- For any other leaf type, it raises a `ValueError`.

	Args:
	    tree: The input PyTree (e.g., nested dictionary, list, tuple) containing
	          JAX arrays, scalars, or potentially other types.

	Returns:
	    A PyTree with the same structure as the input `tree`. Each leaf
	    contains the corresponding `jax.sharding.PartitionSpec`.

	Raises:
	    ValueError: If a leaf in the tree is not a `jax.Array`, `int`, or `float`.
	    AttributeError: If a `jax.Array` leaf doesn't have `.sharding.spec` (which
	                    would be unusual for a properly sharded array).
	"""

	def _call(arr):
		if isinstance(arr, jax.Array):
			if hasattr(arr, "sharding") and hasattr(arr.sharding, "spec"):
				return arr.sharding.spec
			else:
				raise AttributeError(
					f"jax.Array leaf does not have expected .sharding.spec: {arr}"
				)

		elif isinstance(arr, (int, float)):
			return PartitionSpec()
		else:
			raise ValueError(
				f"Unsupported leaf type for get_partition_spec: {type(arr)}. "
				"Expected jax.Array, int, or float."
			)

	return jax.tree_util.tree_map(_call, tree)


[docs]@auto_pytree class PartitionAxis: """ Configuration for partitioning model axes across a device mesh. Defines the mesh dimension names for standard parallelism strategies and maps logical model axes to these dimensions. Allows overriding defaults. Mesh Dimensions: data_parallel_axis: Name for data parallel mesh dim. Default: "dp". fully_sharded_data_parallel_axis: Name for FSDP mesh dim. Default: "fsdp". tensor_parallel_axis: Name for tensor parallel mesh dim. Default: "tp". sequence_parallel_axis: Name for sequence parallel mesh dim. Default: "sp". expert_parallel_axis: Name for expert parallel mesh dim (MoE). Default: "ep". Logical Model Axes: Maps logical tensor axes (like batch, sequence, hidden) to one or more mesh dimension names defined above, or None if not partitioned. Defaults are derived from the standard mesh dimension names but can be overridden during instantiation. For example, `head_axis` defaults to the value of `tensor_parallel_axis` ('tp'). """ # --- Mesh Dimension Names --- data_parallel_axis: str = "dp" fully_sharded_data_parallel_axis: str = "fsdp" tensor_parallel_axis: str = "tp" sequence_parallel_axis: str = "sp" expert_parallel_axis: str = "ep" batch_axis: AxisType = ... sequence_axis: AxisType = ... query_sequence_axis: AxisType = ... head_axis: AxisType = ... key_sequence_axis: AxisType = ... hidden_state_axis: AxisType = ... mlp_intermediate_axis: AxisType = ... vocab_axis: AxisType = ... expert_axis: AxisType = ... expert_gate_axis: AxisType = None attention_dim_axis: AxisType = None # Usually not partitioned bias_head_sequence_axis: AxisType = None bias_key_sequence_axis: AxisType = None # --- Generation Specific --- generation_batch_axis: AxisType = None generation_query_sequence_axis: AxisType = None # Often length 1, not sharded generation_head_axis: AxisType = ... generation_key_sequence_axis: AxisType = ... generation_attention_dim_axis: AxisType = None def __post_init__(self): """ Resolve default partitioning strategies after initialization. Since the dataclass is frozen, we need to use object.__setattr__ to modify fields. """ # Helper to set attribute on frozen dataclass def set_attr(obj, name, value): object.__setattr__(obj, name, value) def _operate(val): return val is Ellipsis # Resolve fields that need defaults if _operate(self.batch_axis): # Default batch sharding uses both FSDP and DP dimensions _shardin = (self.fully_sharded_data_parallel_axis, self.data_parallel_axis) set_attr(self, "batch_axis", _shardin) if _operate(self.sequence_axis): set_attr(self, "sequence_axis", self.sequence_parallel_axis) if _operate(self.query_sequence_axis): set_attr(self, "query_sequence_axis", self.sequence_parallel_axis) if _operate(self.head_axis): set_attr(self, "head_axis", self.tensor_parallel_axis) if _operate(self.key_sequence_axis): set_attr(self, "key_sequence_axis", self.sequence_parallel_axis) if _operate(self.hidden_state_axis): set_attr(self, "hidden_state_axis", self.tensor_parallel_axis) if _operate(self.mlp_intermediate_axis): set_attr(self, "mlp_intermediate_axis", self.tensor_parallel_axis) if _operate(self.vocab_axis): set_attr(self, "vocab_axis", self.tensor_parallel_axis) if _operate(self.expert_axis): set_attr(self, "expert_axis", self.expert_parallel_axis) if _operate(self.generation_head_axis): set_attr(self, "generation_head_axis", self.tensor_parallel_axis) if _operate(self.generation_key_sequence_axis): set_attr(self, "generation_key_sequence_axis", self.sequence_parallel_axis) self._safety_check() def _safety_check(self): """Ensures no essential attributes are left uninitialized (as Ellipsis).""" for field in dataclasses.fields(self): val = getattr(self, field.name) if val == Ellipsis: raise ValueError(f"`{field.name}` shouldn't be ellipsis")