easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_trainer#

class easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_trainer.ORPOTrainer(arguments: ORPOConfig, model: EasyDeLBaseModule | EasyDeLState | None = None, data_collator: DPODataCollatorWithPaddingTFDS | DPODataCollatorWithPaddingGrain | None = None, train_dataset: Dataset | IterableDataset | ShardedDataSource | None = None, eval_dataset: Dataset | IterableDataset | ShardedDataSource | dict[str, Dataset] | None = None, processing_class: ProcessingClassType = None)[source]#

Bases: Trainer

Odds Ratio Preference Optimization trainer.

ORPO is a reference-free preference optimization method that directly optimizes the odds ratio between preferred and rejected responses. Unlike DPO, ORPO doesn’t require a reference model, making it more memory-efficient while maintaining competitive performance.

The trainer uses lazy preprocessing transforms that are applied during iteration, providing better performance than eager HF .map() calls.

arguments#

ORPOConfig with training hyperparameters

Type

easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_config.ORPOConfig

processing_class#

Tokenizer or processor for text encoding

padding_value#

Token ID used for padding

Example

>>> config = ORPOConfig(
...     per_device_train_batch_size=4,
...     orpo_beta=0.1,
...     learning_rate=5e-6,
...     max_prompt_length=512,
...     max_completion_length=512
... )
>>> trainer = ORPOTrainer(
...     arguments=config,
...     model=model,
...     train_dataset=preference_dataset,
...     processing_class=tokenizer
... )
>>> trainer.train()
arguments: ORPOConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configure and JIT-compile training and evaluation step functions.

create_grain_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Create data collection function for Grain batching.

create_tfds_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Create data collection function for TFDS batching.