Trainer Integration#

EasyData integrates seamlessly with all EasyDeL trainers. This guide covers best practices for using data with SFT, DPO, GRPO, KTO, and other trainers.

Quick Reference#

Trainer

Dataset Format

Preprocessing

SFTTrainer

Text or chat messages

Chat template applied

DPOTrainer

Chosen/rejected pairs

Prompt extraction + tokenization

ORPOTrainer

Chosen/rejected pairs

Same as DPO

KTOTrainer

Unpaired or paired preference

Unpairs + tokenizes

GRPOTrainer

Prompts (with reward function)

Prompt extraction

RewardTrainer

Chosen/rejected pairs

Tokenization

CPOTrainer

Chosen/rejected pairs

Same as DPO

BCOTrainer

Prompt/completion/label

Tokenization

SFT Training#

Basic SFT#

from datasets import load_dataset
import easydel as ed

# Load conversational dataset
dataset = load_dataset("tatsu-lab/alpaca", split="train")

trainer = ed.SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    arguments=ed.SFTConfig(
        max_sequence_length=2048,
        # Chat template applied automatically
    ),
)
trainer.train()

SFT with Mixed Datasets#

from easydel.data import block_mixture_interleave

# Mix instruction datasets
alpaca = load_dataset("tatsu-lab/alpaca", split="train")
dolly = load_dataset("databricks/databricks-dolly-15k", split="train")

# Dict format for explicit mapping (recommended)
mixed = block_mixture_interleave(
    datasets={"alpaca": alpaca, "dolly": dolly},
    weights={"alpaca": 0.6, "dolly": 0.4},
    block_size=1000,
    seed=42,
    stop="restart",
)

# Equal weights
mixed = block_mixture_interleave(
    datasets={"alpaca": alpaca, "dolly": dolly},
    weights=None,  # Equal 50/50 mixing
)

trainer = ed.SFTTrainer(
    model=model,
    train_dataset=mixed,
    processing_class=tokenizer,
    arguments=ed.SFTConfig(...),
)

SFT with Pre-tokenized Data#

from easydel.data import ParquetShardedSource

# Load pre-tokenized data
source = ParquetShardedSource("./tokenized_sft/*.parquet")

trainer = ed.SFTTrainer(
    model=model,
    train_dataset=source,  # Works directly
    processing_class=tokenizer,
    arguments=ed.SFTConfig(
        max_sequence_length=2048,
    ),
)

DPO Training#

Basic DPO#

from datasets import load_dataset
import easydel as ed

# Load preference dataset (has chosen/rejected)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

trainer = ed.DPOTrainer(
    model=policy_model,
    reference_model=ref_model,  # Optional, deep-copied if None
    train_dataset=dataset,      # Raw preference data
    processing_class=tokenizer,
    arguments=ed.DPOConfig(
        max_prompt_length=512,
        max_completion_length=512,
        beta=0.1,
    ),
)
trainer.train()

DPO Internal Preprocessing#

DPO trainer automatically:

  1. Extracts shared prompt from chosen/rejected

  2. Applies chat template if conversational

  3. Tokenizes with proper truncation

  4. Handles padding and masking

# Input format (conversational)
{
    "chosen": [
        {"role": "user", "content": "Hello"},
        {"role": "assistant", "content": "Hi there!"},
    ],
    "rejected": [
        {"role": "user", "content": "Hello"},
        {"role": "assistant", "content": "Go away."},
    ],
}

# Or simpler format
{
    "prompt": "Hello",
    "chosen": "Hi there!",
    "rejected": "Go away.",
}

GRPO Training#

Basic GRPO#

from datasets import load_dataset
import easydel as ed

# Load any preference dataset (prompts extracted automatically)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

def reward_fn(prompts, completions, **kwargs):
    # Your reward logic
    return [score_completion(c) for c in completions]

trainer = ed.GRPOTrainer(
    model=model,
    reward_funcs=reward_fn,
    train_dataset=dataset,
    processing_class=tokenizer,
    arguments=ed.GRPOConfig(
        num_return_sequences=4,
        max_prompt_length=512,
        max_completion_length=512,
    ),
)
trainer.train()

GRPO Internal Preprocessing#

GRPO trainer automatically:

  1. Extracts prompts from chosen/rejected conversations

  2. Applies chat template with add_generation_prompt=True

  3. Left-pads for efficient batch generation

GRPO with Custom Prompts#

# Direct prompt format also works
dataset = Dataset.from_dict({
    "prompt": [
        "Write a poem about AI",
        "Explain quantum computing",
    ]
})

