easydel.trainers.direct_preference_optimization_trainer.dpo_config#
- class easydel.trainers.direct_preference_optimization_trainer.dpo_config.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: ~typing.Optional[str] = None, clip_grad: ~typing.Optional[float] = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: ~typing.Optional[int] = 0, dataloader_pin_memory: ~typing.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: ~typing.Optional[int] = None, evaluation_steps: ~typing.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: ~typing.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: ~typing.Optional[~typing.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: ~typing.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: ~typing.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: ~typing.Optional[~easydel.infra.loss_utils.LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: ~typing.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 4096, max_training_steps: ~typing.Optional[int] = None, model_name: str = 'BaseTrainer', model_parameters: ~typing.Optional[dict] = None, metrics_to_show_in_rich_pbar: ~typing.Optional[~typing.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: ~typing.Optional[int] = None, save_total_limit: ~typing.Optional[int] = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: ~typing.Optional[dict] = None, step_partition_spec: ~jax._src.partition_spec.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: ~typing.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: ~typing.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: ~typing.Optional[~numpy.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: ~typing.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0)[source]#
Bases:
TrainingArgumentsConfiguration class for Direct Preference Optimization (DPO) training.
Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.
- beta#
Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1
- Type
float
- label_smoothing#
Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0
- Type
float
- loss_type#
Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’
- Type
LOSS_FN_VARIENTS
- use_weighting#
Whether to apply example weighting in loss calculation. Default: False
- Type
bool
- label_pad_token_id#
Token ID used for padding labels. Default: -100
- Type
int
- padding_value#
Value used for padding sequences. If None, uses model’s default padding token. Default: None
- Type
int | None
- max_length#
Maximum total sequence length (prompt + completion). Default: 512
- Type
int | None
- max_prompt_length#
Maximum length for prompt sequences. Default: 256
- Type
int | None
- max_completion_length#
Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None
- Type
int | None
- is_encoder_decoder#
Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None
- Type
bool | None
- disable_dropout#
Whether to disable dropout during training for deterministic behavior. Default: True
- Type
bool
- precompute_ref_log_probs#
Whether to precompute reference model log probabilities before training. Default: False
- Type
bool
- dataset_num_proc#
Number of processes for dataset preprocessing. Default: None (sequential processing)
- Type
int | None
- reference_free#
Whether to use reference-free variant of DPO. Default: False
- Type
bool
- force_use_ref_model#
Force use reference model even when reference_free=True. Default: False
- Type
bool
- sync_ref_model#
Whether to periodically sync reference model with training model. Default: False
- Type
bool
- learning_rate#
Optimizer learning rate. Default: 1e-6
- Type
float
- ref_model_mixup_alpha#
Alpha parameter for mixup between policy and reference models. Default: 0.9
- Type
float
- ref_model_sync_steps#
Number of steps between reference model syncs. Default: 64
- Type
int
- rpo_alpha#
Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None
- Type
float | None
- tools#
Additional tools for training process
- Type
list[dict | Callable] | None
Example
>>> config = DPOConfig( ... beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6 ... )
- beta: float = Field(name=None,type=None,default=0.1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- dataset_num_proc: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of processes for dataset preprocessing. Default: None (sequential processing)'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- disable_dropout: bool = Field(name=None,type=None,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to disable dropout during training for deterministic behavior.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- extra_optimizer_kwargs: dict#
- force_use_ref_model: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Force use reference model even when reference_free=True.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- ids_to_pop_from_dataset: tp.Optional[tp.List[str]]#
- is_encoder_decoder: Optional[bool] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Explicitly set if model is encoder-decoder. Auto-detected if None.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- label_pad_token_id: int = Field(name=None,type=None,default=-100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Token ID used for padding labels.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- label_smoothing: float = Field(name=None,type=None,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- learning_rate: float = Field(name=None,type=None,default=1e-06,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimizer learning rate.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = Field(name=None,type=None,default='sigmoid',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Type of contrastive loss function to use. Valid options: 'sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'."}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- max_completion_length: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- max_length: Optional[int] = Field(name=None,type=None,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum total sequence length (prompt + completion).'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- max_prompt_length: Optional[int] = Field(name=None,type=None,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum length for prompt sequences.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- model_name: str = Field(name=None,type=None,default='DPOTrainer',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The name of the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- padding_value: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Value used for padding sequences. If None, uses model's default padding token."}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- precompute_ref_log_probs: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to precompute reference model log probabilities before training.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- ref_model_mixup_alpha: float = Field(name=None,type=None,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Alpha parameter for mixup between policy and reference models.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- ref_model_sync_steps: int = Field(name=None,type=None,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between reference model syncs.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- reference_free: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to use reference-free variant of DPO.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- replace(**kwargs)#
- rpo_alpha: Optional[float] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Alpha parameter for Relative Preference Optimization. None disables RPO.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- sync_ref_model: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to periodically sync reference model with training model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- tools: Optional[List[Union[dict, Callable]]] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Additional tools for training process.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
- use_weighting: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to apply example weighting in loss calculation.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#