# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import pathlib
import pickle
import typing as tp
from functools import partial
import jax
import optax
from flax import nnx as nn
from flax import struct
from jax.sharding import NamedSharding, PartitionSpec
from safetensors.flax import load_file as safe_load_file
from safetensors.flax import save_file as safe_save_file
from easydel.utils.helpers import get_logger
from easydel.utils.traversals import specs_to_name_sharding
if tp.TYPE_CHECKING:
from jax.sharding import Mesh
from .base_module import EasyDeLBaseModule, PartitionLike
else:
EasyDeLBaseModule = tp.Any
PartitionLike = tp.Any
Mesh = tp.Any
WEIGHTS_NAME = "easydel-model.parameters"
OPTIMIZER_NAME = "easydel-optstate.parameters"
OPTIMIZER_STRUCT_NAME = "easydel-optstate.structure"
logger = get_logger(__name__)
[docs]class EasyDeLState(struct.PyTreeNode):
"""
**EasyDeLState A Snapshot of Your EasyDeL Model**
The `EasyDeLState` class acts like a comprehensive container that holds all the essential information about your EasyDeL
model at a given point in time. Think of it as a snapshot of your model. It includes
"""
step: int | jax.Array
graphdef: nn.GraphDef
graphstate: nn.GraphState
graphother: nn.GraphState
tx: optax.GradientTransformation = struct.field(pytree_node=False)
opt_state: tp.Optional[optax.OptState] = struct.field(pytree_node=True)
apply_fn: tp.Optional[tp.Callable] = None
[docs] def apply_gradients(self, *, grads):
"""
Applies gradients to the model parameters and updates the optimizer state.
This function is typically called during training to update the model based on the computed gradients.
Args:
grads: A dictionary of gradients, where keys correspond to model parameters.
Returns:
EasyDeLState: An updated EasyDeLState object with modified parameters and optimizer state.
"""
assert self.opt_state is not None
assert self.tx is not None
updates, new_opt_state = self.tx.update(
updates=grads,
state=self.opt_state,
params=self.graphstate,
)
if hasattr(self.tx, "apply_updates_hook"):
graphstate = self.tx.apply_updates_hook(self.graphstate, updates)
else:
graphstate = optax.apply_updates(self.graphstate, updates)
return self.replace(
step=self.step + 1,
graphstate=graphstate,
opt_state=new_opt_state,
)
[docs] @classmethod
def create(
cls,
*, # Force keyword arguments
step: tp.Optional[int] = None,
graphdef: tp.Optional[nn.GraphDef] = None,
graphstate: tp.Optional[nn.GraphState] = None,
graphother: tp.Optional[nn.GraphState] = None,
model: tp.Optional[nn.Module] = None,
tx: tp.Optional[optax.GradientTransformation] = None,
opt_state: tp.Optional[optax.OptState] = None,
init_opt_state: bool = False,
) -> EasyDeLState:
"""
Create an instance with flexible initialization options.
Args:
step: Optional number of training steps.
graphdef: Optional graph definition.
graphstate: Optional graph state.
graphother: Optional graph *others.
model: Optional neural network module.
tx: Optional gradient transformation.
opt_state: Optional optimizer state.
Raises:
ValueError: If initialization parameters are inconsistent.
"""
# Validate mutual exclusivity of model and graph-related parameters
graph_params_provided = (
graphdef is not None or graphstate is not None or graphother is not None
)
if model is not None and graph_params_provided:
raise ValueError(
"Cannot provide both a model and graph-related parameters. "
"Choose either model or (graphdef, graphstate)."
)
if model is not None:
graphdef, graphstate, graphother = nn.split(model, nn.Param, ...)
if graphdef is not None and graphstate is None and graphother is None:
raise ValueError(
"When providing graphdef, (graphstate, graphother) must also be provided.",
)
if graphstate is not None and graphdef is None and graphother is None:
raise ValueError(
"When providing graphstate, (graphdef, graphother) must also be provided.",
)
if graphother is not None and graphdef is None and graphstate is None:
raise ValueError(
"When providing graphother, (graphstate, graphdef) must also be provided.",
)
if init_opt_state and opt_state is not None:
raise ValueError(
"When passing `init_opt_state` as `True` you can't also provide `opt_state`"
)
if init_opt_state and tx is None:
raise ValueError(
"When passing `init_opt_state` as `True` you have to also provide `tx`."
)
if init_opt_state:
opt_state = tx.init(graphstate)
if step is None:
step = 0
return cls(
step=step,
graphdef=graphdef,
graphstate=graphstate,
graphother=graphother,
tx=tx,
opt_state=opt_state,
)
[docs] def init_tx(
self,
tx: optax.GradientTransformation,
partition_rules: PartitionLike = None,
) -> EasyDeLState:
"""
Initialize the optimizer state with the given gradient transformation.
Args:
tx (optax.GradientTransformation): A gradient transformation to initialize the optimizer state.
partition_rules (Optional[Any], optional): Rules for partitioning the optimizer state. Defaults to None.
Returns:
EasyDeLState: An updated EasyDeLState object with the new gradient transformation and sharded optimizer state.
"""
if partition_rules is None:
partition_rules = self.model.config.get_partition_rules()
from eformer.escale import match_partition_rules
eval_opt_state = jax.eval_shape(lambda: tx.init(self.graphstate))
partition_specs = match_partition_rules(partition_rules, eval_opt_state)
named_shardings = specs_to_name_sharding(partition_specs, self.model.mesh)
@partial(jax.jit, out_shardings=named_shardings)
def make():
return tx.init(self.graphstate)
opt_state = make()
return self.replace(tx=tx, opt_state=opt_state)
[docs] def shard_optimizer_state(
self,
opt_state: tp.Optional[tp.Any] = None,
partition_rules: PartitionLike = None,
) -> tp.Any:
"""
Shards the optimizer state according to the provided partition rules.
Args:
opt_state (Optional[Any]): The optimizer state to be sharded. If None, the method will use `self.opt_state`.
Raises a ValueError if both `opt_state` and `self.opt_state` are None.
partition_rules (Optional[Any]): The partition rules to be used for sharding. If None, the method will use
the partition rules from `self.model.config`.
Returns:
Any: The sharded optimizer state.
Raises:
ValueError: If both `opt_state` and `self.opt_state` are None.
"""
if opt_state is None and self.opt_state is None:
raise ValueError("Optimizer state is not initialized.")
if opt_state is None:
opt_state = self.opt_state
if partition_rules is None:
partition_rules = self.model.config.get_partition_rules()
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
with self.model.mesh:
partition_specs = match_partition_rules(partition_rules, opt_state)
shard_fns, _ = make_shard_and_gather_fns(partition_specs)
opt_state = jax.tree_util.tree_map(
lambda f, o: f(o),
shard_fns,
opt_state,
)
return self.replace(opt_state=opt_state)
[docs] def gather_optimizer_state(self, partition_rules=None):
assert self.opt_state is not None, "Optimizer state is not initialized."
if partition_rules is None:
partition_rules = self.model.config.get_partition_rules()
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
partition_specs = match_partition_rules(partition_rules, self.opt_state)
_, gather = make_shard_and_gather_fns(partition_specs)
self = self.replace(
opt_state=jax.tree_util.tree_map(
lambda f, o: f(o),
gather,
self.opt_state,
)
)
return self
[docs] def merge(self, tree) -> EasyDeLBaseModule:
return nn.merge(self.graphdef, tree, self.graphother)
[docs] def merge_to_state(self, tree) -> EasyDeLState:
return self.replace(graphstate=tree)
@property
def model(self) -> EasyDeLBaseModule:
return nn.merge(self.graphdef, self.graphstate, self.graphother)
@property
def size(self) -> int:
"""
Calculates the total size of the optimizer state and model graph state.
Returns:
int: The total size in bytes.
"""
def calculate_size(pytree):
if pytree is None:
return 0
leaves, _ = jax.tree_util.tree_flatten(pytree)
return sum(
leaf.size * leaf.itemsize
for leaf in leaves
if isinstance(leaf, jax.numpy.ndarray)
)
opt_state_size = calculate_size(self.opt_state)
graphstate_size = calculate_size(self.graphstate)
return opt_state_size + graphstate_size
[docs] def load_optimizer(self, load_directory: tp.Union[str, os.PathLike]):
load_directory = pathlib.Path(load_directory)
optim_path = load_directory / OPTIMIZER_NAME
struct_path = load_directory / OPTIMIZER_STRUCT_NAME
if not (optim_path.exists() and struct_path.exists()):
raise FileNotFoundError(f"Optimizer files missing in {load_directory}")
try:
# All processes load simultaneously
with open(struct_path, "rb") as f:
tdef = pickle.load(f)
tensors = safe_load_file(str(optim_path))
ordered_params = [tensors[f"param_{i}"] for i in range(len(tensors))]
sharded_params = [arr for arr in ordered_params]
opt_state = jax.tree_util.tree_unflatten(tdef, sharded_params)
logger.info(f"Optimizer state loaded from {load_directory}")
self = self.replace(opt_state=opt_state)
return self
except Exception as e:
logger.error(f"Optimizer load failed: {str(e)}")
raise e
[docs] def save_state(
self,
save_directory: tp.Union[str, os.PathLike],
float_dtype: tp.Optional[jax.numpy.dtype] = None,
verbose: bool = True,
mismatch_allowed: bool = True,
save_optimizer: bool = True,
enable: tp.Optional[bool] = None,
):
save_directory = pathlib.Path(save_directory)
save_directory = pathlib.Path(save_directory)
if save_optimizer:
if enable is None:
enable = jax.process_index() == 0
if enable:
save_directory.mkdir(parents=True, exist_ok=True)
optim_path = save_directory / OPTIMIZER_NAME
struct_path = save_directory / OPTIMIZER_STRUCT_NAME
else:
optim_path = pathlib.Path("/dev/null")
struct_path = pathlib.Path("/dev/null")
logger.info(f"Coordinated optimizer save through {optim_path}")
try:
tdef = jax.tree_util.tree_structure(self.opt_state)
with open(struct_path, "wb") as f:
pickle.dump(tdef, f)
@partial(
jax.jit,
out_shardings=NamedSharding(self.model.mesh, PartitionSpec()),
)
def gather_fn(x):
return x
tree = jax.tree_util.tree_leaves(self.opt_state)
gathered = {
f"param_{i}": jax.device_get(gather_fn(param)) for i, param in enumerate(tree)
}
safe_save_file(tensors=gathered, filename=str(optim_path))
except Exception as e:
logger.error(f"Optimizer save failed: {str(e)}")
raise
else:
logger.info("Skipping optimizer saving as requested")
self.model.save_pretrained(
save_directory=str(save_directory),
gather_fns=self.model._gather_fns,
float_dtype=float_dtype,
mismatch_allowed=mismatch_allowed,
verbose=verbose,
enable=enable,
)
self,
load_directory: tp.Union[str, os.PathLike],
verbose: bool = True,
): ...
[docs] def shard_with_shape(self, shape) -> EasyDeLState:
"""shard current state with a given shape"""
from eformer.escale import with_sharding_constraint
self = nn.from_tree(
jax.tree_util.tree_map(
lambda arr, sharding: with_sharding_constraint(
arr,
sharding,
),
nn.to_tree(self),
nn.to_tree(shape),
)
)
return self
[docs] def shard_state(self, partition_rules: PartitionLike = None) -> EasyDeLState:
"""
Shards the entire state, according to the provided partition rules.
Args:
partition_rules (Optional[Any]): The partition rules to be used for sharding. If None, the method will use
the partition rules from `self.model.config`.
Returns:
EasyDeLState: An updated EasyDeLState object with the sharded state.
"""
with self.model.mesh:
if self.opt_state is not None:
self = self.shard_optimizer_state(partition_rules=partition_rules)
self = self.shard_model(partition_rules=partition_rules)
return self
[docs] def gather_state(self):
"""
Gathers the entire state.
Returns:
EasyDeLState: An updated EasyDeLState object with the gathered state.
"""
if self.opt_state is not None:
self = self.gather_optimizer_state()
self = self.gather_model()
return self
[docs] def gather_model(
self,
partition_rules: PartitionLike = None,
mesh: tp.Optional[Mesh] = None,
) -> EasyDeLState:
"""
Gathers the model according to the provided partition rules.
Returns:
EasyDeLState: An updated EasyDeLState object with the gathered model.
"""
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
rules = partition_rules or self.model._get_partition_rules(None)
mesh = mesh or self.model._get_mesh(None)
partition_specs = match_partition_rules(
rules=rules,
tree=self.graphstate,
)
_, gather_fns = make_shard_and_gather_fns(
partition_specs=partition_specs,
mesh=mesh,
)
graphstate = jax.tree_util.tree_map(
lambda f, o: f(o),
gather_fns,
self.graphstate,
)
graphother = jax.tree_util.tree_map(
lambda f, o: f(o),
gather_fns,
self.graphother,
)
self = self.replace(graphstate=graphstate, graphother=graphother)
return self
[docs] def shard_model(
self,
partition_rules: PartitionLike = None,
mesh: tp.Optional[Mesh] = None,
) -> EasyDeLState:
"""
Shards the model according to the provided partition rules.
Args:
partition_rules (Optional[Any]): The partition rules to be used for sharding. If None, the method will use
the partition rules from `self.model.config`.
mesh (Optional[Mesh]): The mesh to be used for sharding. If None, the method will use the mesh from `self.model`.
Returns:
EasyDeLState: An updated EasyDeLState object with the sharded model.
"""
rules = partition_rules or self.model._get_partition_rules(None)
mesh = mesh or self.model._get_mesh(None)
def appy_sharding_on_tree(tree):
from eformer.escale import make_shard_and_gather_fns, match_partition_rules
partition_specs = match_partition_rules(rules, tree)
shard_fns, _ = make_shard_and_gather_fns(partition_specs, mesh)
return jax.tree_util.tree_map(lambda f, o: f(o), shard_fns, tree)
graphstate = appy_sharding_on_tree(self.graphstate)
graphother = appy_sharding_on_tree(self.graphother)
self = self.replace(graphstate=graphstate, graphother=graphother)
return self
@property
def shardings(self):
"""
Returns the sharding information for the state.
Returns:
Any: The sharding information.
"""
return nn.from_tree(
jax.tree_util.tree_map(
lambda x: x.sharding if hasattr(x, "sharding") else None,
nn.to_tree(self),
)
)
def __repr__(self):
return "EasyDeLState-" + str(self.model)
__str__ = __repr__