# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provides classes and functions for managing JAX sharding configurations
and applying sharding constraints within a context.
It includes the `PartitionAxis` class for defining logical-to-physical axis mappings
and the `PartitionManager` context manager for applying these rules.
"""
import dataclasses
import hashlib
import typing as tp
import jax
from jax.sharding import PartitionSpec
from eformer.common_types import (
BATCH,
BIAS_HEAD_SEQ,
BIAS_KV_SEQ,
EMBED,
EMPTY,
EXPERT,
EXPERT_GATE,
GENERATION_MODES,
HEAD,
HEAD_DIM,
KV_HEAD,
KV_HEAD_DIM,
KV_LENGTH,
LENGTH,
MLP_INTERMEDIATE,
MODE_DECODE,
MODE_TRAIN,
NOT_GIVEN,
QUERY_LENGTH,
RUNTIME_MODE_TYPES,
VOCAB,
AxisType,
DynamicShardingAxes,
)
from eformer.pytree import xTree, PyTree
from .constraints import get_corrected_named_sharding, with_sharding_constraint
def hash_fn(self) -> int:
shu = "".join(
str(cu)
for cu in self.__dict__.values()
if isinstance(cu, (float, int, float, bool, dict, list))
)
return get_safe_hash_int(shu)
def get_safe_hash_int(text, algorithm="md5"):
try:
text_str = str(text)
hash_object = getattr(hashlib, algorithm)(text_str.encode())
return int.from_bytes(hash_object.digest(), byteorder="big")
except AttributeError as e:
raise ValueError(f"Unsupported hash algorithm: {algorithm}") from e
except Exception as e:
raise Exception(f"Error generating hash: {str(e)}") from e
[docs]class PartitionAxis(xTree):
"""
Configuration for partitioning model axes across a device mesh.
Defines the mesh dimension names for standard parallelism strategies and maps
logical model axes to these dimensions. Allows overriding defaults.
Mesh Dimensions Attributes:
data_parallel_axis: Name for data parallel mesh dim. Default: "dp".
fully_sharded_data_parallel_axis: Name for FSDP mesh dim. Default: "fsdp".
tensor_parallel_axis: Name for tensor parallel mesh dim. Default: "tp".
sequence_parallel_axis: Name for sequence parallel mesh dim. Default: "sp".
expert_parallel_axis: Name for expert parallel mesh dim (MoE). Default: "ep".
Logical Model Axes Attributes:
Maps logical tensor axes (like batch, sequence, hidden) to one or more
mesh dimension names defined above, or None if not partitioned.
Defaults are derived from the standard mesh dimension names but can be
overridden during instantiation. For example, `head_axis` defaults to
the value of `tensor_parallel_axis` ('tp').
batch_axis: Mesh axis for the batch dimension.
sequence_axis: Mesh axis for the general sequence length dimension.
query_sequence_axis: Mesh axis for the query sequence length dimension.
head_axis: Mesh axis for the attention head dimension.
key_sequence_axis: Mesh axis for the key/value sequence length dimension.
hidden_state_axis: Mesh axis for the embedding or hidden state dimension.
mlp_intermediate_axis: Mesh axis for the intermediate dimension in MLP layers.
vocab_axis: Mesh axis for the vocabulary dimension.
expert_axis: Mesh axis for the expert dimension.
expert_gate_axis: Mesh axis for the expert gate dimension.
attention_dim_axis: Mesh axis for the dimension within each attention head.
bias_head_sequence_axis: Mesh axis for bias related to head and sequence dimensions.
bias_key_sequence_axis: Mesh axis for bias related to key/value sequence dimensions.
decode_batch_axis: Mesh axis for the batch dimension during decoding.
decode_query_sequence_axis: Mesh axis for the query sequence length during decoding.
decode_head_axis: Mesh axis for the attention head dimension during decoding.
decode_key_sequence_axis: Mesh axis for the key/value sequence length during decoding.
decode_attention_dim_axis: Mesh axis for the dimension within each attention head during decoding.
"""
data_parallel_axis: str = "dp"
fully_sharded_data_parallel_axis: str = "fsdp"
tensor_parallel_axis: str = "tp"
sequence_parallel_axis: str = "sp"
expert_parallel_axis: str = "ep"
batch_axis: AxisType = NOT_GIVEN
sequence_axis: AxisType = NOT_GIVEN
query_sequence_axis: AxisType = NOT_GIVEN
head_axis: AxisType = NOT_GIVEN
kv_head_axis: AxisType = NOT_GIVEN
key_sequence_axis: AxisType = NOT_GIVEN
hidden_state_axis: AxisType = NOT_GIVEN
mlp_intermediate_axis: AxisType = NOT_GIVEN
vocab_axis: AxisType = NOT_GIVEN
expert_axis: AxisType = NOT_GIVEN
expert_gate_axis: AxisType = None
attention_dim_axis: AxisType = None
attention_kv_dim_axis: AxisType = None
bias_head_sequence_axis: AxisType = None
bias_key_sequence_axis: AxisType = None
decode_batch_axis: AxisType = NOT_GIVEN
decode_query_sequence_axis: AxisType = None
decode_head_axis: AxisType = NOT_GIVEN
decode_kv_head_axis: AxisType = NOT_GIVEN
decode_key_sequence_axis: AxisType = NOT_GIVEN
decode_attention_dim_axis: AxisType = None
decode_attention_kv_dim_axis: AxisType = None
_SEMANTIC_MAP: tp.ClassVar[tp.Dict[str, str]] = {
BATCH: "batch_axis",
LENGTH: "sequence_axis",
QUERY_LENGTH: "query_sequence_axis",
KV_LENGTH: "key_sequence_axis",
EMBED: "hidden_state_axis",
HEAD: "head_axis",
KV_HEAD: "kv_head_axis",
MLP_INTERMEDIATE: "mlp_intermediate_axis",
VOCAB: "vocab_axis",
EXPERT: "expert_axis",
EXPERT_GATE: "expert_gate_axis",
HEAD_DIM: "attention_dim_axis",
KV_HEAD_DIM: "decode_attention_kv_dim_axis",
BIAS_HEAD_SEQ: "bias_head_sequence_axis",
BIAS_KV_SEQ: "bias_key_sequence_axis",
EMPTY: None, # Represents an unsharded dimension
}
"""
Maps semantic axis name constants (e.g., BATCH) to their corresponding
attribute names in the PartitionAxis class (e.g., "batch_axis").
"""
_STANDARD_TO_GENERATION_ATTR_MAP: tp.ClassVar[tp.Dict[str, str]] = {
"batch_axis": "decode_batch_axis",
"query_sequence_axis": "decode_query_sequence_axis",
"key_sequence_axis": "decode_key_sequence_axis",
"head_axis": "decode_head_axis",
"kv_head_axis": "decode_kv_head_axis",
"attention_dim_axis": "decode_attention_dim_axis",
"attention_kv_dim_axis": "decode_attention_kv_dim_axis",
}
"""
Maps standard axis attribute names to their corresponding generation-specific
attribute names. Used to apply different sharding rules during generation modes.
"""
def __post_init__(self):
"""
Post-initialization hook to resolve default axis values.
If an axis attribute is set to NOT_GIVEN, its value is resolved based
on default logic, typically using the standard mesh dimension names.
"""
resolved_values = {}
def resolve_field(name, default_logic):
"""Helper to resolve a single field's value if it's NOT_GIVEN."""
current_value = getattr(self, name)
if current_value is NOT_GIVEN:
resolved_values[name] = default_logic()
elif name not in resolved_values:
resolved_values[name] = current_value
def get_resolved(name):
"""Helper to get a field's value, prioritizing resolved values."""
return resolved_values.get(name, getattr(self, name))
# Resolve standard axis defaults
resolve_field(
"batch_axis",
lambda: (self.fully_sharded_data_parallel_axis, self.data_parallel_axis),
)
resolve_field("sequence_axis", lambda: self.sequence_parallel_axis)
resolve_field("query_sequence_axis", lambda: self.sequence_parallel_axis)
# Default qS = S rule
resolve_field("head_axis", lambda: self.tensor_parallel_axis)
resolve_field("kv_head_axis", lambda: None)
resolve_field("key_sequence_axis", lambda: self.sequence_parallel_axis)
# Default kS = S rule
resolve_field("hidden_state_axis", lambda: self.tensor_parallel_axis)
resolve_field("mlp_intermediate_axis", lambda: self.tensor_parallel_axis)
resolve_field("vocab_axis", lambda: self.tensor_parallel_axis)
resolve_field("expert_axis", lambda: self.expert_parallel_axis)
# Resolve generation-specific axis defaults based on standard axes
resolve_field("decode_batch_axis", lambda: get_resolved("batch_axis"))
resolve_field("decode_head_axis", lambda: get_resolved("head_axis"))
resolve_field("decode_kv_head_axis", lambda: get_resolved("kv_head_axis"))
resolve_field("decode_key_sequence_axis", lambda: get_resolved("key_sequence_axis"))
# Ensure all fields are included in resolved_values, even if not NOT_GIVEN
for fld in dataclasses.fields(self):
if fld.name not in resolved_values and fld.name not in [
"_SEMANTIC_MAP",
"_STANDARD_TO_GENERATION_ATTR_MAP",
]:
resolved_values[fld.name] = getattr(self, fld.name)
# Update the instance attributes with the resolved values
for name, value in resolved_values.items():
object.__setattr__(self, name, value)
# Perform a safety check to ensure all NOT_GIVEN values were resolved
self._safety_check()
def _safety_check(self):
"""
Checks if any axis attribute still has the NOT_GIVEN value after resolution.
Raises:
ValueError: If any attribute is still NOT_GIVEN, indicating a
configuration error.
"""
for fld in dataclasses.fields(self):
if fld.name not in ["_SEMANTIC_MAP", "_STANDARD_TO_GENERATION_ATTR_MAP"]:
val = getattr(self, fld.name)
if val == NOT_GIVEN:
raise ValueError(f"Partitioning rule `{fld.name}` was not resolved.")
[docs] def resolve_spec(
self,
axes: tp.Sequence[tp.Optional[str]],
mode: RUNTIME_MODE_TYPES, # type:ignore
) -> PartitionSpec:
"""
Generates a PartitionSpec from a sequence of semantic axis names and a mode.
Maps a sequence of semantic axis name strings (like BATCH, LENGTH) to the
actual mesh axis names defined in this `PartitionAxis` instance, considering
the current runtime mode (e.g., training vs. generation).
Args:
axes: A sequence of semantic axis name strings (e.g., [BATCH, LENGTH, HEAD])
or None (or "_") for axes that shouldn't be sharded.
mode: The current operational mode (e.g., MODE_TRAIN,
MODE_DECODE) which determines if generation-specific
rules should be applied.
Returns:
A jax.sharding.PartitionSpec instance representing the sharding
for the given sequence of axes.
Raises:
ValueError: If an unknown semantic axis name is encountered or if
a resolved axis rule is still NOT_GIVEN (should be caught
by `_safety_check` but included for robustness).
LookupError: If an internal attribute name derived from the semantic
map isn't found in the instance (shouldn't happen with
correct class definition).
"""
resolved_rules: list[AxisType] = []
for axis_name in axes:
if axis_name is None or axis_name == "_":
# None or "_" explicitly means no sharding for this dimension
resolved_rules.append(None)
continue
# Look up the standard attribute name from the semantic map
standard_attr_name = self._SEMANTIC_MAP.get(axis_name)
if standard_attr_name is None:
raise ValueError(f"Unknown semantic axis name: '{axis_name}'")
target_attr_name = standard_attr_name
# If in a generation mode, check for a specific generation rule
if mode in GENERATION_MODES:
gen_attr_name = self._STANDARD_TO_GENERATION_ATTR_MAP.get(standard_attr_name)
if gen_attr_name:
# If a generation-specific attribute exists and is not None or NOT_GIVEN, use it
if hasattr(self, gen_attr_name):
gen_val = getattr(self, gen_attr_name)
if gen_val is not None and gen_val is not NOT_GIVEN:
target_attr_name = gen_attr_name
try:
# Get the mesh axis rule (string or tuple of strings) from the instance
mesh_axis_rule: AxisType = getattr(self, target_attr_name)
except AttributeError as e:
# This indicates a mismatch between _SEMANTIC_MAP/_STANDARD_TO_GENERATION_ATTR_MAP
# and the actual class attributes.
raise LookupError(
f"Internal error: Attribute '{target_attr_name}' not found in PartitionAxis instance."
) from e
if mesh_axis_rule is NOT_GIVEN:
# This should ideally be caught by _safety_check, but included here too
raise ValueError(
f"Resolved axis rule for '{axis_name}' ('{target_attr_name}') is still NOT_GIVEN."
)
resolved_rules.append(mesh_axis_rule)
# Create and return the PartitionSpec tuple
return PartitionSpec(*resolved_rules)
__hash__ = hash_fn
class PartitionManager(PyTree):
"""
Context manager for applying sharding constraints using PartitionAxis.
This class acts as a context manager (`with PartitionManager(...)`) to
set a context-local variable (`_CURRENT_PARTITION_MANAGER`) that makes
the current manager implicitly available via functions like
`get_current_partition_manager()` or the static `shard()` method.
Args:
paxis: The PartitionAxis instance defining the sharding strategy
to be used within this context.
"""
paxis: PartitionAxis
def __post_init__(self):
if not isinstance(self.paxis, PartitionAxis):
raise TypeError(f"Expected PartitionAxis, got {type(self.paxis)}")
def shard(
self,
x: jax.Array,
axes: tp.Sequence[tp.Optional[str]] = NOT_GIVEN,
mode: RUNTIME_MODE_TYPES | int = NOT_GIVEN, # type:ignore
dynamic_axes: tp.Optional[DynamicShardingAxes] = NOT_GIVEN,
auto_correct: bool = True,
) -> jax.Array:
"""
Applies sharding constraint to a JAX array based on the active PartitionManager context.
Retrieves the current `PartitionManager` implicitly using `get_current_partition_manager()`
and uses its `PartitionAxis` to resolve the semantic axis names (`axes`) into a
`PartitionSpec`. It then applies the sharding constraint to the array `x`.
Supports specifying axes and mode directly, or providing a `DynamicShardingAxes`
named tuple. Can also infer the mode based on a dimension size if an integer
`mode` is provided.
Args:
x: The JAX array to apply the sharding constraint to.
axes: A sequence of semantic axis name strings or None. Required if
`dynamic_axes` is NOT_GIVEN.
mode: The runtime mode (string constant) or an integer representing
the dimension index to check for mode inference. Required if
`dynamic_axes` is NOT_GIVEN.
dynamic_axes: An optional `DynamicShardingAxes` named tuple that
provides both `axes` and `mode`. If provided, `axes` and
`mode` arguments are ignored.
auto_correct: If True, automatically corrects the resolved `PartitionSpec`
based on array shape and mesh compatibility using
`get_corrected_named_sharding`. Defaults to True.
Returns:
The array `x` with the sharding constraint applied.
Raises:
LookupError: If called outside of an active `PartitionManager` context.
AssertionError: If neither `axes`/`mode` nor `dynamic_axes` are provided.
ValueError: Propagated from `PartitionAxis.resolve_spec` or if resolved
axis rule is NOT_GIVEN.
"""
spec = self.resolve(
axes=axes,
mode=mode,
dynamic_axes=dynamic_axes,
shape=x.shape,
)
if auto_correct:
spec = get_corrected_named_sharding(x.shape, spec).spec
return with_sharding_constraint(x, spec)
def resolve(
self,
axes: tp.Sequence[tp.Optional[str]] = NOT_GIVEN,
mode: RUNTIME_MODE_TYPES | int = NOT_GIVEN, # type:ignore
dynamic_axes: tp.Optional[DynamicShardingAxes] = NOT_GIVEN,
shape=NOT_GIVEN,
) -> jax.Array:
if axes is NOT_GIVEN or mode is NOT_GIVEN:
assert dynamic_axes is not NOT_GIVEN, (
"if axes or mode is empty you should provide dynamic axes"
)
axes = dynamic_axes.axes
mode = dynamic_axes.mode
if isinstance(mode, int):
assert shape is not NOT_GIVEN, (
"when using dynamic mode detection shape should be provided"
)
mode = MODE_DECODE if shape[mode] == 1 else MODE_TRAIN
return self.paxis.resolve_spec(axes, mode)
def __str__(self):
"""String representation of the PartitionManager."""
return "PartitionManager(...)"
def __repr__(self):
"""Representation of the PartitionManager."""
return "PartitionManager(...)"
__hash__ = hash_fn
def apply_logical_sharding(
x: jax.Array,
partition_manager: PartitionManager,
axes: tp.Sequence[tp.Optional[str]] = NOT_GIVEN,
mode: RUNTIME_MODE_TYPES | int = NOT_GIVEN, # type:ignore
dynamic_axes: tp.Optional[DynamicShardingAxes] = NOT_GIVEN,
auto_correct: bool = True,
):
"""
Applies logical sharding to a JAX array using an available PartitionManager.
This function is a convenience wrapper around `PartitionManager.shard`.
It attempts to find a `PartitionManager` from the current context first
(`get_current_partition_manager`), and if none is found, it falls back
to the last created manager (`get_partition_manager`).
Args:
x: The JAX array to apply sharding to.
partition_manager: An explicit `PartitionManager` instance to use.
axes: A sequence of semantic axis name strings or None. Required if
`dynamic_axes` is NOT_GIVEN and `partition_manager` is NOT_GIVEN.
mode: The runtime mode or dimension index for inference. Required if
`dynamic_axes` is NOT_GIVEN and `partition_manager` is NOT_GIVEN.
dynamic_axes: An optional `DynamicShardingAxes` tuple. If provided,
`axes` and `mode` are ignored.
auto_correct: If True, automatically corrects the resolved PartitionSpec.
Defaults to True.
Returns:
The JAX array with sharding constraints applied.
Raises:
AssertionError: If neither `axes`/`mode` nor `dynamic_axes` are provided
when a manager is found or provided.
"""
return partition_manager.shard(
x=x,
axes=axes,
mode=mode,
dynamic_axes=dynamic_axes,
auto_correct=auto_correct,
)