Direct Preference Optimization (DPO) Trainer#
Direct Preference Optimization (DPO) is a technique for fine-tuning language models using human preferences without explicit reward modeling. This tutorial shows how to use EasyDeL’s DPOTrainer.
Overview#
DPOTrainer helps you fine-tune language models to align with human preferences by optimizing the policy to prefer chosen responses over rejected ones. It avoids the separate reward model typically needed in RLHF.
Configuration#
The DPOTrainer is configured using the DPOConfig class which includes the following key parameters:
from easydel.trainers import DPOConfig
dpo_config = DPOConfig(
# Model and training basics
model_name="DPOTrainer", # Name of the model
learning_rate=1e-6, # Learning rate for optimization
# Beta parameter controls how strongly to optimize preferences
beta=0.1, # Temperature parameter for deviation from reference model
# Loss function options
loss_type="sigmoid", # Loss type (sigmoid, hinge, ipo, etc.)
label_smoothing=0.0, # Smoothing factor for labels
# Sequence length parameters
max_length=512, # Maximum total sequence length
max_prompt_length=256, # Maximum length for prompts
max_completion_length=256,# Maximum length for completions
# Reference model control
reference_free=False, # Whether to use reference-free variant
sync_ref_model=False, # Periodically sync reference model
ref_model_sync_steps=64, # Steps between reference model syncs
# Training optimization
disable_dropout=True, # Disable dropout during training
total_batch_size=16, # Total batch size
gradient_accumulation_steps=1 # Steps for gradient accumulation
)
Basic Usage#
Here’s a simple example showing how to initialize and use the DPOTrainer:
import easydel as ed
from transformers import AutoTokenizer
from datasets import load_dataset
import jax
# Load tokenizer and prepare dataset
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Load dataset with preference pairs
dataset = load_dataset(
"trl-lib/ultrafeedback_binarized",
split="train[:10%]" # Using a small subset for demonstration
)
# Load model and reference model
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
dtype=jax.numpy.bfloat16,
# ... model loading configs
)
ref_model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
dtype=jax.numpy.bfloat16,
# ... model loading configs
)
# Create DPO config
config = ed.DPOConfig(
model_name="dpo_example",
save_directory="dpo_checkpoints",
beta=0.1,
loss_type="sigmoid",
max_length=2048,
max_prompt_length=1024,
total_batch_size=16,
learning_rate=1e-6,
num_train_epochs=3,
use_wandb=True,
)
# Initialize the trainer
trainer = ed.DPOTrainer(
model=model,
reference_model=ref_model,
arguments=config,
train_dataset=dataset,
processing_class=tokenizer,
)
# Start training
trainer.train()
Command Line Training#
You can also run DPO training directly from the command line:
python -m easydel.scripts.finetune.dpo \
--repo_id meta-llama/Llama-3.1-8B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--dataset_split "train[:90%]" \
--refrence_model_repo_id meta-llama/Llama-3.1-8B-Instruct \
--attn_mechanism vanilla \
--beta 0.08 \
--loss_type sigmoid \
--max_length 2048 \
--max_prompt_length 1024 \
--total_batch_size 16 \
--learning_rate 1e-6 \
--log_steps 50 \
--num_train_epochs 3 \
--do_last_save \
--save_steps 1000 \
--use_wandb
Dataset Format#
DPOTrainer expects a dataset with the following format:
{
"prompt": "What is the capital of France?",
"chosen": "The capital of France is Paris.",
"rejected": "The capital of France is London."
}
Advanced Usage#
Custom Loss Functions#
DPOTrainer supports various loss functions:
sigmoid: Standard DPO loss (default)hinge: Hinge loss for more aggressive preference learningipo: Implicit Preference Optimization lossSeveral others like
exo_pair,nca_pair,robust,aot, etc.
Example with IPO loss:
config = ed.DPOConfig(
loss_type="ipo",
beta=0.2, # May need different beta for different loss types
# Other parameters...
)
Reference Model Syncing#
To prevent policy drift during training, you can sync the reference model periodically:
config = ed.DPOConfig(
sync_ref_model=True,
ref_model_sync_steps=128, # Sync every 128 steps
ref_model_mixup_alpha=0.9, # Mixing parameter
# Other parameters...
)
Reference-Free Training#
You can train without a separate reference model:
config = ed.DPOConfig(
reference_free=True,
# Other parameters...
)
Tips for Effective DPO Training#
Beta selection: Start with values between 0.05-0.2 and tune based on results
Learning rate: Use smaller learning rates (1e-6 to 5e-6) than standard SFT
Batch size: Larger batch sizes often work better for preference learning
Validation: Monitor the preference accuracy on validation sets
Dataset quality: The quality of preference pairs greatly impacts results