Source code for easydel.utils.checkpoint_managers.streamer

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import typing as tp
from functools import partial

import jax
import jax.experimental
import jax.experimental.multihost_utils
import jax.numpy as jnp
import msgpack
import safetensors
from flax.serialization import to_bytes, to_state_dict
from flax.struct import PyTreeNode
from tqdm import tqdm
from eformer.jaximus import implicit
from easydel.utils.helpers import get_logger

from ..traversals import flatten_dict, is_flatten, unflatten_dict

logger = get_logger(__name__)

ALLOWED_DATA_TYPES = [
	jnp.int4,
	jnp.int8,
	jnp.int16,
	jnp.int32,
	jnp.int64,
	jnp.uint4,
	jnp.uint8,
	jnp.uint16,
	jnp.uint32,
	jnp.uint64,
	jnp.float16,
	jnp.float32,
	jnp.float64,
	jnp.bfloat16,
	jnp.float_,
]


@implicit
def put_dtype(
	array: jax.Array,
	dtype: tp.Optional[tp.Union[str, jnp.dtype]],
) -> jax.Array:
	"""
	Get the tensor with the specified data type.

	Args:
		array: The input tensor.
		dtype: The desired data type.

	Returns:
		The tensor with the specified data type.
	"""
	if not dtype:
		return array

	if isinstance(dtype, str):
		dtype_map = {
			"bf16": jnp.bfloat16,
			"bfloat16": jnp.bfloat16,
			"fp16": jnp.float16,
			"float16": jnp.float16,
			"fp32": jnp.float32,
			"float32": jnp.float32,
			"fp64": jnp.float64,
			"float64": jnp.float64,
			"fp8": jnp.float8_e5m2,
			"fp8_e4m3fn": jnp.float8_e4m3fn,
			"fp8_e4m3fnuz": jnp.float8_e4m3fnuz,
			"fp8_e4m3b11fnuz": jnp.float8_e4m3b11fnuz,
			"fp8_e5m2": jnp.float8_e5m2,
			"fp8_e5m2fnuz": jnp.float8_e5m2fnuz,
			"float8_e4m3fn": jnp.float8_e4m3fn,
			"float8_e4m3fnuz": jnp.float8_e4m3fnuz,
			"float8_e4m3b11fnuz": jnp.float8_e4m3b11fnuz,
			"float8_e5m2": jnp.float8_e5m2,
			"float8_e5m2fnuz": jnp.float8_e5m2fnuz,
		}[dtype]
		try:
			dtype = dtype_map[dtype]
		except KeyError as e:
			raise ValueError(f"Unsupported dtype string: {dtype}") from e

	if array.dtype in (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64):
		return array.astype(dtype)
	return array  # Return original array if it's not a float


def _read_process_array(
	key,
	shard_fns,
	mismatch_allowed,
	manager,
	callback: tp.Optional[tp.Callable[[jax.Array, str], jax.Array]] = None,
	dtype: tp.Optional[tp.Union[str, jnp.dtype]] = None,
):
	"""Helper function to process a single tensor from a checkpoint."""
	tensor = manager.get_tensor(key)
	mismatch = 0
	if shard_fns:
		try:
			callable_func = shard_fns.get(key)
			if callable_func is None:
				if not mismatch_allowed:
					raise KeyError(
						f"Shard Function {key} is None and NoneType OBJ is not callable."
					)
				mismatch = 1
			else:
				tensor = callable_func(tensor)
		except KeyError as k_err:
			if not mismatch_allowed:
				raise KeyError(k_err) from None
			mismatch = 1

	if callback:
		tensor = callback(tensor, key)
	tensor = put_dtype(tensor, dtype)
	return key, tensor, mismatch


