easydel.utils.checkpoint_managers.streamer#
- class easydel.utils.checkpoint_managers.streamer.CheckpointManager(checkpoint_dir: ~typing.Union[str, ~os.PathLike], enable: ~typing.Optional[bool] = None, float_dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>, save_optimizer_state: bool = True, verbose: bool = False)[source]#
Bases:
objectA class to manage saving and loading checkpoints.
- Parameters
checkpoint_dir – The directory to save checkpoints to.
enable – Whether to enable saving and loading checkpoints.
float_dtype – The floating-point data type to use for saving checkpoints.
save_optimizer_state – Whether to save the optimizer state in the checkpoint.
verbose – Whether to print verbose output.
- static load_checkpoint(path: Union[str, PathLike], shard_fns: Optional[dict[Callable]] = None, verbose: bool = False, mismatch_allowed: bool = True, callback: Optional[Callable[[Array, str], Array]] = None, dtype: Optional[Union[str, dtype]] = None) Tuple[Union[PyTreeNode, dict], dict][source]#
Load a checkpoint from the given path.
- Parameters
path – The path to the checkpoint file.
target – The target PyTree to load the checkpoint into.
shard_fns – A dictionary of functions to shard the state after loading.
verbose – Whether to print verbose output.
mismatch_allowed – Whether to allow mismatches between the state dictionary and shard functions.
callback – Optional callback applied to each loaded tensor
- Returns
A tuple containing the loaded state dictionary and metadata.
- static save_checkpoint(state: PyTreeNode, path: Union[str, PathLike], gather_fns: Optional[Union[dict[Callable], bool]] = None, float_dtype: Optional[Union[str, dtype]] = None, verbose: bool = True, mismatch_allowed: bool = True, metadata: Optional[dict[str, str]] = None, enable: Optional[bool] = None) Union[str, PathLike][source]#
Save a checkpoint to the given path using SafeTensors.
- Parameters
state – The state dictionary to save.
path – The path to the checkpoint file.
gather_fns – A dictionary of functions to gather the state before saving.
float_dtype – The floating-point data type to use for saving the checkpoint.
verbose – Whether to print verbose output.
mismatch_allowed – Whether to allow mismatches between the state dictionary and gather functions.
metadata – Additional metadata to store in the checkpoint.
enable – whenever checkpointer is enable to save file or not.
- Returns
path where data is saved to.
- save_pickle(obj: object, filename: Union[str, PathLike])[source]#
Save an object to a pickle file.
- Parameters
obj – The object to save.
filename – The filename to save the object to.
- static save_state_to_file(state: PyTreeNode, path: Union[str, PathLike], gather_fns: Optional[dict[Callable]] = None, float_dtype: Optional[Union[str, dtype]] = None, verbose: bool = False, mismatch_allowed: bool = True)[source]#
Save the state dictionary to a file.
- Parameters
state – The state dictionary to save.
path – The path to the file to save the state dictionary to.
gather_fns – A dictionary of functions to gather the state before saving.
float_dtype – The floating-point data type to use for saving the state dictionary.
verbose – Whether to print verbose output.
mismatch_allowed – Whether to allow mismatches between the state dictionary and gather functions.