easydel.trainers.trainer.modeling_output#

Training output data structures.

This module defines the output types returned by training operations, including the final state and associated metadata from training runs.

class easydel.trainers.trainer.modeling_output.TrainerOutput(state: Any, mesh: jax._src.mesh.Mesh | None, checkpoint_manager: Any, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]] = None, shard_fns: Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]] = None, last_save_file_name: str | None = None, checkpoint_path: str | None = None)[source]#

Bases: object

Output from a training run.

Contains the final model state and metadata from the training process, including checkpointing information and utility functions for state manipulation.

state#

The final EasyDeLState after training completion.

Type

Any

mesh#

The JAX sharding mesh used during training, if any.

Type

jax._src.mesh.Mesh | None

checkpoint_manager#

Manager object for handling model checkpoints.

Type

Any

gather_fns#

Functions for gathering sharded parameters to host.

Type

Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]]

shard_fns#

Functions for sharding parameters across devices.

Type

Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]]

last_save_file_name#

Name of the most recently saved checkpoint file.

Type

str | None

checkpoint_path#

Full path to the checkpoint directory.

Type

str | None

checkpoint_manager: Any#
checkpoint_path: str | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]] = None#
last_save_file_name: str | None = None#
mesh: jax._src.mesh.Mesh | None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

shard_fns: Optional[Union[Any, Mapping[str, Callable], dict[str, Callable]]] = None#
state: Any#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.