easydel.trainers.trainer.trainer#
Main Trainer implementation for EasyDeL.
This module contains the core Trainer class that orchestrates the complete training pipeline for neural network models using JAX/Flax. The trainer provides a high-level interface for:
Distributed training across multiple devices and hosts
Automatic mixed precision training
Gradient accumulation for large batch sizes
Comprehensive checkpointing and recovery
Integration with various data loaders (Grain, TensorFlow datasets)
Metrics tracking and logging (WandB, TensorBoard)
Memory-efficient training with sharding strategies
The Trainer class is designed to be flexible and extensible, supporting various model architectures including language models, vision models, and multimodal architectures.
- class easydel.trainers.trainer.trainer.Trainer(arguments: TrainingArguments | None = None, model_state: EasyDeLState | None = None, model: tp.type[EasyDeLBaseModule] | None = None, dataset_train: Dataset | IterableDataset | ShardedDataSource | None = None, dataset_eval: Dataset | IterableDataset | ShardedDataSource | None = None, data_collator: tp.Callable | None = None, finetune: bool = True, processing_class: PreTrainedTokenizerBase | None = None, **deprecated_kwargs)[source]#
Bases:
BaseTrainerMain trainer implementation for EasyDeL models.
This class provides a complete training and evaluation pipeline for JAX-based models with support for distributed training, gradient accumulation, mixed precision, and various optimization strategies.
The trainer handles: - Distributed training across multiple devices and hosts - Automatic checkpointing and resumption - Gradient accumulation for large effective batch sizes - Learning rate scheduling and optimization - Comprehensive metrics tracking and logging - Memory-efficient data loading with Grain or TensorFlow datasets
Key Features: - JIT compilation of training and evaluation steps - Automatic mixed precision training - Support for model and data parallelism - Integration with WandB and TensorBoard - Flexible data collation and preprocessing
Example
>>> trainer = Trainer( ... arguments=training_args, ... model=model, ... dataset_train=train_dataset, ... dataset_eval=eval_dataset ... ) >>> output = trainer.train()
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configure and JIT-compile training and evaluation step functions.
This method is crucial for performance as it: 1. Sets up proper sharding specifications for distributed training 2. JIT-compiles the step functions with appropriate static arguments 3. Configures input/output sharding for efficient data movement 4. Sets up the checkpoint manager for model persistence
The compilation process traces through the computation graph once and generates optimized XLA code for subsequent executions.
- Returns
- Contains:
sharded_training_step_function: JIT-compiled training function with gradient computation and parameter updates
sharded_evaluation_step_function: JIT-compiled evaluation function for forward passes only
mesh: Device mesh for distributed computation
checkpoint_manager: AsyncCheckpointManager for saving/loading
- Return type
Note
Static arguments are traced at compile time and cannot change
The donate_argnums=(0,) for training allows in-place updates
Empty sharding specs indicate replication across devices
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start']) Callable[source]#
Creates a function to collect and process batches of data for training or evaluation.
This function handles padding or truncating sequences to the specified max_sequence_length based on the chosen truncation_mode.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode. Defaults to “keep_end”.
- Returns
A function that takes a batch of data and returns a processed batch.
- Return type
tp.Callable
- create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a collate/collect function to process batches of data for training or evaluation.
This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.
- Returns
A function that takes a batch (list of dicts) and returns a processed dict of arrays.
- Return type
tp.Callable
- create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a collate/collect function to process batches of data for training or evaluation.
This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.
- Returns
A function that takes a batch (list of dicts) and returns a processed dict of arrays.
- Return type
tp.Callable
- eval(model_state: EasyDeLState) Iterator[dict][source]#
Evaluate the model on the evaluation dataset.
This method performs model evaluation without gradient computation, yielding metrics for each evaluation step. It’s useful for: - Periodic evaluation during training - Final model evaluation after training - Standalone evaluation of checkpoints
The evaluation process: 1. Switches to evaluation mode (no gradient computation) 2. Iterates through the evaluation dataset 3. Computes forward passes and metrics 4. Yields metrics for monitoring and analysis 5. Handles multi-host synchronization
- Parameters
model_state – Model state containing parameters for evaluation. This can be different from the training state, allowing evaluation of checkpoints or other models.
- Yields
dict –
- Evaluation metrics for each step, including:
loss: Average loss value
accuracy: Average accuracy (if applicable)
throughput: Tokens/samples per second
Additional model-specific metrics
- Raises
AssertionError – If evaluation dataloader is not configured
Example
>>> for metrics in trainer.eval(model_state): ... print(f"Eval loss: {metrics['eval/loss']}")
Note
Evaluation is performed without gradient computation
Catches RuntimeError from multi-host synchronization issues
Progress bar shows evaluation progress in real-time
- train() TrainerOutput[source]#
Execute the complete training pipeline.
This is the main entry point for training. It orchestrates the entire training workflow from initialization to completion:
Calls start_training_hook for custom initialization
Sets up metrics tracking and logging infrastructure
Logs initial configuration and model information
Executes the main training loop across all epochs
Handles interruptions and saves final checkpoints
Runs final evaluation if configured
Cleans up resources and returns results
The method is designed to be robust to interruptions and will save the model state before exiting on errors or keyboard interrupts.
- Returns
- Contains:
state: Final model state after training
mesh: Device mesh used for training
checkpoint_path: Path to the final checkpoint
last_save_file_name: Name of the last saved file
- Return type
Example
>>> trainer = Trainer(arguments=args, model=model, ...) >>> output = trainer.train() >>> print(f"Final loss: {output.state.metrics['loss']}")
Note
Automatically resumes from checkpoints if configured
Saves checkpoints periodically based on save_steps
Can be interrupted with Ctrl+C without losing progress