Group Relative Policy Optimization (GRPO) Trainer#
Group Relative Policy Optimization (GRPO) is an advanced reinforcement learning technique for fine-tuning language models using multiple generations and reward functions. This tutorial demonstrates how to use EasyDeL’s GRPOTrainer.
Overview#
The GRPO trainer generates multiple completions for each prompt, evaluates them with reward functions, and optimizes the policy to favor higher-reward generations. This is particularly useful for tasks like mathematical reasoning, where sampling multiple solutions and selecting the best ones can significantly improve performance.
Configuration#
The GRPOTrainer is configured using the GRPOConfig class with the following key parameters:
from easydel.trainers import GRPOConfig
grpo_config = GRPOConfig(
# Model and training basics
model_name="GRPOTrainer", # Name of the model
learning_rate=1e-6, # Learning rate for optimization
# Sequence parameters
max_prompt_length=512, # Maximum length for prompts
max_completion_length=256, # Maximum length for completions
# GRPO specific parameters
beta=0.04, # Controls policy deviation penalty
# Reference model parameters
sync_ref_model=False, # Whether to sync the reference model periodically
ref_model_mixup_alpha=0.9, # Alpha parameter for reference model mixing
ref_model_sync_steps=64, # Steps between reference model syncs
# Dataset processing
dataset_num_proc=None, # Processes for dataset processing
skip_apply_chat_template=False, # Skip chat template extraction
# Batch and training parameters
total_batch_size=16, # Total batch size
num_train_epochs=3, # Number of training epochs
)
Basic Usage#
Here’s how to set up and use the GRPOTrainer for fine-tuning a language model:
import easydel as ed
import jax
from jax import numpy as jnp
from transformers import AutoTokenizer
from datasets import load_dataset
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
# Load model
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
dtype=jnp.bfloat16
)
# Load dataset (must have a "prompt" field)
dataset = load_dataset("gsm8k", "main", split="train[:10%]")
# Define a reward function - must return values between 0 and 1
def math_problem_reward_function(completion, reference_answer=None):
"""Example reward function that extracts and verifies an answer."""
try:
# Extract answer number from completion
import re
answer_match = re.search(r"The answer is (\d+)", completion)
if answer_match:
extracted_answer = int(answer_match.group(1))
# Compare with reference if available
if reference_answer is not None:
return 1.0 if extracted_answer == reference_answer else 0.0
# Some other scoring logic when reference not available
return min(1.0, max(0.0, 0.5)) # Default score
return 0.0
except:
return 0.0
# Initialize vInference for generation
inference = ed.vInference(
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
num_return_sequences=4, # Generate 4 different completions
temperature=0.7,
top_p=0.95
)
# Create GRPO config
config = ed.GRPOConfig(
model_name="grpo_math_solver",
save_directory="grpo_checkpoints",
beta=0.04,
max_prompt_length=512,
max_completion_length=256,
total_batch_size=16,
learning_rate=1e-6,
num_train_epochs=3,
use_wandb=True,
)
# Initialize trainer
trainer = ed.GRPOTrainer(
arguments=config,
vinference=inference,
model=model,
reward_funcs=math_problem_reward_function,
train_dataset=dataset,
processing_class=tokenizer,
)
# Start training
trainer.train()
Command Line Training#
You can run GRPO training directly from the command line for math problem solving:
python -m easydel.scripts.finetune.gsm8k_grpo \
--repo_id meta-llama/Llama-3.1-8B-Instruct \
--attn_mechanism vanilla \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--beta 0.04 \
--top_p 0.95 \
--top_k 50 \
--num_return_sequences 4 \
--xml_reward 0.125 \
--correctness_reward 2.0 \
--total_batch_size 16 \
--learning_rate 1e-6 \
--num_train_epochs 3 \
--use_wandb
Or for NuminaMath fine-tuning:
python -m easydel.scripts.finetune.numinamath_grpo \
--repo_id meta-llama/Llama-3.1-8B-Instruct \
--attn_mechanism vanilla \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--beta 0.04 \
--top_p 0.95 \
--num_return_sequences 4 \
--total_batch_size 16 \
--learning_rate 1e-6 \
--num_train_epochs 3 \
--use_wandb
Dataset Format#
For GRPO, the dataset must include a “prompt” field that contains the text prompt:
{
"prompt": "Solve this math problem step-by-step: If John has 5 apples and buys 3 more, how many does he have in total?",
"answer": "8" # Optional reference answer for reward calculation
}
Advanced Usage#
Multiple Reward Functions#
You can use multiple reward functions to shape the model’s behavior:
def correctness_reward(completion, reference=None):
# Check if the answer matches the reference
return 1.0 if reference in completion else 0.0
def reasoning_reward(completion):
# Score the quality of reasoning in the completion
if "step 1" in completion.lower() and "therefore" in completion.lower():
return 0.8 # Good reasoning pattern
return 0.3 # Lacking clear reasoning
# Use multiple reward functions
trainer = ed.GRPOTrainer(
# Other parameters...
reward_funcs=[correctness_reward, reasoning_reward],
# If you need specific tokenizers for different reward functions:
reward_processing_classes=[tokenizer, tokenizer],
)
Custom Generation Parameters#
You can customize the generation parameters in vInference:
inference = ed.vInference(
model=model,
tokenizer=tokenizer,
max_new_tokens=256,
num_return_sequences=8, # More generations for better exploration
temperature=0.9, # Higher temperature for more diversity
top_p=0.92,
top_k=100,
do_sample=True,
repetition_penalty=1.1,
)
Tips for Effective GRPO Training#
Reward function design: Create reward functions that accurately measure desired qualities in generations
Number of generations: More generations (4-8) typically provide better training signal, but use more memory
Beta tuning: Start with lower values (0.01-0.05) and adjust based on training stability
Diversity vs. quality: Balance generation parameters to get diverse but high-quality samples
Dataset selection: GRPO works best on datasets with clear evaluation criteria (math, coding, reasoning)