# Or conversational
dataset = Dataset.from_dict({
    "prompt": [
        [{"role": "user", "content": "Write a poem about AI"}],
        [{"role": "user", "content": "Explain quantum computing"}],
    ]
})

KTO Training#

Basic KTO#

from datasets import load_dataset
import easydel as ed

# KTO accepts paired preference data (will unpair internally)
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

trainer = ed.KTOTrainer(
    model=policy_model,
    reference_model=ref_model,
    train_dataset=dataset,
    processing_class=tokenizer,
    arguments=ed.KTOConfig(
        max_prompt_length=512,
        max_completion_length=512,
        beta=0.1,
        loss_type="kto",
    ),
)
trainer.train()

KTO Internal Preprocessing#

KTO trainer automatically:

  1. Extracts shared prompt

  2. Unpairs preference data (1 pair → 2 examples with labels)

  3. Applies chat template

  4. Tokenizes with BCO collator

# Input: paired preference data
{
    "chosen": [...],
    "rejected": [...],
}

# Internally converted to:
[
    {"prompt": "...", "completion": "...", "label": True},   # From chosen
    {"prompt": "...", "completion": "...", "label": False},  # From rejected
]

ORPO Training#

from datasets import load_dataset
import easydel as ed

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

trainer = ed.ORPOTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    arguments=ed.ORPOConfig(
        max_prompt_length=512,
        max_completion_length=512,
        beta=0.1,
    ),
)
trainer.train()

Reward Training#

from datasets import load_dataset
import easydel as ed

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

# Use sequence classification model
trainer = ed.RewardTrainer(
    model=reward_model,  # AutoEasyDeLModelForSequenceClassification
    train_dataset=dataset,
    processing_class=tokenizer,
    arguments=ed.RewardConfig(
        max_length=2048,
    ),
)
trainer.train()

Using ShardedDataSource#

All trainers accept ShardedDataSource directly:

from easydel.data import ParquetShardedSource, JsonShardedSource

# From local files
source = ParquetShardedSource("data/*.parquet")

# From cloud
source = ParquetShardedSource(
    "gs://bucket/data/*.parquet",
    storage_options={"token": "cloud"},
)

# Use with any trainer
trainer = ed.SFTTrainer(
    model=model,
    train_dataset=source,  # Works directly
    processing_class=tokenizer,
    arguments=ed.SFTConfig(...),
)

Mixed Datasets with Trainers#

from datasets import load_dataset
from easydel.data import block_mixture_interleave

# Load multiple datasets
ds1 = load_dataset("dataset1", split="train")
ds2 = load_dataset("dataset2", split="train")
ds3 = load_dataset("dataset3", split="train")

# Dict format for explicit mapping (recommended)
mixed = block_mixture_interleave(
    datasets={"ds1": ds1, "ds2": ds2, "ds3": ds3},
    weights={"ds1": 0.5, "ds2": 0.3, "ds3": 0.2},
    block_size=1000,
    seed=42,
    stop="restart",
)

# Works with any trainer
trainer = ed.DPOTrainer(
    model=model,
    train_dataset=mixed,
    processing_class=tokenizer,
    arguments=ed.DPOConfig(...),
)

Dynamic Weight Scheduling#

from easydel.data import (
    MixedShardedSource,
    HFDatasetShardedSource,
    WeightScheduler,
    WeightSchedulePoint,
)

# Wrap datasets
source1 = HFDatasetShardedSource(dataset1)
source2 = HFDatasetShardedSource(dataset2)

# Define schedule
scheduler = WeightScheduler(
    schedule=[
        WeightSchedulePoint(step=0, weights={"easy": 0.9, "hard": 0.1}),
        WeightSchedulePoint(step=10000, weights={"easy": 0.5, "hard": 0.5}),
        WeightSchedulePoint(step=50000, weights={"easy": 0.1, "hard": 0.9}),
    ],
    interpolation="linear",
)

mixed = MixedShardedSource(
    sources={"easy": source1, "hard": source2},
    weight_scheduler=scheduler,
    block_size=1000,
)

trainer = ed.SFTTrainer(
    model=model,
    train_dataset=mixed,
    processing_class=tokenizer,
    arguments=ed.SFTConfig(...),
)

Pre-tokenization for Trainers#

SFT Pre-tokenization#

from easydel.data import tokenize_and_save

# Tokenize once
tokenize_and_save(
    data_files="conversations/*.jsonl",
    tokenizer="meta-llama/Llama-2-7b-chat-hf",
    output_path="./sft_tokenized",
    max_length=2048,
)

