easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_config#

class easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_config.ORPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: str | None = None, clip_grad: float | None = None, custom_scheduler: tp.Callable[[int], tp.Any] | None = None, dataloader_num_workers: int | None = 0, dataloader_pin_memory: bool | None = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: int | None = None, evaluation_steps: int | None = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: str | None = None, grain_shard_index: int | None = None, grain_shard_count: int | None = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: list[str] | None = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: dict | None = None, learning_rate: float = 1e-06, learning_rate_end: float | None = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: LossConfig | None = None, low_mem_usage: bool = True, max_evaluation_steps: int | None = None, max_sequence_length: int | None = 4096, max_training_steps: int | None = None, per_epoch_training_steps: int | None = None, per_epoch_evaluation_steps: int | None = None, model_name: str | None = None, model_parameters: dict | None = None, metrics_to_show_in_rich_pbar: list[str] | None = None, generation_top_p: float | None = None, generation_top_k: int | None = None, generation_temperature: float | None = None, generation_do_sample: bool | None = None, generation_num_return_sequences: int | None = None, generation_max_new_tokens: int | None = None, generation_shard_inputs: bool = True, generation_interval: int | None = None, generation_prompts: list[str | dict[str, tp.Any]] = <factory>, generation_use_train_prompts: bool = False, generation_num_prompts: int = 1, generation_dataset_prompt_field: str | None = 'prompt', generation_extra_kwargs: dict[str, tp.Any] | None = None, generation_config_overrides: dict[str, tp.Any] | None = None, generation_seed: int | None = None, generation_preview_print: bool = False, generation_log_to_wandb: bool = True, use_esurge_generation: bool = True, esurge_use_tqdm: bool = True, esurge_hbm_utilization: float | None = 0.45, esurge_max_num_seqs: int | None = None, esurge_min_input_pad: int | None = None, esurge_page_size: int | None = 32, esurge_silent_mode: bool = True, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_interval_minutes: float | None = None, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: int | None = None, save_total_limit: int | None = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, shuffle_seed_train: int = 64871, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: dict | None = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: int | None = None, resume_if_possible: bool = True, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: str | None = None, train_on_inputs: bool = True, trainer_prefix: str | None = 'orpotrainer', truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: jnp.dtype | None = None, track_memory: bool | float = False, use_data_collactor: bool = True, use_grain: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: str | None = None, wandb_name: str | None = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*', weight_distribution_log_steps: int = 50, _can_log_metrics: bool | None = None, _im_a_hidden_checkpoint_manager: Checkpointer | None = None, max_length: int | None = 1024, max_prompt_length: int | None = 512, max_completion_length: int | None = None, beta: float = 0.1, disable_dropout: bool = True, label_pad_token_id: int = -100, padding_value: int | None = None, generate_during_eval: bool = False, is_encoder_decoder: bool | None = None, dataset_num_proc: int | None = None)[source]#

Bases: TrainingArguments

Configuration class for Odds Ratio Preference Optimization training.

ORPO is a reference-free preference optimization method that uses odds ratios to model preferences between chosen and rejected responses. Unlike DPO, ORPO doesn’t require a reference model, making it more memory-efficient and simpler to implement while achieving comparable or better performance.

The key innovation of ORPO is formulating preference learning through log-odds differences: log(p/(1-p)), which provides better gradient properties than raw probabilities and eliminates the need for KL regularization with a reference model.

trainer_prefix#

Prefix for trainer logs and checkpoints. Default: “orpotrainer”

Type

str | None

learning_rate#

Learning rate for the optimizer. Default: 1e-6

Type

float

max_length#

Maximum total sequence length (prompt + completion). Default: 1024

Type

int | None

max_prompt_length#

Maximum length for prompt sequences. Default: 512

Type

int | None

max_completion_length#

Maximum length for completion sequences. Automatically calculated as max_length - max_prompt_length if None.

Type

int | None

beta#

Temperature parameter controlling the strength of preference optimization. Higher values make the model more selective between chosen and rejected responses. Default: 0.1

Type

float

disable_dropout#

Whether to disable dropout during training for deterministic behavior. Default: True

Type

bool

label_pad_token_id#

Token ID used for padding labels in loss computation. Default: -100 (ignored by PyTorch/JAX loss functions)

Type

int

padding_value#

Value used for padding input sequences. If None, uses the tokenizer’s pad_token_id.

Type

int | None

generate_during_eval#

Whether to generate sample outputs during evaluation for qualitative assessment. Default: False

Type

bool

is_encoder_decoder#

Whether the model is encoder-decoder architecture. Auto-detected if None.

Type

bool | None

model_init_kwargs#

Additional keyword arguments for model initialization.

Type

dict | None

dataset_num_proc#

Number of processes for parallel dataset preprocessing. None uses sequential processing.

Type

int | None

max_sequence_length#

Computed attribute for maximum sequence length used in training (2 * max_length for concatenated chosen/rejected).

Type

int

Example

>>> config = ORPOConfig(
...     beta=0.2,
...     max_length=2048,
...     learning_rate=2e-6,
...     num_train_epochs=3
... )

Note

ORPO loss = -log_sigmoid(beta * (log_odds_chosen - log_odds_rejected)) where log_odds = log(p/(1-p)) for each response.

beta: float = 0.1#
dataset_num_proc: int | None = None#
disable_dropout: bool = True#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

generate_during_eval: bool = False#
is_encoder_decoder: bool | None = None#
label_pad_token_id: int = -100#
learning_rate: float = 1e-06#
max_completion_length: int | None = None#
max_length: int | None = 1024#
max_prompt_length: int | None = 512#
padding_value: int | None = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

trainer_prefix: str | None = 'orpotrainer'#