easydel.trainers.direct_preference_optimization_trainer.dpo_trainer#
- class easydel.trainers.direct_preference_optimization_trainer.dpo_trainer.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#
Bases:
TrainerTrainer for Direct Preference Optimization (DPO).
This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.
- checkpoint_manager: tp.Any#
- checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
- compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#
Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.
- Parameters
state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).
padded_batch (tp.Dict) – The padded batch of data.
- Returns
A tuple containing the log probabilities for the chosen and rejected responses.
- Return type
tuple[tp.Any, tp.Any]
- config: EasyDeLBaseConfig#
- configure_dataloaders()[source]#
Returns the training dataloader, potentially with precomputed reference log probabilities.
If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.
- Returns
The training dataloader.
- Return type
tensorflow.data.Dataset
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a data collection function for batching.
For DPO training, this method simply returns the pre-configured data_collator.
- Parameters
max_sequence_length (int) – The maximum sequence length (not used in this implementation).
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.
- Returns
The data collator function.
- Return type
tp.Callable
- data_collator: tp.Optional[tp.Callable]#
- dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
- dataloader_train: tp.Iterator[np.ndarray]#
- dataset_eval: tp.Optional[Dataset]#
- dataset_train: tp.Optional[Dataset]#
- dtype: tp.Any#
- evalu_tracker: CompilationTracker#
- finetune: bool#
- max_evaluation_steps: int#
- max_training_steps: int#
- memory_monitor: tp.Any#
- model_state: EasyDeLState#
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- param_dtype: tp.Any#
- static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
- pruning_module: tp.Any#
- scheduler: optax.Schedule#
- sharded_evaluation_step_function: JitWrapped#
- sharded_training_step_function: JitWrapped#
- state: tp.Any#
- state_named_sharding: tp.Any#
- state_partition_spec: tp.Any#
- state_shape: tp.Any#
- static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
Tokenize a row of the dataset.
- Parameters
features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.
processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.
max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.
max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.
add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.
- Returns
Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.
- Return type
dict[str, list[int]]
- train_tracker: CompilationTracker#
- tx: optax.GradientTransformation#
- wandb_runtime: tp.Any#