[docs]class CheckpointManager: """ A class to manage saving and loading checkpoints. Args: checkpoint_dir: The directory to save checkpoints to. enable: Whether to enable saving and loading checkpoints. float_dtype: The floating-point data type to use for saving checkpoints. save_optimizer_state: Whether to save the optimizer state in the checkpoint. verbose: Whether to print verbose output. """ def __init__( self, checkpoint_dir: tp.Union[str, os.PathLike], enable: tp.Optional[bool] = None, float_dtype: jnp.dtype = jnp.bfloat16, save_optimizer_state: bool = True, verbose: bool = False, ): self.float_dtype = float_dtype self.save_optimizer_state = save_optimizer_state self.checkpoint_dir = checkpoint_dir self.enable = enable self.verbose = verbose
[docs] @staticmethod def load_checkpoint( path: tp.Union[str, os.PathLike], shard_fns: tp.Optional[dict[tp.Callable]] = None, verbose: bool = False, mismatch_allowed: bool = True, callback: tp.Optional[tp.Callable[[jax.Array, str], jax.Array]] = None, dtype: tp.Optional[tp.Union[str, jnp.dtype]] = None, ) -> tp.Tuple[tp.Union[PyTreeNode, dict], dict]: """ Load a checkpoint from the given path. Args: path: The path to the checkpoint file. target: The target PyTree to load the checkpoint into. shard_fns: A dictionary of functions to shard the state after loading. verbose: Whether to print verbose output. mismatch_allowed: Whether to allow mismatches between the state dictionary and shard functions. callback: Optional callback applied to each loaded tensor Returns: A tuple containing the loaded state dictionary and metadata. """ with safetensors.safe_open(path, framework="flax") as f: metadata = f.metadata() keys = list(f.keys()) if shard_fns and not is_flatten(shard_fns): shard_fns = flatten_dict(shard_fns, sep=".") process_func = partial( _read_process_array, shard_fns=shard_fns, mismatch_allowed=mismatch_allowed, manager=f, callback=callback, dtype=dtype, ) results = [ process_func(key) for key in tqdm( keys, desc="Loading", total=len(keys), disable=not verbose, ) ] state = {key: tensor for key, tensor, _ in results} mismatch_count = sum(mismatch for _, _, mismatch in results) if verbose and mismatch_count: logger.info(f"Sharding mismatch: {mismatch_count}") state = unflatten_dict(state, sep=".") return state, metadata
[docs] @staticmethod def save_checkpoint( state: PyTreeNode, path: tp.Union[str, os.PathLike], gather_fns: tp.Optional[tp.Union[dict[tp.Callable], bool]] = None, float_dtype: tp.Optional[tp.Union[str, jnp.dtype]] = None, verbose: bool = True, mismatch_allowed: bool = True, metadata: tp.Optional[dict[str, str]] = None, enable: tp.Optional[bool] = None, ) -> tp.Union[str, os.PathLike]: """ Save a checkpoint to the given path using SafeTensors. Args: state: The state dictionary to save. path: The path to the checkpoint file. gather_fns: A dictionary of functions to gather the state before saving. float_dtype: The floating-point data type to use for saving the checkpoint. verbose: Whether to print verbose output. mismatch_allowed: Whether to allow mismatches between the state dictionary and gather functions. metadata: Additional metadata to store in the checkpoint. enable: whenever checkpointer is enable to save file or not. Returns: path where data is saved to. """ if enable is None: enable = jax.process_index() == 0 if not enable: path = "/dev/null" state = to_state_dict(state) gather_mismatch_count = 0 if not is_flatten(state): state = flatten_dict(state, sep=".") if gather_fns: pbar_gather = tqdm( list(state.keys()), desc="Gathering State", disable=not verbose, ) if isinstance(gather_fns, bool): for key in pbar_gather: pbar_gather.update(1) state[key] = jax.device_get(state[key]) else: if not is_flatten(gather_fns): gather_fns = flatten_dict(gather_fns, sep=".") for key in pbar_gather: callable_func = gather_fns.get(key, None) if callable_func is None: if not mismatch_allowed: raise KeyError(f"Gather Function {key} missing.") gather_mismatch_count += 1 else: state[key] = callable_func(state[key]) pbar_gather.set_postfix(gather_mismatch=gather_mismatch_count) pbar_gather.update(1) state = { key: put_dtype( jax.device_get(jnp.array(value)) if not isinstance(value, jax.Array) else value, float_dtype, ) for key, value in state.items() if value is not None } safetensors.flax.save_file(tensors=state, filename=path, metadata=metadata) return path
[docs] @staticmethod def save_state_to_file( state: PyTreeNode, path: tp.Union[str, os.PathLike], gather_fns: tp.Optional[dict[tp.Callable]] = None, float_dtype: tp.Optional[tp.Union[str, jnp.dtype]] = None, verbose: bool = False, mismatch_allowed: bool = True, ): """ Save the state dictionary to a file. Args: state: The state dictionary to save. path: The path to the file to save the state dictionary to. gather_fns: A dictionary of functions to gather the state before saving. float_dtype: The floating-point data type to use for saving the state dictionary. verbose: Whether to print verbose output. mismatch_allowed: Whether to allow mismatches between the state dictionary and gather functions. """ state = to_state_dict(state) packer = msgpack.Packer() flatten_state = flatten_dict(state) if gather_fns: gather_fns = flatten_dict(gather_fns) pbar = tqdm( flatten_state.items(), disable=not verbose, desc="Saving State to File", ) gather_mismatch_count = 0 with open(path, "wb") as stream: for key, value in pbar: if gather_fns: try: callable_func = gather_fns.get(key) if callable_func is None: if not mismatch_allowed: raise KeyError( f"Gather Function {key} is None and NoneType OBJ is not callable." ) gather_mismatch_count += 1 else: value = callable_func(value) except KeyError as k_err: if not mismatch_allowed: raise KeyError(k_err) from None gather_mismatch_count += 1 pbar.set_postfix(gather_mismatch=gather_mismatch_count) value = put_dtype(value, float_dtype) stream.write(packer.pack((key, to_bytes(value))))
[docs] def save_pickle(self, obj: object, filename: tp.Union[str, os.PathLike]): """ Save an object to a pickle file. Args: obj: The object to save. filename: The filename to save the object to. """ import pickle if self.enable: path = os.path.join(self.checkpoint_dir, filename) else: path = "/dev/null" with open(path, "wb") as stream: pickle.dump(obj, stream)