# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
from collections import defaultdict
from functools import partial
import jax
from eformer.loggings import get_logger
from jax import numpy as jnp
from jax.sharding import PartitionSpec
from tqdm.autonotebook import tqdm
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.base_state import EasyDeLState
from easydel.infra.utils import ProcessingClassType
from easydel.utils import Registry
from easydel.utils.compiling_utils import ejit
from easydel.utils.traversals import deepcopy_model
from ..base_trainer import TrainerConfigureFunctionOutput
from ..prompt_transforms import DPOPreprocessTransform
from ..trainer.trainer import Trainer
from ..training_configurations import MetricsType
from ..utils import DataCollatorForPreferenceGrain, DataCollatorForPreferenceTFDS
from ._fn import concatenated_forward, evaluation_step, training_step
from .dpo_config import DPOConfig
if tp.TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from easydel.data.core.protocols import ShardedDataSource
logger = get_logger(__name__)
[docs]@Registry.register("trainer", "dpo")
class DPOTrainer(Trainer):
"""Trainer for Direct Preference Optimization (DPO).
This trainer implements the Direct Preference Optimization algorithm for training
language models from human preferences without requiring a separate reward model.
DPO directly optimizes the policy to match human preferences by maximizing the
likelihood of preferred completions relative to rejected ones.
The trainer uses lazy preprocessing transforms that are applied during iteration,
providing better performance than eager HF .map() calls.
Attributes:
arguments (DPOConfig): Configuration object containing all training parameters.
processing_class: Tokenizer or processor for data preprocessing.
reference_state (EasyDeLState): Reference model state for KL divergence computation.
padding_value (int): Token ID used for padding sequences.
Example:
>>> config = DPOConfig(
... beta=0.1,
... loss_type="sigmoid",
... max_length=512,
... learning_rate=5e-6
... )
>>> trainer = DPOTrainer(
... arguments=config,
... model=model,
... reference_model=reference_model,
... processing_class=tokenizer,
... train_dataset=preference_dataset
... )
>>> trainer.train()
Note:
The trainer expects datasets with 'prompt', 'chosen', and 'rejected' columns.
These will be automatically tokenized via lazy transforms during iteration.
"""
arguments: DPOConfig
def __init__(
self,
arguments: DPOConfig,
model: EasyDeLBaseModule | EasyDeLState,
reference_model: EasyDeLBaseModule | EasyDeLState | None = None,
processing_class: ProcessingClassType = None,
train_dataset: Dataset | IterableDataset | ShardedDataSource | None = None,
eval_dataset: Dataset | IterableDataset | ShardedDataSource | None = None,
data_collator: tp.Callable | None = None,
):
if arguments is None:
raise ValueError("arguments cannot be None")
if not isinstance(arguments, DPOConfig):
raise TypeError(f"arguments must be DPOConfig, got {type(arguments)}")
if processing_class is None:
raise ValueError("processing_class must be specified to tokenize a DPO dataset.")
self.arguments = arguments
self.truncation_mode = arguments.truncation_mode
self.processing_class = processing_class
self.is_encoder_decoder = arguments.is_encoder_decoder
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False
# Determine padding value
if arguments.padding_value is not None:
self.padding_value = arguments.padding_value
else:
if hasattr(processing_class, "pad_token_id") and processing_class.pad_token_id is not None:
self.padding_value = processing_class.pad_token_id
elif hasattr(processing_class, "tokenizer") and processing_class.tokenizer.pad_token_id is not None:
self.padding_value = processing_class.tokenizer.pad_token_id
else:
raise ValueError(
"`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in the "
"`processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set "
"`tokenizer.pad_token` (e.g., `tokenizer.pad_token = tokenizer.eos_token`) before instantiating "
"the trainer."
)
arguments.padding_value = self.padding_value
# Setup data collators
self.input_data_collator_tfds = (
DataCollatorForPreferenceTFDS(
max_prompt_length=arguments.max_prompt_length,
max_completion_length=arguments.max_completion_length,
pad_token_id=self.padding_value,
label_pad_token_id=arguments.label_pad_token_id,
is_encoder_decoder=arguments.is_encoder_decoder,
)
if data_collator is None
else data_collator
)
self.input_data_collator_grain = (
DataCollatorForPreferenceGrain(
max_prompt_length=arguments.max_prompt_length,
max_completion_length=arguments.max_completion_length,
pad_token_id=self.padding_value,
label_pad_token_id=arguments.label_pad_token_id,
is_encoder_decoder=arguments.is_encoder_decoder,
)
if data_collator is None
else data_collator
)
self._stored_metrics = defaultdict(lambda: defaultdict(list))
# Setup models
if not isinstance(model, EasyDeLState):
model = model.to_state()
if reference_model is None:
reference_model = deepcopy_model(model)
if not isinstance(reference_model, EasyDeLState):
reference_model = reference_model.to_state()
self.reference_state = reference_model
super().__init__(
model_state=model,
arguments=arguments,
dataset_train=train_dataset,
dataset_eval=eval_dataset,
data_collator=None,
processing_class=processing_class,
)
def _get_preprocess_transform(self) -> DPOPreprocessTransform | None:
"""Get DPO preprocessing transform for ShardedDataSource.
Returns a transform that handles:
- Prompt extraction from chosen/rejected
- Chat template application
- Triple tokenization (prompt, chosen, rejected)
Returns:
DPOPreprocessTransform or None if data is already tokenized.
"""
if self._is_pretokenized():
return None
return DPOPreprocessTransform(
tokenizer=self.processing_class,
max_prompt_length=self.arguments.max_prompt_length,
max_completion_length=self.arguments.max_completion_length,
tools=getattr(self.arguments, "tools", None),
label_pad_token_id=self.arguments.label_pad_token_id,
)
def _is_pretokenized(self) -> bool:
"""Check if dataset already has DPO tokenized fields."""
if self._train_source is None:
return False
try:
sample = next(iter(self._train_source.open_shard(self._train_source.shard_names[0])))
return "prompt_input_ids" in sample
except (StopIteration, IndexError):
return False
[docs] def create_grain_collect_function(
self,
max_sequence_length: int,
truncation_mode: tp.Literal["keep_end", "keep_start"] = "keep_end",
) -> tp.Callable:
"""Create data collection function for Grain batching."""
return self.input_data_collator_grain
[docs] def create_tfds_collect_function(
self,
max_sequence_length: int,
truncation_mode: tp.Literal["keep_end", "keep_start"] = "keep_end",
) -> tp.Callable:
"""Create data collection function for TFDS batching."""
return self.input_data_collator_tfds
[docs] def compute_reference_log_probs(
self,
state: EasyDeLState,
padded_batch: dict,
) -> tuple[tp.Any, tp.Any]:
"""Compute log probabilities of the reference model for a batch."""
if self.reference_state is None:
outs = self.concatenated_forward(state.model, batch=padded_batch)
else:
outs = self.concatenated_forward(self.reference_state.model, batch=padded_batch)
return outs["chosen_logps"], outs["rejected_logps"]
@property
def _train_shared_fn_extra_args(self) -> tuple[tp.Any]:
return (self.reference_state,)
@property
def _eval_shared_fn_extra_args(self) -> tuple[tp.Any]:
return (self.reference_state,)
[docs] def on_step_end(
self,
state: EasyDeLState,
metrics: MetricsType,
step: int,
) -> tuple[EasyDeLState, MetricsType]:
"""Hook called at the end of each step for reference model sync."""
if (
self.arguments.sync_ref_model
and self.reference_state is not None
and (step % self.arguments.ref_model_sync_steps == 0)
):
self.reference_state = self.reference_state.replace(graphstate=deepcopy_model(state.graphstate))
return state, metrics