easydel.trainers.direct_preference_optimization_trainer.dpo_trainer#

class easydel.trainers.direct_preference_optimization_trainer.dpo_trainer.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#

Bases: Trainer

Trainer for Direct Preference Optimization (DPO).

This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.

arguments: DPOConfig#
checkpoint_manager: tp.Any#
checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#

Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.

Parameters
  • state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).

  • padded_batch (tp.Dict) – The padded batch of data.

Returns

A tuple containing the log probabilities for the chosen and rejected responses.

Return type

tuple[tp.Any, tp.Any]

config: EasyDeLBaseConfig#
configure_dataloaders()[source]#

Returns the training dataloader, potentially with precomputed reference log probabilities.

If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.

Returns

The training dataloader.

Return type

tensorflow.data.Dataset

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

data_collator: tp.Optional[tp.Callable]#
dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
dataloader_train: tp.Iterator[np.ndarray]#
dataset_eval: tp.Optional[Dataset]#
dataset_train: tp.Optional[Dataset]#
dtype: tp.Any#
evalu_tracker: CompilationTracker#
finetune: bool#
max_evaluation_steps: int#
max_training_steps: int#
memory_monitor: tp.Any#
model_state: EasyDeLState#
on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

param_dtype: tp.Any#
static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
pruning_module: tp.Any#
scheduler: optax.Schedule#
sharded_evaluation_step_function: JitWrapped#
sharded_training_step_function: JitWrapped#
state: tp.Any#
state_named_sharding: tp.Any#
state_partition_spec: tp.Any#
state_shape: tp.Any#
timer: Timers#
static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#

Tokenize a row of the dataset.

Parameters
  • features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.

  • processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.

  • max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.

  • max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.

  • add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.

Returns

Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.

Return type

dict[str, list[int]]

train_tracker: CompilationTracker#
tx: optax.GradientTransformation#
wandb_runtime: tp.Any#