easydel.trainers.direct_preference_optimization_trainer.dpo_trainer#
- class easydel.trainers.direct_preference_optimization_trainer.dpo_trainer.DPOTrainer(arguments: DPOConfig, model: EasyDeLBaseModule | EasyDeLState, reference_model: EasyDeLBaseModule | EasyDeLState | None = None, processing_class: ProcessingClassType = None, train_dataset: Dataset | IterableDataset | ShardedDataSource | None = None, eval_dataset: Dataset | IterableDataset | ShardedDataSource | None = None, data_collator: tp.Callable | None = None)[source]#
Bases:
TrainerTrainer for Direct Preference Optimization (DPO).
This trainer implements the Direct Preference Optimization algorithm for training language models from human preferences without requiring a separate reward model. DPO directly optimizes the policy to match human preferences by maximizing the likelihood of preferred completions relative to rejected ones.
The trainer uses lazy preprocessing transforms that are applied during iteration, providing better performance than eager HF .map() calls.
- processing_class#
Tokenizer or processor for data preprocessing.
- reference_state#
Reference model state for KL divergence computation.
- Type
- padding_value#
Token ID used for padding sequences.
- Type
int
Example
>>> config = DPOConfig( ... beta=0.1, ... loss_type="sigmoid", ... max_length=512, ... learning_rate=5e-6 ... ) >>> trainer = DPOTrainer( ... arguments=config, ... model=model, ... reference_model=reference_model, ... processing_class=tokenizer, ... train_dataset=preference_dataset ... ) >>> trainer.train()
Note
The trainer expects datasets with ‘prompt’, ‘chosen’, and ‘rejected’ columns. These will be automatically tokenized via lazy transforms during iteration.
- compute_reference_log_probs(state: EasyDeLState, padded_batch: dict) tuple[Any, Any][source]#
Compute log probabilities of the reference model for a batch.
- configure_dataloaders()[source]#
Configure dataloaders with optional precomputed reference log probs.
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configure and JIT-compile training and evaluation step functions.
- create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Create data collection function for Grain batching.
- create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Create data collection function for TFDS batching.
- on_step_end(state: EasyDeLState, metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int) tuple[easydel.infra.base_state.EasyDeLState, dict[str, float | list | tuple | numpy.ndarray | Any]][source]#
Hook called at the end of each step for reference model sync.