easydel.trainers.trainer_protocol

Contents

easydel.trainers.trainer_protocol#

class easydel.trainers.trainer_protocol.BaseTrainerProtocol(arguments: TrainingArguments | None = None, model_state: EasyDeLState | None = None, model: tp.type[EasyDeLBaseModule] | None = None, dataset_train: Dataset | None = None, dataset_eval: Dataset | None = None, data_collator: tp.Callable | None = None, finetune: bool = True, **deprecated_kwargs)[source]#

Bases: object

Abstract base protocol defining the interface for all trainer implementations.

This protocol ensures that all trainer implementations provide the necessary methods and attributes for training and evaluation workflows. It defines the contract that concrete trainer classes must fulfill.

The protocol covers: - Initialization and configuration methods - Training and evaluation loops - Checkpoint management - Metrics logging and monitoring - Hook methods for customization

All methods marked with @abstractmethod must be implemented by subclasses.

abstract apply_training_hooks(metrics: LossMetrics) LossMetrics[source]#

Apply training hooks to check for issues and enforce limits.

Parameters

metrics – Current training metrics including loss.

Returns

Potentially modified metrics.

Return type

LossMetrics

Raises

Note

Checks for NaN losses and training time limits based on configuration in training arguments.

arguments: TrainingArguments#
abstract calculate_number_total_flops(params, is_training=True)[source]#

Calculate total FLOPs for the model.

checkpoint_manager: Any#
checkpoint_path: str | os.PathLike | None#
abstract compile_aot() bool[source]#

Compile the state ahead of time for faster execution.

config: EasyDeLBaseConfig#
abstract configure_dataloaders() TrainerConfigureDataloaderOutput[source]#

Configure dataloaders for training and evaluation.

Creates training and evaluation dataloaders using the provided datasets and data collator. Determines the maximum number of training and evaluation steps based on dataset sizes and arguments.

Returns

  • dataloader_train: Training data iterator

  • max_training_steps: Total training steps

  • dataloader_eval: Optional evaluation data iterator

  • max_evaluation_steps: Optional total evaluation steps

Return type

TrainerConfigureDataloaderOutput containing

Note

Automatically selects between Grain and TensorFlow dataloaders based on arguments.use_grain setting.

abstract configure_functions() TrainerConfigureFunctionOutput[source]#

Configure and JIT-compile training and evaluation step functions.

This method prepares the computational graph for training and evaluation, including setting up sharding specifications, compiling functions with appropriate static arguments, and initializing the checkpoint manager.

Returns

Compiled functions and infrastructure

Return type

TrainerConfigureFunctionOutput

abstract configure_model() TrainerConfigureModelOutput[source]#

Configure model, optimizer, scheduler, and configuration.

Retrieves model configuration from the model state and creates the optimizer and scheduler using training arguments.

Returns

  • model: The EasyDeL model instance

  • tx: Gradient transformation (optimizer)

  • scheduler: Learning rate schedule

  • config: Optional model configuration

Return type

TrainerConfigureModelOutput containing

Note

If pruning_module is set, it wraps the optimizer for structured pruning support.

abstract static count_model_parameters(prm)[source]#

Count total number of model parameters.

Parameters

prm – Model parameters (can be frozen or unfrozen PyTree).

Returns

Total number of parameters in the model.

Return type

int

Note

Handles both frozen and unfrozen Flax parameter dictionaries.

abstract create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Create a generic data collection function for batch processing.

Parameters
  • max_sequence_length – Maximum allowed sequence length for padding/truncation.

  • truncation_mode – How to truncate sequences exceeding max length: - “keep_end”: Keep the end of the sequence - “keep_start”: Keep the beginning of the sequence

Returns

Callable that processes batches of data.

Note

This is a generic version that can be used with any dataloader type. Implementations should handle padding, truncation, and format conversion.

abstract create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Create a Grain data collection function for batch processing.

Parameters
  • max_sequence_length – Maximum allowed sequence length for padding/truncation.

  • truncation_mode – How to truncate sequences exceeding max length: - “keep_end”: Keep the end of the sequence - “keep_start”: Keep the beginning of the sequence

Returns

Callable that processes batches for Grain dataloader.

Note

