BaseTrainer Documentation#
Introduction#
The BaseTrainer class in EasyDeL provides a robust and flexible foundation for training transformer models. It encapsulates the common logic required for training loops, evaluation, checkpointing, and configuration, allowing users to focus on the specifics of their model and data.
BaseTrainer implements the BaseTrainerProtocol, ensuring a consistent interface across different training scenarios. Use BaseTrainer when you need a feature-rich, configurable trainer for standard training tasks like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), etc., or as a starting point for more complex custom training workflows.
For details on the abstract interface that BaseTrainer implements, refer to the Trainer Protocol Documentation.
Core Concepts#
Understanding these core concepts is crucial for effectively using BaseTrainer:
TrainingArguments: A dataclass holding all hyperparameters and configuration settings for the training process (e.g., learning rate, batch size, number of epochs, checkpointing strategy, logging options). You initialize the trainer with an instance of this class.EasyDeLState: A FlaxTrainStatesubclass that bundles the model parameters, optimizer state, and potentially other training-related variables (like PRNG keys).BaseTrainermanages this state throughout the training process.Training Loop: The core process orchestrated by the
train()method. It involves iterating over epochs and steps, fetching data batches, executing the training step function, logging metrics, and handling checkpointing.Checkpointing:
BaseTrainerautomatically saves theEasyDeLState(model parameters, optimizer state) andTrainingArgumentsat specified intervals (e.g., every N steps or epochs) to allow resuming training later. It also manages a configurable limit on the number of checkpoints to keep.Distributed Training:
BaseTraineris designed with JAX’spmapandshmapin mind, enabling efficient training across multiple devices (GPUs/TPUs) with minimal code changes required from the user for standard setups.
Key Methods and Configuration#
BaseTrainer relies on several key methods, some of which you might override for customization:
__init__(self, arguments: TrainingArguments, ...): Initializes the trainer. RequiresTrainingArgumentsand dataset information.configure_model(self) -> TrainerConfigureModelOutput: (Abstract method in protocol, implemented inBaseTrainer) Initializes the model and the initialEasyDeLState. It typically loads a pretrained model based onarguments. Returns aTrainerConfigureModelOutputtuple containing the initializedmodelandstate.configure_dataloaders(self) -> TrainerConfigureDataloaderOutput: (Abstract method in protocol, implemented inBaseTrainer) Sets up the training and evaluation dataloaders based on the provided datasets andarguments. Returns aTrainerConfigureDataloaderOutputtuple containingdataloader_trainanddataloader_eval.configure_functions(self) -> TrainerConfigureFunctionOutput: (Abstract method in protocol, implemented in subclasses likeSFTTrainer) Defines the core training and evaluation step functions, often applyingjax.jitfor performance. Returns aTrainerConfigureFunctionOutputtuple containingtrain_step_fnandeval_step_fn.create_collect_function(self) -> Callable: (Optional override) Defines how batches are collated and preprocessed before being passed to the training/evaluation step functions. Useful for custom padding or data manipulation.train(self, ckpt_path: Optional[str] = None) -> EasyDeLState: The main entry point to start the training process. It orchestrates the entire training loop, including epoch/step iteration, data loading, calling thetrain_step_fn, logging, evaluation (if configured), and checkpointing. Optionally takes ackpt_pathto resume training. Returns the finalEasyDeLState.eval(self, state: EasyDeLState) -> Dict: Runs the evaluation loop using theeval_step_fnon the evaluation dataloader. Returns a dictionary of evaluation metrics. Typically called automatically bytrain()if an evaluation dataset and strategy are provided.save_pretrained(self, ckpt_path: str, state: Optional[EasyDeLState] = None): Saves the current training state (EasyDeLState) andTrainingArgumentsto the specifiedckpt_path. Called automatically during training based onarguments.save_strategy.
Configuration Example (TrainingArguments)#
from easydel.trainers import TrainingArguments
args = TrainingArguments(
model_name="MyFineTunedModel",
num_train_epochs=3,
learning_rate=2e-5,
total_batch_size=32, # Effective batch size (per_device_train_batch_size * num_devices)
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
gradient_accumulation_steps=1,
weight_decay=0.01,
max_steps=-1, # If > 0, overrides num_train_epochs
gradient_checkpointing="nothing_saveable", # Or "everything_saveable"
sharding_array=(1, -1, 1, 1), # Device mesh shape for model sharding
use_fast_kernels=True,
logging_steps=10,
save_steps=100, # Save checkpoint every 100 steps
save_total_limit=2, # Keep only the latest 2 checkpoints
evaluation_strategy="steps", # Or "epoch"
eval_steps=100, # Evaluate every 100 steps
output_dir="runs/my_model_run"
# Add other relevant TrainingArguments as needed
)
# Example usage (assuming MyTrainerSubclass exists):
# trainer = MyTrainerSubclass(arguments=args, ...)
# trainer.train()
Customization#
While BaseTrainer handles most standard scenarios, you can customize its behavior:
Subclassing: Create a new class inheriting from
BaseTrainer(or a more specific trainer likeSFTTrainer) and override methods likeconfigure_functionsorcreate_collect_functionfor custom logic.Hooks:
BaseTrainerprovides hooks (empty methods likeon_step_start,on_step_end,on_log) that you can override in your subclass to inject custom actions at specific points in the training loop without rewriting the entire loop.
from easydel.trainers import BaseTrainer, TrainerConfigureFunctionOutput
import jax
import jax.numpy as jnp # Assuming jnp is needed
class CustomTrainer(BaseTrainer):
def create_collect_function(self, *args, **kwargs):
# Example: Custom data collation
def collect_fn(batch):
# Implement your custom batch processing logic here
processed_batch = batch # Placeholder
return processed_batch
return collect_fn
def configure_functions(self, *args, **kwargs) -> TrainerConfigureFunctionOutput:
# Define your custom train and eval steps
def train_step(state, batch):
# Implement custom forward/backward pass logic
# This is highly dependent on your specific task
loss = jnp.mean(batch.get("labels", 0.0)) # Placeholder loss
metrics = {"loss": loss} # Placeholder metrics
# Assuming gradients are computed somehow (e.g., using jax.grad)
# grads = ...
# new_state = state.apply_gradients(grads=grads) # Placeholder state update
new_state = state # Placeholder
return new_state, metrics
def eval_step(state, batch):
# Implement custom evaluation logic
# This is highly dependent on your specific task
loss = jnp.mean(batch.get("labels", 0.0)) # Placeholder loss
metrics = {"eval_loss": loss} # Placeholder metrics
return metrics
return TrainerConfigureFunctionOutput(
train_step_fn=jax.jit(train_step),
eval_step_fn=jax.jit(eval_step)
)
def on_step_end(self, state, metrics, *args, **kwargs):
# Example Hook: Print learning rate every step if scheduler exists
if hasattr(self, "scheduler") and self.scheduler is not None:
current_lr = self.scheduler(state.step)
print(f"Step completed. Current LR: {current_lr}")
# Call parent hook if needed for base functionality
super().on_step_end(state, metrics, *args, **kwargs)