easydel.utils.checkpoint_managers.__init__#

class easydel.utils.checkpoint_managers.__init__.CheckpointManager(checkpoint_dir: ~typing.Union[str, ~os.PathLike], enable: ~typing.Optional[bool] = None, float_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, save_optimizer_state: bool = True, verbose: bool = False)[source]#

Bases: object

A class to manage saving and loading checkpoints.

Parameters
  • checkpoint_dir – The directory to save checkpoints to.

  • enable – Whether to enable saving and loading checkpoints.

  • float_dtype – The floating-point data type to use for saving checkpoints.

  • save_optimizer_state – Whether to save the optimizer state in the checkpoint.

  • verbose – Whether to print verbose output.

static load_checkpoint(path: Union[str, PathLike], shard_fns: Optional[dict[Callable]] = None, verbose: bool = False, mismatch_allowed: bool = True, callback: Optional[Callable[[Array, str], Array]] = None, dtype: Optional[Union[str, dtype]] = None) Tuple[Union[PyTreeNode, dict], dict][source]#

Load a checkpoint from the given path.

Parameters
  • path – The path to the checkpoint file.

  • target – The target PyTree to load the checkpoint into.

  • shard_fns – A dictionary of functions to shard the state after loading.

  • verbose – Whether to print verbose output.

  • mismatch_allowed – Whether to allow mismatches between the state dictionary and shard functions.

  • callback – Optional callback applied to each loaded tensor

Returns

A tuple containing the loaded state dictionary and metadata.

static save_checkpoint(state: PyTreeNode, path: Union[str, PathLike], gather_fns: Optional[Union[dict[Callable], bool]] = None, float_dtype: Optional[Union[str, dtype]] = None, verbose: bool = True, mismatch_allowed: bool = True, metadata: Optional[dict[str, str]] = None, enable: Optional[bool] = None) Union[str, PathLike][source]#

Save a checkpoint to the given path using SafeTensors.

Parameters
  • state – The state dictionary to save.

  • path – The path to the checkpoint file.

  • gather_fns – A dictionary of functions to gather the state before saving.

  • float_dtype – The floating-point data type to use for saving the checkpoint.

  • verbose – Whether to print verbose output.

  • mismatch_allowed – Whether to allow mismatches between the state dictionary and gather functions.

  • metadata – Additional metadata to store in the checkpoint.

  • enable – whenever checkpointer is enable to save file or not.

Returns

path where data is saved to.

save_pickle(obj: object, filename: Union[str, PathLike])[source]#

Save an object to a pickle file.

Parameters
  • obj – The object to save.

  • filename – The filename to save the object to.

static save_state_to_file(state: PyTreeNode, path: Union[str, PathLike], gather_fns: Optional[dict[Callable]] = None, float_dtype: Optional[Union[str, dtype]] = None, verbose: bool = False, mismatch_allowed: bool = True)[source]#

Save the state dictionary to a file.

Parameters
  • state – The state dictionary to save.

  • path – The path to the file to save the state dictionary to.

  • gather_fns – A dictionary of functions to gather the state before saving.

  • float_dtype – The floating-point data type to use for saving the state dictionary.

  • verbose – Whether to print verbose output.

  • mismatch_allowed – Whether to allow mismatches between the state dictionary and gather functions.