Function should handle padding, truncation, and data format conversion compatible with Grain’s data pipeline.

abstract create_progress_bar(total: int, desc: str = '', disabled: bool = False) BaseProgressBar[source]#

Create a progress bar of the specified type.

Parameters
  • total – Total number of steps for the progress bar.

  • desc – Description text to display.

  • disabled – Whether to disable the progress bar.

Returns

Progress bar instance of the configured type.

Return type

BaseProgressBar

Note

Type is determined by arguments.progress_bar_type: - “tqdm”: Standard tqdm progress bar - “rich”: Rich library progress bar with metrics - “json”: JSON-formatted progress output - disabled=True returns NullProgressBar

abstract create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Create a TensorFlow Dataset collection function for batch processing.

Parameters
  • max_sequence_length – Maximum allowed sequence length for padding/truncation.

  • truncation_mode – How to truncate sequences exceeding max length: - “keep_end”: Keep the end of the sequence - “keep_start”: Keep the beginning of the sequence

Returns

Callable that processes batches for TensorFlow datasets.

Note

Function should handle padding, truncation, and data format conversion compatible with tf.data API.

data_collator: Optional[Callable]#
dataloader_eval: Optional[Iterator[ndarray]]#
dataloader_train: Iterator[ndarray]#
dataset_eval: Any | None#
dataset_train: Any | None#
dtype: Any#
abstract eval(model_state: EasyDeLState) Iterator[dict][source]#

Evaluate the model on the evaluation dataset.

This method runs the model in evaluation mode, computing metrics without updating parameters. It yields metrics for each evaluation step, allowing for streaming evaluation and progress monitoring.

Parameters

model_state – The model state to evaluate

Yields

dict – Evaluation metrics for each step

evalu_tracker: CompilationTracker#
abstract property evaluation_batch_size#
finetune: bool#
abstract finish()[source]#

Finalize the training process.

abstract initialize_trainer_utils()[source]#

Initialize all trainer utilities in the correct order.

This orchestration method sets up all trainer components: 1. Weights & Biases logging (if enabled) 2. Training timer for performance monitoring 3. Dataloaders for training and evaluation 4. Model, optimizer, and learning rate scheduler 5. Model state sharding across devices 6. Compiled training and evaluation functions

Note

The initialization order is critical as later steps depend on earlier ones being completed.

abstract property is_enable#
abstract property is_process_zero#
abstract log_metrics(metrics: dict[str, float], pbar: BaseProgressBar, step: int, mode: str = 'train') None[source]#

Log metrics and update progress bar.

Parameters
  • metrics – Dictionary of metric names and values.

  • pbar – Progress bar instance to update.

  • step – Current step number.

  • mode – “train” or “eval” to prefix metrics.

Note

  • Updates progress bar every log_steps

  • Logs to wandb/tensorboard every report_steps

  • Filters out internal metrics (mlperf, grad_norm)

abstract log_weight_distribution(state: EasyDeLState, step: int)[source]#

Log weight distribution statistics.

Parameters
  • state – Model state containing parameters.

  • step – Current training step.

Note

Logs statistics like mean, std, min, max of weights for monitoring training stability.

max_evaluation_steps: int#
max_training_steps: int#
memory_monitor: Any#
abstract property mesh#
abstract property model#
model_state: EasyDeLState#
abstract on_step_end(state: EasyDeLState, metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int) tuple[easydel.infra.base_state.EasyDeLState, dict[str, float | list | tuple | numpy.ndarray | Any]][source]#

hook process to call in start of the step.

abstract on_step_start(state: EasyDeLState, step: int) EasyDeLState[source]#

hook process to call in start of the step.

param_dtype: Any#
pruning_module: Any#
abstract save_information(output_path: str | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath) None[source]#

Save the generated information to a markdown file.

abstract save_pretrained(state: EasyDeLState, save_directory: str | None = None, gather_fns: Optional[Union[Any, Mapping[str, Callable], dict[Callable]]] = None, to_torch: bool = False, easystate_to_huggingface_model_kwargs: dict | None = None, torch_save_pretrained_kwargs: dict | None = None)[source]#

Saves the model state as a checkpoint file or to a Torch compatible directory.

scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]#
sharded_evaluation_step_function: Any#
sharded_training_step_function: Any#
abstract specs_to_name_sharding(tree, mesh=None)[source]#

Convert specs to named sharding.

abstract start_evaluation_hook()[source]#

Hook to run before evaluation starts.

abstract start_training_hook()[source]#

Hook to run before training starts.

state: Any#
state_named_sharding: Any#
state_partition_spec: Any#
state_shape: Any#
state_shardings: Any#
timer: Timers#
abstract train() Any[source]#

Execute the complete training process.

This is the main entry point for training. It orchestrates the entire training workflow including initialization, training loops, evaluation, checkpointing, and finalization.

Returns

TrainerOutput or similar object containing final state and metrics

train_tracker: CompilationTracker#
abstract property training_batch_size#
tx: GradientTransformation#
wandb_runtime: Any#
class easydel.trainers.trainer_protocol.TrainerConfigureDataloaderOutput(dataloader_train: Iterator[ndarray], max_training_steps: int, dataloader_eval: Optional[Iterator[ndarray]] = None, max_evaluation_steps: int | None = None)[source]#

Bases: object

Output configuration for dataloader setup.

Contains the configured dataloaders and computed maximum steps for training and evaluation phases.

dataloader_train#

Iterator over training batches

Type

Iterator[numpy.ndarray]

max_training_steps#

Total number of training steps

Type

int

dataloader_eval#

Optional iterator over evaluation batches

Type

Optional[Iterator[numpy.ndarray]]

max_evaluation_steps#

Optional total number of evaluation steps

Type

int | None

dataloader_eval: Optional[Iterator[ndarray]] = None#
dataloader_train: Iterator[ndarray]#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

max_evaluation_steps: int | None = None#
max_training_steps: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.trainer_protocol.TrainerConfigureFunctionOutput(sharded_training_step_function: Any, mesh: Mesh, checkpoint_manager: AsyncCheckpointManager, sharded_evaluation_step_function: Any | None = None)[source]#

Bases: object

Output configuration for training and evaluation functions.

Contains the compiled step functions and supporting infrastructure.

sharded_training_step_function#

JIT-compiled training step function

Type

Any

mesh#

Device mesh for distributed computation

Type

jax._src.mesh.Mesh

checkpoint_manager#

Manager for saving/loading checkpoints

Type

eformer.serialization.async_manager.AsyncCheckpointManager

sharded_evaluation_step_function#

Optional JIT-compiled evaluation function

Type

Any | None

checkpoint_manager: AsyncCheckpointManager#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

mesh: Mesh#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sharded_evaluation_step_function: Any | None = None#
sharded_training_step_function: Any#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.trainer_protocol.TrainerConfigureModelOutput(model: EasyDeLBaseModule, tx: GradientTransformation, scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]], config: easydel.infra.base_config.EasyDeLBaseConfig | None = None)[source]#

Bases: object

Output configuration for model setup.

Contains the configured model, optimizer, scheduler, and model configuration.

model#

The initialized EasyDeL model

Type

easydel.infra.base_module.EasyDeLBaseModule

tx#

Gradient transformation (optimizer) for training

Type

optax._src.base.GradientTransformation

scheduler#

Learning rate schedule function

Type

collections.abc.Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]]

config#

Optional model configuration object

Type

easydel.infra.base_config.EasyDeLBaseConfig | None

config: easydel.infra.base_config.EasyDeLBaseConfig | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

model: EasyDeLBaseModule#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

scheduler: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]]#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

tx: GradientTransformation#
class easydel.trainers.trainer_protocol.TrainerOutput(state: EasyDeLState, mesh: jax._src.mesh.Mesh | None, last_save_file_name: str | None = None, checkpoint_path: str | None = None)[source]#

Bases: object

Final output from the training process.

Contains the final model state and checkpoint information.

state#

Final model state after training

Type

easydel.infra.base_state.EasyDeLState

mesh#

Device mesh used during training

Type

jax._src.mesh.Mesh | None

last_save_file_name#

Name of the last saved checkpoint file

Type

str | None

checkpoint_path#

Full path to the last checkpoint

Type

str | None

checkpoint_path: str | None = None#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

last_save_file_name: str | None = None#
mesh: jax._src.mesh.Mesh | None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

state: EasyDeLState#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.