easydel.trainers.trainer.modeling_output#
Training output data structures.
This module defines the output types returned by training operations, including the final state and associated metadata from training runs.
- class easydel.trainers.trainer.modeling_output.TrainerOutput(state: Any, mesh: jax._src.mesh.Mesh | None, checkpoint_manager: Any, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]] = None, shard_fns: Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]] = None, last_save_file_name: str | None = None, checkpoint_path: str | None = None)[source]#
Bases:
objectOutput from a training run.
Contains the final model state and metadata from the training process, including checkpointing information and utility functions for state manipulation.
- state#
The final EasyDeLState after training completion.
- Type
Any
- mesh#
The JAX sharding mesh used during training, if any.
- Type
jax._src.mesh.Mesh | None
- checkpoint_manager#
Manager object for handling model checkpoints.
- Type
Any
- gather_fns#
Functions for gathering sharded parameters to host.
- Type
Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]]
- shard_fns#
Functions for sharding parameters across devices.
- Type
Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]]
- last_save_file_name#
Name of the most recently saved checkpoint file.
- Type
str | None
- checkpoint_path#
Full path to the checkpoint directory.
- Type
str | None
- checkpoint_manager: Any#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]] = None#
- mesh: jax._src.mesh.Mesh | None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- shard_fns: Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]] = None#
- state: Any#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.