easydel.trainers.group_relative_policy_optimization.grpo_trainer#
- class easydel.trainers.group_relative_policy_optimization.grpo_trainer.GRPOTrainer(arguments: GRPOConfig, model: EasyDeLBaseModule | EasyDeLState | None, reward_funcs: RewardFunc | list[RewardFunc], train_dataset: Dataset | IterableDataset | ShardedDataSource | None = None, eval_dataset: Dataset | IterableDataset | ShardedDataSource | dict[str, Dataset] | None = None, processing_class: ProcessingClassType = None, reward_processing_classes: ProcessingClassType = None, data_tokenize_fn: tp.Callable | None = None)[source]#
Bases:
TrainerGroup Relative Policy Optimization trainer for RLHF.
GRPO is a reinforcement learning method that optimizes policies by comparing responses within groups, providing more stable training than standard PPO. It uses relative scoring within batches to reduce variance and improve convergence in preference-based learning tasks.
Key features: - Group-based advantage normalization - Stable policy updates with KL regularization - Support for multiple reward models - Efficient generation and scoring pipeline
- arguments#
GRPOConfig instance with training hyperparameters
- ref_state#
Reference model state for KL divergence computation
- processing_class#
Tokenizer or processor for text encoding
- reward_processing_classes#
Optional separate processors for reward models
- generation_config#
Configuration for response generation
- data_tokenize_fn#
Function to tokenize dataset samples
Example
>>> config = GRPOConfig( ... per_device_train_batch_size=4, ... grpo_n_samples=4, ... grpo_beta=0.1, ... learning_rate=1e-6 ... ) >>> trainer = GRPOTrainer( ... arguments=config, ... model=model, ... reward_funcs=reward_model, ... train_dataset=dataset, ... processing_class=tokenizer ... ) >>> trainer.train()
- arguments: GRPOConfig#
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Create data collator for Grain data loading.
- create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Create data collator for TFDS data loading.
- on_step_end(state: EasyDeLState, metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int) tuple[easydel.infra.base_state.EasyDeLState, dict[str, float | list | tuple | numpy.ndarray | Any]][source]#
hook process to call in start of the step.
- property step_sharding#