easydel.trainers.reward_trainer.reward_trainer#

class easydel.trainers.reward_trainer.reward_trainer.RewardTrainer(arguments: RewardConfig, processing_class: ProcessingClassType, model: EasyDeLBaseModule | EasyDeLState | None = None, train_dataset: Dataset | IterableDataset | ShardedDataSource | None = None, eval_dataset: Dataset | IterableDataset | ShardedDataSource | dict[str, Dataset] | None = None, data_collator: RewardDataCollatorWithPaddingTFDS | RewardDataCollatorWithPaddingGrain | None = None)[source]#

Bases: Trainer

Reward model trainer for RLHF pipelines.

Trains reward models that learn to score responses based on human preferences. The reward model is typically used in the RLHF pipeline to provide feedback signals for policy optimization methods like PPO or DPO.

The trainer uses lazy preprocessing transforms that are applied during iteration, providing better performance than eager HF .map() calls.

arguments#

RewardConfig with training hyperparameters

Type

easydel.trainers.training_configurations.TrainingArguments

processing_class#

Tokenizer or processor for text encoding

Example

>>> config = RewardConfig(
...     per_device_train_batch_size=8,
...     learning_rate=2e-5,
...     max_sequence_length=512
... )
>>> trainer = RewardTrainer(
...     arguments=config,
...     model=reward_model,
...     train_dataset=preference_dataset,
...     processing_class=tokenizer
... )
>>> trainer.train()

Note

The dataset should contain ‘chosen’ and ‘rejected’ columns with text examples representing preferred and non-preferred responses.

configure_functions() TrainerConfigureFunctionOutput[source]#

Configure and JIT-compile training and evaluation step functions.

create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Create data collection function for Grain batching.

create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Create data collection function for TFDS batching.