easydel.trainers.direct_preference_optimization_trainer.dpo_config#

class easydel.trainers.direct_preference_optimization_trainer.dpo_config.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'DPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, beta: float = 0.1, label_smoothing: float = 0.0, loss_type: ~typing.Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid', use_weighting: bool = False, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, max_length: ~typing.Optional[int] = 512, max_prompt_length: ~typing.Optional[int] = 256, max_completion_length: ~typing.Optional[int] = None, is_encoder_decoder: ~typing.Optional[bool] = None, disable_dropout: bool = True, precompute_ref_log_probs: bool = False, dataset_num_proc: ~typing.Optional[int] = None, reference_free: bool = False, force_use_ref_model: bool = False, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, rpo_alpha: ~typing.Optional[float] = None, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None)[source]#

Bases: TrainingArguments

Configuration class for Direct Preference Optimization (DPO) training.

Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.

beta#

Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1

Type

float

label_smoothing#

Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0

Type

float

loss_type#

Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’

Type

LOSS_FN_VARIENTS

use_weighting#

Whether to apply example weighting in loss calculation. Default: False

Type

bool

label_pad_token_id#

Token ID used for padding labels. Default: -100

Type

int

padding_value#

Value used for padding sequences. If None, uses model’s default padding token. Default: None

Type

int | None

max_length#

Maximum total sequence length (prompt + completion). Default: 512

Type

int | None

max_prompt_length#

Maximum length for prompt sequences. Default: 256

Type

int | None

max_completion_length#

Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None

Type

int | None

is_encoder_decoder#

Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None

Type

bool | None

disable_dropout#

Whether to disable dropout during training for deterministic behavior. Default: True

Type

bool

precompute_ref_log_probs#

Whether to precompute reference model log probabilities before training. Default: False

Type

bool

dataset_num_proc#

Number of processes for dataset preprocessing. Default: None (sequential processing)

Type

int | None

reference_free#

Whether to use reference-free variant of DPO. Default: False

Type

bool

force_use_ref_model#

Force use reference model even when reference_free=True. Default: False

Type

bool

sync_ref_model#

Whether to periodically sync reference model with training model. Default: False

Type

bool

learning_rate#

Optimizer learning rate. Default: 1e-6

Type

float

ref_model_mixup_alpha#

Alpha parameter for mixup between policy and reference models. Default: 0.9

Type

float

ref_model_sync_steps#

Number of steps between reference model syncs. Default: 64

Type

int

rpo_alpha#

Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None

Type

float | None

tools#

Additional tools for training process

Type

list[dict | Callable] | None

Example

>>> config = DPOConfig(
...   beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6
... )
beta: float = 0.1#
dataset_num_proc: Optional[int] = None#
disable_dropout: bool = True#
extra_optimizer_kwargs: dict#
force_use_ref_model: bool = False#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

ids_to_pop_from_dataset: tp.Optional[tp.List[str]]#
is_encoder_decoder: Optional[bool] = None#
label_pad_token_id: int = -100#
label_smoothing: float = 0.0#
learning_rate: float = 1e-06#
loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid'#
max_completion_length: Optional[int] = None#
max_length: Optional[int] = 512#
max_prompt_length: Optional[int] = 256#
model_name: str = 'DPOTrainer'#
padding_value: Optional[int] = None#
precompute_ref_log_probs: bool = False#
ref_model_mixup_alpha: float = 0.9#
ref_model_sync_steps: int = 64#
reference_free: bool = False#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

rpo_alpha: Optional[float] = None#
sync_ref_model: bool = False#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

tools: Optional[List[Union[dict, Callable]]] = None#
use_weighting: bool = False#