BaseTrainer Documentation#

Introduction#

The BaseTrainer class in EasyDeL provides a robust and flexible foundation for training transformer models. It encapsulates the common logic required for training loops, evaluation, checkpointing, and configuration, allowing users to focus on the specifics of their model and data.

BaseTrainer implements the BaseTrainerProtocol, ensuring a consistent interface across different training scenarios. Use BaseTrainer when you need a feature-rich, configurable trainer for standard training tasks like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), etc., or as a starting point for more complex custom training workflows.

For details on the abstract interface that BaseTrainer implements, refer to the Trainer Protocol Documentation.

Core Concepts#

Understanding these core concepts is crucial for effectively using BaseTrainer:

  • TrainingArguments: A dataclass holding all hyperparameters and configuration settings for the training process (e.g., learning rate, batch size, number of epochs, checkpointing strategy, logging options). You initialize the trainer with an instance of this class.

  • EasyDeLState: A Flax TrainState subclass that bundles the model parameters, optimizer state, and potentially other training-related variables (like PRNG keys). BaseTrainer manages this state throughout the training process.

  • Training Loop: The core process orchestrated by the train() method. It involves iterating over epochs and steps, fetching data batches, executing the training step function, logging metrics, and handling checkpointing.

  • Checkpointing: BaseTrainer automatically saves the EasyDeLState (model parameters, optimizer state) and TrainingArguments at specified intervals (e.g., every N steps or epochs) to allow resuming training later. It also manages a configurable limit on the number of checkpoints to keep.

  • Distributed Training: BaseTrainer is designed with JAX’s pmap and shmap in mind, enabling efficient training across multiple devices (GPUs/TPUs) with minimal code changes required from the user for standard setups.

Key Methods and Configuration#

BaseTrainer relies on several key methods, some of which you might override for customization:

  • __init__(self, arguments: TrainingArguments, ...): Initializes the trainer. Requires TrainingArguments and dataset information.

  • configure_model(self) -> TrainerConfigureModelOutput: (Abstract method in protocol, implemented in BaseTrainer) Initializes the model and the initial EasyDeLState. It typically loads a pretrained model based on arguments. Returns a TrainerConfigureModelOutput tuple containing the initialized model and state.

  • configure_dataloaders(self) -> TrainerConfigureDataloaderOutput: (Abstract method in protocol, implemented in BaseTrainer) Sets up the training and evaluation dataloaders based on the provided datasets and arguments. Returns a TrainerConfigureDataloaderOutput tuple containing dataloader_train and dataloader_eval.

  • configure_functions(self) -> TrainerConfigureFunctionOutput: (Abstract method in protocol, implemented in subclasses like SFTTrainer) Defines the core training and evaluation step functions, often applying jax.jit for performance. Returns a TrainerConfigureFunctionOutput tuple containing train_step_fn and eval_step_fn.

  • create_collect_function(self) -> Callable: (Optional override) Defines how batches are collated and preprocessed before being passed to the training/evaluation step functions. Useful for custom padding or data manipulation.

  • train(self, ckpt_path: Optional[str] = None) -> EasyDeLState: The main entry point to start the training process. It orchestrates the entire training loop, including epoch/step iteration, data loading, calling the train_step_fn, logging, evaluation (if configured), and checkpointing. Optionally takes a ckpt_path to resume training. Returns the final EasyDeLState.

  • eval(self, state: EasyDeLState) -> Dict: Runs the evaluation loop using the eval_step_fn on the evaluation dataloader. Returns a dictionary of evaluation metrics. Typically called automatically by train() if an evaluation dataset and strategy are provided.

  • save_pretrained(self, ckpt_path: str, state: Optional[EasyDeLState] = None): Saves the current training state (EasyDeLState) and TrainingArguments to the specified ckpt_path. Called automatically during training based on arguments.save_strategy.

Configuration Example (TrainingArguments)#

from easydel.trainers import TrainingArguments

args = TrainingArguments(
    model_name="MyFineTunedModel",
    num_train_epochs=3,
    learning_rate=2e-5,
    total_batch_size=32, # Effective batch size (per_device_train_batch_size * num_devices)
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    weight_decay=0.01,
    max_steps=-1, # If > 0, overrides num_train_epochs
    gradient_checkpointing="nothing_saveable", # Or "everything_saveable"
    sharding_array=(1, -1, 1, 1), # Device mesh shape for model sharding
    use_fast_kernels=True,
    logging_steps=10,
    save_steps=100, # Save checkpoint every 100 steps
    save_total_limit=2, # Keep only the latest 2 checkpoints
    evaluation_strategy="steps", # Or "epoch"
    eval_steps=100, # Evaluate every 100 steps
    output_dir="runs/my_model_run"
    # Add other relevant TrainingArguments as needed
)

# Example usage (assuming MyTrainerSubclass exists):
# trainer = MyTrainerSubclass(arguments=args, ...)
# trainer.train()

Customization#

While BaseTrainer handles most standard scenarios, you can customize its behavior:

  • Subclassing: Create a new class inheriting from BaseTrainer (or a more specific trainer like SFTTrainer) and override methods like configure_functions or create_collect_function for custom logic.

  • Hooks: BaseTrainer provides hooks (empty methods like on_step_start, on_step_end, on_log) that you can override in your subclass to inject custom actions at specific points in the training loop without rewriting the entire loop.

from easydel.trainers import BaseTrainer, TrainerConfigureFunctionOutput
import jax
import jax.numpy as jnp # Assuming jnp is needed

class CustomTrainer(BaseTrainer):
    def create_collect_function(self, *args, **kwargs):
        # Example: Custom data collation
        def collect_fn(batch):
            # Implement your custom batch processing logic here
            processed_batch = batch # Placeholder
            return processed_batch
        return collect_fn

    def configure_functions(self, *args, **kwargs) -> TrainerConfigureFunctionOutput:
        # Define your custom train and eval steps
        def train_step(state, batch):
            # Implement custom forward/backward pass logic
            # This is highly dependent on your specific task
            loss = jnp.mean(batch.get("labels", 0.0)) # Placeholder loss
            metrics = {"loss": loss} # Placeholder metrics
            # Assuming gradients are computed somehow (e.g., using jax.grad)
            # grads = ...
            # new_state = state.apply_gradients(grads=grads) # Placeholder state update
            new_state = state # Placeholder
            return new_state, metrics

        def eval_step(state, batch):
            # Implement custom evaluation logic
            # This is highly dependent on your specific task
            loss = jnp.mean(batch.get("labels", 0.0)) # Placeholder loss
            metrics = {"eval_loss": loss} # Placeholder metrics
            return metrics

        return TrainerConfigureFunctionOutput(
            train_step_fn=jax.jit(train_step),
            eval_step_fn=jax.jit(eval_step)
        )

    def on_step_end(self, state, metrics, *args, **kwargs):
        # Example Hook: Print learning rate every step if scheduler exists
        if hasattr(self, "scheduler") and self.scheduler is not None:
             current_lr = self.scheduler(state.step)
             print(f"Step completed. Current LR: {current_lr}")
        # Call parent hook if needed for base functionality
        super().on_step_end(state, metrics, *args, **kwargs)