easydel.trainers.trainer.trainer#

Main Trainer implementation for EasyDeL.

This module contains the core Trainer class that orchestrates the complete training pipeline for neural network models using JAX/Flax. The trainer provides a high-level interface for:

  • Distributed training across multiple devices and hosts

  • Automatic mixed precision training

  • Gradient accumulation for large batch sizes

  • Comprehensive checkpointing and recovery

  • Integration with various data loaders (Grain, TensorFlow datasets)

  • Metrics tracking and logging (WandB, TensorBoard)

  • Memory-efficient training with sharding strategies

The Trainer class is designed to be flexible and extensible, supporting various model architectures including language models, vision models, and multimodal architectures.

class easydel.trainers.trainer.trainer.Trainer(arguments: TrainingArguments | None = None, model_state: EasyDeLState | None = None, model: tp.type[EasyDeLBaseModule] | None = None, dataset_train: Dataset | IterableDataset | ShardedDataSource | None = None, dataset_eval: Dataset | IterableDataset | ShardedDataSource | None = None, data_collator: tp.Callable | None = None, finetune: bool = True, processing_class: PreTrainedTokenizerBase | None = None, **deprecated_kwargs)[source]#

Bases: BaseTrainer

Main trainer implementation for EasyDeL models.

This class provides a complete training and evaluation pipeline for JAX-based models with support for distributed training, gradient accumulation, mixed precision, and various optimization strategies.

The trainer handles: - Distributed training across multiple devices and hosts - Automatic checkpointing and resumption - Gradient accumulation for large effective batch sizes - Learning rate scheduling and optimization - Comprehensive metrics tracking and logging - Memory-efficient data loading with Grain or TensorFlow datasets

Key Features: - JIT compilation of training and evaluation steps - Automatic mixed precision training - Support for model and data parallelism - Integration with WandB and TensorBoard - Flexible data collation and preprocessing

Example

>>> trainer = Trainer(
...     arguments=training_args,
...     model=model,
...     dataset_train=train_dataset,
...     dataset_eval=eval_dataset
... )
>>> output = trainer.train()
configure_functions() TrainerConfigureFunctionOutput[source]#

Configure and JIT-compile training and evaluation step functions.

This method is crucial for performance as it: 1. Sets up proper sharding specifications for distributed training 2. JIT-compiles the step functions with appropriate static arguments 3. Configures input/output sharding for efficient data movement 4. Sets up the checkpoint manager for model persistence

The compilation process traces through the computation graph once and generates optimized XLA code for subsequent executions.

Returns

Contains:
  • sharded_training_step_function: JIT-compiled training function with gradient computation and parameter updates

  • sharded_evaluation_step_function: JIT-compiled evaluation function for forward passes only

  • mesh: Device mesh for distributed computation

  • checkpoint_manager: AsyncCheckpointManager for saving/loading

Return type

TrainerConfigureFunctionOutput

Note

  • Static arguments are traced at compile time and cannot change

  • The donate_argnums=(0,) for training allows in-place updates

  • Empty sharding specs indicate replication across devices

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#

Creates a function to collect and process batches of data for training or evaluation.

This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.

Returns

A function that takes a batch of data and returns a processed batch.

Return type

tp.Callable

create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable

eval(model_state: EasyDeLState) Iterator[dict][source]#

Evaluate the model on the evaluation dataset.

This method performs model evaluation without gradient computation, yielding metrics for each evaluation step. It’s useful for: - Periodic evaluation during training - Final model evaluation after training - Standalone evaluation of checkpoints

The evaluation process: 1. Switches to evaluation mode (no gradient computation) 2. Iterates through the evaluation dataset 3. Computes forward passes and metrics 4. Yields metrics for monitoring and analysis 5. Handles multi-host synchronization

Parameters

model_state – Model state containing parameters for evaluation. This can be different from the training state, allowing evaluation of checkpoints or other models.

Yields

dict

Evaluation metrics for each step, including:
  • loss: Average loss value

  • accuracy: Average accuracy (if applicable)

  • throughput: Tokens/samples per second

  • Additional model-specific metrics

Raises

AssertionError – If evaluation dataloader is not configured

Example

>>> for metrics in trainer.eval(model_state):
...     print(f"Eval loss: {metrics['eval/loss']}")

Note

  • Evaluation is performed without gradient computation

  • Catches RuntimeError from multi-host synchronization issues

  • Progress bar shows evaluation progress in real-time

train() TrainerOutput[source]#

Execute the complete training pipeline.

This is the main entry point for training. It orchestrates the entire training workflow from initialization to completion:

  1. Calls start_training_hook for custom initialization

  2. Sets up metrics tracking and logging infrastructure

  3. Logs initial configuration and model information

  4. Executes the main training loop across all epochs

  5. Handles interruptions and saves final checkpoints

  6. Runs final evaluation if configured

  7. Cleans up resources and returns results

The method is designed to be robust to interruptions and will save the model state before exiting on errors or keyboard interrupts.

Returns

Contains:
  • state: Final model state after training

  • mesh: Device mesh used for training

  • checkpoint_path: Path to the final checkpoint

  • last_save_file_name: Name of the last saved file

Return type

TrainerOutput

Example

>>> trainer = Trainer(arguments=args, model=model, ...)
>>> output = trainer.train()
>>> print(f"Final loss: {output.state.metrics['loss']}")

Note

  • Automatically resumes from checkpoints if configured

  • Saves checkpoints periodically based on save_steps

  • Can be interrupted with Ctrl+C without losing progress