Trainer Integration#
EasyData integrates seamlessly with all EasyDeL trainers. This guide covers best practices for using data with SFT, DPO, GRPO, KTO, and other trainers.
Quick Reference#
Trainer |
Dataset Format |
Preprocessing |
|---|---|---|
|
Text or chat messages |
Chat template applied |
|
Chosen/rejected pairs |
Prompt extraction + tokenization |
|
Chosen/rejected pairs |
Same as DPO |
|
Unpaired or paired preference |
Unpairs + tokenizes |
|
Prompts (with reward function) |
Prompt extraction |
|
Chosen/rejected pairs |
Tokenization |
|
Chosen/rejected pairs |
Same as DPO |
|
Prompt/completion/label |
Tokenization |
SFT Training#
Basic SFT#
from datasets import load_dataset
import easydel as ed
# Load conversational dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train")
trainer = ed.SFTTrainer(
model=model,
train_dataset=dataset,
processing_class=tokenizer,
arguments=ed.SFTConfig(
max_sequence_length=2048,
# Chat template applied automatically
),
)
trainer.train()
SFT with Mixed Datasets#
from easydel.data import block_mixture_interleave
# Mix instruction datasets
alpaca = load_dataset("tatsu-lab/alpaca", split="train")
dolly = load_dataset("databricks/databricks-dolly-15k", split="train")
# Dict format for explicit mapping (recommended)
mixed = block_mixture_interleave(
datasets={"alpaca": alpaca, "dolly": dolly},
weights={"alpaca": 0.6, "dolly": 0.4},
block_size=1000,
seed=42,
stop="restart",
)
# Equal weights
mixed = block_mixture_interleave(
datasets={"alpaca": alpaca, "dolly": dolly},
weights=None, # Equal 50/50 mixing
)
trainer = ed.SFTTrainer(
model=model,
train_dataset=mixed,
processing_class=tokenizer,
arguments=ed.SFTConfig(...),
)
SFT with Pre-tokenized Data#
from easydel.data import ParquetShardedSource
# Load pre-tokenized data
source = ParquetShardedSource("./tokenized_sft/*.parquet")
trainer = ed.SFTTrainer(
model=model,
train_dataset=source, # Works directly
processing_class=tokenizer,
arguments=ed.SFTConfig(
max_sequence_length=2048,
),
)
DPO Training#
Basic DPO#
from datasets import load_dataset
import easydel as ed
# Load preference dataset (has chosen/rejected)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = ed.DPOTrainer(
model=policy_model,
reference_model=ref_model, # Optional, deep-copied if None
train_dataset=dataset, # Raw preference data
processing_class=tokenizer,
arguments=ed.DPOConfig(
max_prompt_length=512,
max_completion_length=512,
beta=0.1,
),
)
trainer.train()
DPO Internal Preprocessing#
DPO trainer automatically:
Extracts shared prompt from chosen/rejected
Applies chat template if conversational
Tokenizes with proper truncation
Handles padding and masking
# Input format (conversational)
{
"chosen": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
],
"rejected": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Go away."},
],
}
# Or simpler format
{
"prompt": "Hello",
"chosen": "Hi there!",
"rejected": "Go away.",
}
GRPO Training#
Basic GRPO#
from datasets import load_dataset
import easydel as ed
# Load any preference dataset (prompts extracted automatically)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
def reward_fn(prompts, completions, **kwargs):
# Your reward logic
return [score_completion(c) for c in completions]
trainer = ed.GRPOTrainer(
model=model,
reward_funcs=reward_fn,
train_dataset=dataset,
processing_class=tokenizer,
arguments=ed.GRPOConfig(
num_return_sequences=4,
max_prompt_length=512,
max_completion_length=512,
),
)
trainer.train()
GRPO Internal Preprocessing#
GRPO trainer automatically:
Extracts prompts from chosen/rejected conversations
Applies chat template with
add_generation_prompt=TrueLeft-pads for efficient batch generation
GRPO with Custom Prompts#
# Direct prompt format also works
dataset = Dataset.from_dict({
"prompt": [
"Write a poem about AI",
"Explain quantum computing",
]
})
# Or conversational
dataset = Dataset.from_dict({
"prompt": [
[{"role": "user", "content": "Write a poem about AI"}],
[{"role": "user", "content": "Explain quantum computing"}],
]
})
KTO Training#
Basic KTO#
from datasets import load_dataset
import easydel as ed
# KTO accepts paired preference data (will unpair internally)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = ed.KTOTrainer(
model=policy_model,
reference_model=ref_model,
train_dataset=dataset,
processing_class=tokenizer,
arguments=ed.KTOConfig(
max_prompt_length=512,
max_completion_length=512,
beta=0.1,
loss_type="kto",
),
)
trainer.train()
KTO Internal Preprocessing#
KTO trainer automatically:
Extracts shared prompt
Unpairs preference data (1 pair → 2 examples with labels)
Applies chat template
Tokenizes with BCO collator
# Input: paired preference data
{
"chosen": [...],
"rejected": [...],
}
# Internally converted to:
[
{"prompt": "...", "completion": "...", "label": True}, # From chosen
{"prompt": "...", "completion": "...", "label": False}, # From rejected
]
ORPO Training#
from datasets import load_dataset
import easydel as ed
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = ed.ORPOTrainer(
model=model,
train_dataset=dataset,
processing_class=tokenizer,
arguments=ed.ORPOConfig(
max_prompt_length=512,
max_completion_length=512,
beta=0.1,
),
)
trainer.train()
Reward Training#
from datasets import load_dataset
import easydel as ed
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
# Use sequence classification model
trainer = ed.RewardTrainer(
model=reward_model, # AutoEasyDeLModelForSequenceClassification
train_dataset=dataset,
processing_class=tokenizer,
arguments=ed.RewardConfig(
max_length=2048,
),
)
trainer.train()
Using ShardedDataSource#
All trainers accept ShardedDataSource directly:
from easydel.data import ParquetShardedSource, JsonShardedSource
# From local files
source = ParquetShardedSource("data/*.parquet")
# From cloud
source = ParquetShardedSource(
"gs://bucket/data/*.parquet",
storage_options={"token": "cloud"},
)
# Use with any trainer
trainer = ed.SFTTrainer(
model=model,
train_dataset=source, # Works directly
processing_class=tokenizer,
arguments=ed.SFTConfig(...),
)
Mixed Datasets with Trainers#
from datasets import load_dataset
from easydel.data import block_mixture_interleave
# Load multiple datasets
ds1 = load_dataset("dataset1", split="train")
ds2 = load_dataset("dataset2", split="train")
ds3 = load_dataset("dataset3", split="train")
# Dict format for explicit mapping (recommended)
mixed = block_mixture_interleave(
datasets={"ds1": ds1, "ds2": ds2, "ds3": ds3},
weights={"ds1": 0.5, "ds2": 0.3, "ds3": 0.2},
block_size=1000,
seed=42,
stop="restart",
)
# Works with any trainer
trainer = ed.DPOTrainer(
model=model,
train_dataset=mixed,
processing_class=tokenizer,
arguments=ed.DPOConfig(...),
)
Dynamic Weight Scheduling#
from easydel.data import (
MixedShardedSource,
HFDatasetShardedSource,
WeightScheduler,
WeightSchedulePoint,
)
# Wrap datasets
source1 = HFDatasetShardedSource(dataset1)
source2 = HFDatasetShardedSource(dataset2)
# Define schedule
scheduler = WeightScheduler(
schedule=[
WeightSchedulePoint(step=0, weights={"easy": 0.9, "hard": 0.1}),
WeightSchedulePoint(step=10000, weights={"easy": 0.5, "hard": 0.5}),
WeightSchedulePoint(step=50000, weights={"easy": 0.1, "hard": 0.9}),
],
interpolation="linear",
)
mixed = MixedShardedSource(
sources={"easy": source1, "hard": source2},
weight_scheduler=scheduler,
block_size=1000,
)
trainer = ed.SFTTrainer(
model=model,
train_dataset=mixed,
processing_class=tokenizer,
arguments=ed.SFTConfig(...),
)
Pre-tokenization for Trainers#
SFT Pre-tokenization#
from easydel.data import tokenize_and_save
# Tokenize once
tokenize_and_save(
data_files="conversations/*.jsonl",
tokenizer="meta-llama/Llama-2-7b-chat-hf",
output_path="./sft_tokenized",
max_length=2048,
)
# Use in training
from easydel.data import ParquetShardedSource
source = ParquetShardedSource("./sft_tokenized/*.parquet")
trainer = ed.SFTTrainer(train_dataset=source, ...)
DPO Pre-tokenization#
from easydel.trainers.transforms import DPOPreprocessTransform
from easydel.data import JsonShardedSource, TransformedShardedSource
transform = DPOPreprocessTransform(
tokenizer=tokenizer,
max_prompt_length=512,
max_completion_length=512,
)
source = JsonShardedSource("preference_data/*.jsonl")
tokenized = TransformedShardedSource(source, transform=transform)
# Save for later
from easydel.data import save_iterator
save_iterator(
tokenized.open_shard(tokenized.shard_names[0]),
output_path="./dpo_tokenized",
format="parquet",
)
Trainer-Specific Transforms#
Access trainer preprocessing logic directly:
from easydel.trainers.transforms import (
SFTPreprocessTransform,
DPOPreprocessTransform,
ORPOPreprocessTransform,
KTOPreprocessTransform,
GRPOPreprocessTransform,
RewardPreprocessTransform,
BCOPreprocessTransform,
CPOPreprocessTransform,
)
# SFT transform
sft_transform = SFTPreprocessTransform(
tokenizer=tokenizer,
max_length=2048,
)
# DPO transform
dpo_transform = DPOPreprocessTransform(
tokenizer=tokenizer,
max_prompt_length=512,
max_completion_length=512,
)
# Apply to data
from easydel.data import TransformedShardedSource
source = JsonShardedSource("data/*.jsonl")
transformed = TransformedShardedSource(source, transform=dpo_transform)
Streaming with Trainers#
from datasets import load_dataset
# Stream from HuggingFace
dataset = load_dataset(
"HuggingFaceFW/fineweb",
split="train",
streaming=True,
)
# Works with trainers
trainer = ed.SFTTrainer(
model=model,
train_dataset=dataset,
processing_class=tokenizer,
arguments=ed.SFTConfig(
shuffle_train_dataset=False, # Already streamed
...
),
)
Best Practices#
1. Match Data Format to Trainer#
# SFT: text or messages
{"text": "..."} or {"messages": [...]}
# DPO/ORPO/CPO: chosen/rejected
{"chosen": [...], "rejected": [...]}
# KTO: paired or unpaired
{"chosen": [...], "rejected": [...]} # Paired (unpaired internally)
{"prompt": "...", "completion": "...", "label": True/False} # Already unpaired
# GRPO: prompts
{"prompt": "..."} or {"chosen": [...], "rejected": [...]}
# Reward: chosen/rejected
{"prompt": "...", "chosen": "...", "rejected": "..."}
2. Use Appropriate Sequence Lengths#
# DPO/KTO/ORPO: separate prompt and completion lengths
ed.DPOConfig(
max_prompt_length=512, # For prompt
max_completion_length=512, # For completion
)
# SFT: single length
ed.SFTConfig(
max_sequence_length=2048, # Total length
)
# GRPO: prompt length + generation length
ed.GRPOConfig(
max_prompt_length=512,
max_completion_length=512, # For generation
)
3. Handle Pad Tokens#
# Most trainers handle this automatically
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# But you can specify explicitly
ed.DPOConfig(
padding_value=tokenizer.pad_token_id,
label_pad_token_id=-100, # Ignore in loss
)
4. Monitor Data Processing#
# Check preprocessing
trainer = ed.DPOTrainer(...)
# Before training, verify a batch
batch = next(iter(trainer._training_batch_iterator()))
print(f"Batch keys: {batch.keys()}")
print(f"Input shape: {batch['input_ids'].shape}")
Troubleshooting#
“Dataset must have column X”#
# Check your dataset columns
print(dataset.column_names)
# Map to expected format
dataset = dataset.map(lambda x: {"chosen": x["good"], "rejected": x["bad"]})
Tokenization errors#
# Ensure tokenizer has required tokens
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
Out of memory#
# Use streaming
dataset = load_dataset(..., streaming=True)
# Or reduce batch size
ed.SFTConfig(total_batch_size=4)
# Or use gradient accumulation
ed.SFTConfig(
total_batch_size=32,
gradient_accumulation_steps=8, # Effective batch = 4
)
Next Steps#
Quickstart - Get started quickly
Dataset Mixing - Advanced mixing strategies
Pre-tokenization - Offline tokenization