# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as tp
from functools import partial
import flax
import flax.nnx
import jax
from eformer.escale import with_sharding_constraint
from flax import nnx as nn
from jax import numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
from transformers import AutoTokenizer
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.base_state import EasyDeLState
from easydel.infra.utils import ProcessingClassType
from easydel.utils import Registry
from easydel.utils.compiling_utils import ejit
from easydel.utils.helpers import capture_time, get_logger
from easydel.utils.traversals import deepcopy_model
from ..prompt_transforms import GRPOPreprocessTransform, is_conversational
from ..prompt_utils import apply_chat_template
from ..trainer.trainer import Trainer
from ..trainer_protocol import TrainerConfigureFunctionOutput
from ..training_configurations import MetricsType
from ._fn import get_per_token_logps, grpo_step
from .grpo_config import GRPOConfig
try:
import wandb # type:ignore
except ImportError:
wandb = None
if tp.TYPE_CHECKING:
from datasets import Dataset, IterableDataset
from easydel.data.core.protocols import ShardedDataSource
logger = get_logger(__name__)
RewardFunc = tp.Union[EasyDeLBaseModule, EasyDeLState, tp.Callable[[list, list], list[float]]] # noqa
def _fileaf(x):
return isinstance(x, jax.Array)
[docs]def delete_tree(pytree):
return jax.tree_util.tree_map(
lambda x: x.delete() if isinstance(x, jax.Array) else None,
pytree,
is_leaf=_fileaf,
)
[docs]@Registry.register("trainer", "grpo")
class GRPOTrainer(Trainer):
"""Group Relative Policy Optimization trainer for RLHF.
GRPO is a reinforcement learning method that optimizes policies by comparing
responses within groups, providing more stable training than standard PPO.
It uses relative scoring within batches to reduce variance and improve
convergence in preference-based learning tasks.
Key features:
- Group-based advantage normalization
- Stable policy updates with KL regularization
- Support for multiple reward models
- Efficient generation and scoring pipeline
Attributes:
arguments: GRPOConfig instance with training hyperparameters
ref_state: Reference model state for KL divergence computation
processing_class: Tokenizer or processor for text encoding
reward_processing_classes: Optional separate processors for reward models
generation_config: Configuration for response generation
data_tokenize_fn: Function to tokenize dataset samples
Example:
>>> config = GRPOConfig(
... per_device_train_batch_size=4,
... grpo_n_samples=4,
... grpo_beta=0.1,
... learning_rate=1e-6
... )
>>> trainer = GRPOTrainer(
... arguments=config,
... model=model,
... reward_funcs=reward_model,
... train_dataset=dataset,
... processing_class=tokenizer
... )
>>> trainer.train()
"""
arguments: GRPOConfig # type hinting
def __init__(
self,
arguments: GRPOConfig,
model: EasyDeLBaseModule | EasyDeLState | None,
reward_funcs: RewardFunc | list[RewardFunc],
train_dataset: Dataset | IterableDataset | ShardedDataSource | None = None,
eval_dataset: Dataset | IterableDataset | ShardedDataSource | dict[str, Dataset] | None = None,
processing_class: ProcessingClassType = None,
reward_processing_classes: ProcessingClassType = None,
data_tokenize_fn: tp.Callable | None = None,
):
assert arguments is not None, (
"You Have to pass `arguments` that will be used for training, but you have passed `arguments=None`"
)
assert isinstance(arguments, GRPOConfig), f"arguments type must be `GRPOConfig` but got {type(arguments)}"
assert processing_class is not None, "processing_class must be specified to tokenize a DPO dataset."
self.arguments = arguments
self.truncation_mode = arguments.truncation_mode
self.processing_class = processing_class
self.loss_type = arguments.loss_type.lower() if isinstance(arguments.loss_type, str) else arguments.loss_type
self.epsilon = arguments.epsilon
self.epsilon_high = arguments.epsilon_high
self.delta = arguments.delta
self.importance_sampling_level = arguments.importance_sampling_level
if isinstance(self.importance_sampling_level, str):
self.importance_sampling_level = self.importance_sampling_level.lower()
self.scale_rewards = arguments.scale_rewards
if isinstance(self.scale_rewards, str):
self.scale_rewards = self.scale_rewards.lower()
self.top_entropy_quantile = arguments.top_entropy_quantile
if not isinstance(model, EasyDeLState):
model = model.to_state()
self.ref_state = deepcopy_model(model=model)
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(
model.model.config._name_or_path,
padding_side="left",
)
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
self.reward_funcs = reward_funcs
if reward_processing_classes is None:
reward_processing_classes = [None] * len(reward_funcs)
elif not isinstance(reward_processing_classes, list):
reward_processing_classes = [reward_processing_classes]
else:
if len(reward_processing_classes) != len(reward_funcs):
raise ValueError("The number of reward processing classes must match the number of reward functions.")
empty_sharding = NamedSharding(spec=PartitionSpec(), mesh=model.model.mesh)
for i, (reward_processing_class, reward_func) in enumerate(
zip(reward_processing_classes, reward_funcs, strict=False)
):
if isinstance(reward_func, EasyDeLBaseModule | EasyDeLState):
if isinstance(reward_func, EasyDeLBaseModule):
reward_func = reward_func.to_state()
sharding = reward_func.shardings
@ejit(
static_argnums=(0,),
in_shardings=(sharding.graphstate, sharding.graphother, empty_sharding),
out_shardings=empty_sharding,
)
def apply_fn(gd, gs, gt, batch):
batch = with_sharding_constraint(arr=batch, sharding=self.arguments.step_partition_spec)
return nn.merge(gd, gs, gt)(**batch)
reward_func = reward_func.replace(apply_fn=apply_fn)
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.model.config._name_or_path)
if reward_processing_class.pad_token_id is None:
reward_processing_class.pad_token = reward_processing_class.eos_token
reward_func.model.config.pad_token_id = reward_processing_class.pad_token_id
reward_processing_classes[i] = reward_processing_class
reward_funcs[i] = reward_func
if arguments.reward_weights is not None and len(arguments.reward_weights) != len(reward_funcs):
raise ValueError(
f"Expected {len(reward_funcs)} reward weights, but got {len(arguments.reward_weights)} instead."
)
self.reward_weights = jnp.asarray(
arguments.reward_weights if arguments.reward_weights is not None else [1.0] * len(reward_funcs),
dtype="f4",
)
self.reward_func_names = [getattr(func, "__name__", None) or func.__class__.__name__ for func in reward_funcs]
self.num_generations = arguments.num_generations
self.reward_processing_classes = reward_processing_classes
self.reward_funcs = reward_funcs
self.arguments = arguments
self.processing_class = processing_class
if getattr(self.arguments, "generation_num_return_sequences", None) is None:
self.arguments.generation_num_return_sequences = self.num_generations
if getattr(self.arguments, "generation_top_p", None) is None:
self.arguments.generation_top_p = self.arguments.top_p
if getattr(self.arguments, "generation_top_k", None) is None:
self.arguments.generation_top_k = self.arguments.top_k
if getattr(self.arguments, "generation_temperature", None) is None:
self.arguments.generation_temperature = self.arguments.temperature
if getattr(self.arguments, "generation_extra_kwargs", None) is None:
self.arguments.generation_extra_kwargs = {}
if self.arguments.generation_kwargs is not None:
self.arguments.generation_extra_kwargs.update(self.arguments.generation_kwargs)
for key, value in (
("min_p", self.arguments.min_p),
("repetition_penalty", self.arguments.repetition_penalty),
):
if value is not None and key not in self.arguments.generation_extra_kwargs:
self.arguments.generation_extra_kwargs[key] = value
# Check if datasets are conversational before passing to BaseTrainer
self.train_is_conversational = False
self.eval_is_conversational = False
if train_dataset is not None:
try:
self.train_is_conversational = is_conversational(train_dataset[0])
except (IndexError, KeyError):
pass
if eval_dataset is not None:
try:
self.eval_is_conversational = is_conversational(eval_dataset[0])
except (IndexError, KeyError):
pass
self.data_tokenize_fn = data_tokenize_fn
log_table = None
if self.arguments.use_wandb and self.arguments.can_log_metrics and wandb is not None:
log_table = wandb.Table(columns=["generated_result", "input_prompt", "took", "length", "step"])
self.log_table = log_table
super().__init__(
model_state=model,
arguments=arguments,
dataset_train=train_dataset,
dataset_eval=eval_dataset,
data_collator=None,
processing_class=processing_class,
)
def _get_preprocess_transform(self) -> GRPOPreprocessTransform | None:
"""Get GRPO preprocessing transform for ShardedDataSource."""
if self._is_pretokenized():
return None
return GRPOPreprocessTransform(
tokenizer=self.processing_class,
max_prompt_length=self.arguments.max_prompt_length,
tools=getattr(self.arguments, "tools", None),
skip_apply_chat_template=self.arguments.skip_apply_chat_template,
)
def _is_pretokenized(self) -> bool:
"""Check if dataset already has tokenized fields."""
if self._train_source is None:
return False
try:
sample = next(iter(self._train_source.open_shard(self._train_source.shard_names[0])))
return "input_ids" in sample
except (StopIteration, IndexError):
return False
[docs] def create_grain_collect_function(
self,
max_sequence_length: int,
truncation_mode: tp.Literal["keep_end", "keep_start"] = "keep_end",
) -> tp.Callable:
"""Create data collator for Grain data loading."""
from ..utils import GRPODataCollatorGrain
return GRPODataCollatorGrain(
max_prompt_length=self.arguments.max_prompt_length,
pad_token_id=self.padding_value,
)
[docs] def create_tfds_collect_function(
self,
max_sequence_length: int,
truncation_mode: tp.Literal["keep_end", "keep_start"] = "keep_end",
) -> tp.Callable:
"""Create data collator for TFDS data loading."""
from ..utils import GRPODataCollatorTFDS
return GRPODataCollatorTFDS(
max_prompt_length=self.arguments.max_prompt_length,
pad_token_id=self.padding_value,
)
@property
def step_sharding(self):
return NamedSharding(
mesh=self.model.mesh,
spec=self.arguments.step_partition_spec,
)
def _preprocess_batch_input(
self,
state: EasyDeLState,
batch: dict[str, jax.Array],
is_train: bool,
) -> tuple[dict[str, jax.Array], dict[str, float | int | str]]:
# Purify batch first to handle list of dicts (uncollated batch)
batch = self._purify_batch(batch)
with capture_time() as preprocessing_time_fn:
prompt_ids, prompt_mask = batch["input_ids"], batch["attention_mask"]
with capture_time() as generation_time_fn:
results = self.generate_unified(
input_ids=prompt_ids,
attention_mask=prompt_mask,
state=state,
apply_chat_template=False, # GRPO doesn't apply chat template to prompts
shard_inputs=False, # Already sharded
all_gather=False, # We'll handle gathering ourselves
)
sequences = results.sequences
prompt_ids = results.prompt_ids
prompt_mask = results.prompt_mask
completion_ids = results.completion_ids
completion_prompts = results.completion_prompts
generation_time = generation_time_fn()
prompt_completion_ids = sequences
completion_mask = self._make_attn_mask(completion_ids)
if self.arguments.mask_truncated_completions:
eos_tokens = jnp.asarray(self._eos_token_id).reshape(-1)
has_eos = jnp.any(jnp.isin(completion_ids, eos_tokens), axis=1)
completion_mask = completion_mask * has_eos[:, None].astype(completion_mask.dtype)
# Derive how many completions we have per prompt instead of trusting config-only value.
generation_factor = completion_ids.shape[0] // max(prompt_mask.shape[0], 1)
generation_factor = max(generation_factor, 1)
ridmask = prompt_mask.repeat(generation_factor, 0)
with capture_time() as token_logps_time_fn:
ref_per_token_logps = self.compute_refmodel_logps(
self.ref_state.graphstate,
self.ref_state.graphother,
prompt_completion_ids,
jnp.concatenate([ridmask, completion_mask], -1),
)
token_logps_time = token_logps_time_fn()
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
is_conversational = self.train_is_conversational if is_train else self.eval_is_conversational
if is_conversational:
completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
else:
completions = completions_text
rewards_per_func = jnp.full(
(prompt_ids.shape[0] * generation_factor, len(self.reward_funcs)),
jnp.nan,
dtype="f4",
)
with capture_time() as rewarding_time_fn:
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes, strict=False)
):
if isinstance(reward_func, EasyDeLState):
if is_conversational:
messages = [
{"messages": p + c} for p, c in zip(completion_prompts, completions, strict=False)
]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(completion_prompts, completions, strict=False)]
rew = reward_func.apply_fn(
reward_func.graphdef,
reward_func.graphstate,
reward_func.graphother,
dict(
reward_processing_class(
texts,
return_tensors="jax",
padding="max_length",
padding_side="right",
add_special_tokens=False,
truncation=True,
return_attention_mask=True,
max_length=self.arguments.max_sequence_length,
)
),
).logits[:, 0]
else:
in_prompts = completion_prompts
output_reward_func = reward_func(
prompts=in_prompts,
completions=completions,
max_length=self.arguments.max_sequence_length,
batch=batch,
)
rew = jnp.array(
[val if val is not None else jnp.nan for val in output_reward_func],
dtype="f4",
)
rewards_per_func = rewards_per_func.at[:, i].set(rew.reshape(-1))
rewarding_time = rewarding_time_fn()
log_completion_ids = completion_ids
log_completion_length = jnp.sum(completion_mask, -1)
prompt_ids = self._all_gather(prompt_ids)
prompt_mask = self._all_gather(prompt_mask)
completion_ids = self._all_gather(completion_ids)
completion_mask = self._all_gather(completion_mask)
ref_per_token_logps = self._all_gather(ref_per_token_logps)
rewards_per_func = self._all_gather(rewards_per_func)
with capture_time() as grouped_comp_time_fn:
generation_factor = completion_ids.shape[0] // max(prompt_mask.shape[0], 1)
generation_factor = max(generation_factor, 1)
rewards = jnp.nansum(rewards_per_func * self.reward_weights[None, :], axis=1)
mean_grouped_rewards = jnp.nanmean(rewards.reshape(-1, generation_factor), axis=-1)
advantages = rewards - mean_grouped_rewards.repeat(generation_factor, axis=0)
if self.scale_rewards in ("group", "none"):
std_rewards = jnp.nanstd(rewards.reshape(-1, generation_factor), axis=-1)
std_rewards = std_rewards.repeat(generation_factor, axis=0)
elif self.scale_rewards == "batch":
std_rewards = jnp.nanstd(rewards)
std_rewards = jnp.broadcast_to(std_rewards, advantages.shape)
else:
raise ValueError(
f"Invalid value for scale_rewards: {self.scale_rewards}. Must be 'batch', 'group', or 'none'."
)
is_std_zero = jnp.isclose(std_rewards, 0.0)
if self.scale_rewards != "none":
advantages = advantages / (std_rewards + 1e-4)
advantages = jnp.nan_to_num(advantages)
grouped_comp_time = grouped_comp_time_fn()
preprocessing_time = preprocessing_time_fn()
completion_length = jnp.sum(completion_mask, -1)
metrics_dict = {
"reward_mean": jnp.nanmean(rewards, -1),
"reward_std": jnp.nanmean(std_rewards),
"completion_length": jnp.mean(completion_length),
"grouped_comp_time": grouped_comp_time,
"rewarding_time": rewarding_time,
"token_logps_time": token_logps_time,
"generation_time": generation_time,
"preprocessing_time": preprocessing_time,
"frac_reward_zero_std": jnp.mean(is_std_zero.astype(jnp.float32)),
}
for i, reward_func_name in enumerate(self.reward_func_names):
metrics_dict[reward_func_name] = jnp.nanmean(rewards_per_func[:, i])
if self.log_table is not None:
cur_step = jax.device_get(state.step)
decoded_prompt = completion_prompts
decoded_text = self._decode_prompt_batch(
self.processing_class,
jax.device_get(log_completion_ids),
False,
self._pad_token_id,
True,
)
for decoded, prompt, length in zip(decoded_text, decoded_prompt, log_completion_length, strict=False):
prompt_repr = prompt if isinstance(prompt, str) else str(prompt)
self.log_table.add_data(decoded, prompt_repr, generation_time, float(jax.device_get(length)), cur_step)
wandb.log({"generations": self.log_table}, step=cur_step)
# i don't care who you are and what you do.
# ill find you and ill gather u...
return (
{
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
"num_items_in_batch": jnp.sum(completion_mask),
},
metrics_dict,
)
[docs] def on_step_end(
self,
state: EasyDeLState,
metrics: MetricsType,
step: int,
) -> tuple[EasyDeLState, MetricsType]:
"""hook process to call in start of the step."""
if (
self.arguments.sync_ref_model
and self.ref_state is not None
and (step % self.arguments.ref_model_sync_steps == 0)
):
alpha = self.arguments.ref_model_mixup_alpha
new_graphstate = jax.tree_util.tree_map(
lambda new, old: alpha * new + (1 - alpha) * old,
deepcopy_model(state.graphstate),
self.ref_state.graphstate,
)
self.ref_state = self.ref_state.replace(graphstate=new_graphstate)
return state, metrics