# Use in training
from easydel.data import ParquetShardedSource

source = ParquetShardedSource("./sft_tokenized/*.parquet")
trainer = ed.SFTTrainer(train_dataset=source, ...)

DPO Pre-tokenization#

from easydel.trainers.transforms import DPOPreprocessTransform
from easydel.data import JsonShardedSource, TransformedShardedSource

transform = DPOPreprocessTransform(
    tokenizer=tokenizer,
    max_prompt_length=512,
    max_completion_length=512,
)

source = JsonShardedSource("preference_data/*.jsonl")
tokenized = TransformedShardedSource(source, transform=transform)

# Save for later
from easydel.data import save_iterator
save_iterator(
    tokenized.open_shard(tokenized.shard_names[0]),
    output_path="./dpo_tokenized",
    format="parquet",
)

Trainer-Specific Transforms#

Access trainer preprocessing logic directly:

from easydel.trainers.transforms import (
    SFTPreprocessTransform,
    DPOPreprocessTransform,
    ORPOPreprocessTransform,
    KTOPreprocessTransform,
    GRPOPreprocessTransform,
    RewardPreprocessTransform,
    BCOPreprocessTransform,
    CPOPreprocessTransform,
)

# SFT transform
sft_transform = SFTPreprocessTransform(
    tokenizer=tokenizer,
    max_length=2048,
)

# DPO transform
dpo_transform = DPOPreprocessTransform(
    tokenizer=tokenizer,
    max_prompt_length=512,
    max_completion_length=512,
)

# Apply to data
from easydel.data import TransformedShardedSource

source = JsonShardedSource("data/*.jsonl")
transformed = TransformedShardedSource(source, transform=dpo_transform)

Streaming with Trainers#

from datasets import load_dataset

# Stream from HuggingFace
dataset = load_dataset(
    "HuggingFaceFW/fineweb",
    split="train",
    streaming=True,
)

# Works with trainers
trainer = ed.SFTTrainer(
    model=model,
    train_dataset=dataset,
    processing_class=tokenizer,
    arguments=ed.SFTConfig(
        shuffle_train_dataset=False,  # Already streamed
        ...
    ),
)

Best Practices#

1. Match Data Format to Trainer#

# SFT: text or messages
{"text": "..."} or {"messages": [...]}

# DPO/ORPO/CPO: chosen/rejected
{"chosen": [...], "rejected": [...]}

# KTO: paired or unpaired
{"chosen": [...], "rejected": [...]}  # Paired (unpaired internally)
{"prompt": "...", "completion": "...", "label": True/False}  # Already unpaired

# GRPO: prompts
{"prompt": "..."} or {"chosen": [...], "rejected": [...]}

# Reward: chosen/rejected
{"prompt": "...", "chosen": "...", "rejected": "..."}

2. Use Appropriate Sequence Lengths#

# DPO/KTO/ORPO: separate prompt and completion lengths
ed.DPOConfig(
    max_prompt_length=512,      # For prompt
    max_completion_length=512,  # For completion
)

# SFT: single length
ed.SFTConfig(
    max_sequence_length=2048,  # Total length
)

# GRPO: prompt length + generation length
ed.GRPOConfig(
    max_prompt_length=512,
    max_completion_length=512,  # For generation
)

3. Handle Pad Tokens#

# Most trainers handle this automatically
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# But you can specify explicitly
ed.DPOConfig(
    padding_value=tokenizer.pad_token_id,
    label_pad_token_id=-100,  # Ignore in loss
)

4. Monitor Data Processing#

# Check preprocessing
trainer = ed.DPOTrainer(...)

# Before training, verify a batch
batch = next(iter(trainer._training_batch_iterator()))
print(f"Batch keys: {batch.keys()}")
print(f"Input shape: {batch['input_ids'].shape}")

Troubleshooting#

“Dataset must have column X”#

# Check your dataset columns
print(dataset.column_names)

# Map to expected format
dataset = dataset.map(lambda x: {"chosen": x["good"], "rejected": x["bad"]})

Tokenization errors#

# Ensure tokenizer has required tokens
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

Out of memory#

# Use streaming
dataset = load_dataset(..., streaming=True)

# Or reduce batch size
ed.SFTConfig(total_batch_size=4)

# Or use gradient accumulation
ed.SFTConfig(
    total_batch_size=32,
    gradient_accumulation_steps=8,  # Effective batch = 4
)

Next Steps#