easydel.trainers.group_relative_policy_optimization.grpo_trainer#

class easydel.trainers.group_relative_policy_optimization.grpo_trainer.GRPOTrainer(arguments: GRPOConfig, model: EasyDeLBaseModule | EasyDeLState | None, reward_funcs: RewardFunc | list[RewardFunc], train_dataset: Dataset | IterableDataset | ShardedDataSource | None = None, eval_dataset: Dataset | IterableDataset | ShardedDataSource | dict[str, Dataset] | None = None, processing_class: ProcessingClassType = None, reward_processing_classes: ProcessingClassType = None, data_tokenize_fn: tp.Callable | None = None)[source]#

Bases: Trainer

Group Relative Policy Optimization trainer for RLHF.

GRPO is a reinforcement learning method that optimizes policies by comparing responses within groups, providing more stable training than standard PPO. It uses relative scoring within batches to reduce variance and improve convergence in preference-based learning tasks.

Key features: - Group-based advantage normalization - Stable policy updates with KL regularization - Support for multiple reward models - Efficient generation and scoring pipeline

arguments#

GRPOConfig instance with training hyperparameters

Type

easydel.trainers.group_relative_policy_optimization.grpo_config.GRPOConfig

ref_state#

Reference model state for KL divergence computation

processing_class#

Tokenizer or processor for text encoding

reward_processing_classes#

Optional separate processors for reward models

generation_config#

Configuration for response generation

data_tokenize_fn#

Function to tokenize dataset samples

Example

>>> config = GRPOConfig(
...     per_device_train_batch_size=4,
...     grpo_n_samples=4,
...     grpo_beta=0.1,
...     learning_rate=1e-6
... )
>>> trainer = GRPOTrainer(
...     arguments=config,
...     model=model,
...     reward_funcs=reward_model,
...     train_dataset=dataset,
...     processing_class=tokenizer
... )
>>> trainer.train()
arguments: GRPOConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Create data collator for Grain data loading.

create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Create data collator for TFDS data loading.

on_step_end(state: EasyDeLState, metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int) tuple[easydel.infra.base_state.EasyDeLState, dict[str, float | list | tuple | numpy.ndarray | Any]][source]#

hook process to call in start of the step.

property step_sharding#
easydel.trainers.group_relative_policy_optimization.grpo_trainer.delete_tree(pytree)[source]#