easydel.trainers.direct_preference_optimization_trainer.__init__#
- class easydel.trainers.direct_preference_optimization_trainer.__init__.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'DPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, beta: float = 0.1, label_smoothing: float = 0.0, loss_type: ~typing.Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid', use_weighting: bool = False, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, max_length: ~typing.Optional[int] = 512, max_prompt_length: ~typing.Optional[int] = 256, max_completion_length: ~typing.Optional[int] = None, is_encoder_decoder: ~typing.Optional[bool] = None, disable_dropout: bool = True, precompute_ref_log_probs: bool = False, dataset_num_proc: ~typing.Optional[int] = None, reference_free: bool = False, force_use_ref_model: bool = False, sync_ref_model: bool = False, ref_model_mixup_alpha: float = 0.9, ref_model_sync_steps: int = 64, rpo_alpha: ~typing.Optional[float] = None, tools: ~typing.Optional[~typing.List[~typing.Union[dict, ~typing.Callable]]] = None)[source]#
Bases:
TrainingArgumentsConfiguration class for Direct Preference Optimization (DPO) training.
Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.
- beta#
Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1
- Type
float
- label_smoothing#
Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0
- Type
float
- loss_type#
Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’
- Type
LOSS_FN_VARIENTS
- use_weighting#
Whether to apply example weighting in loss calculation. Default: False
- Type
bool
- label_pad_token_id#
Token ID used for padding labels. Default: -100
- Type
int
- padding_value#
Value used for padding sequences. If None, uses model’s default padding token. Default: None
- Type
int | None
- max_length#
Maximum total sequence length (prompt + completion). Default: 512
- Type
int | None
- max_prompt_length#
Maximum length for prompt sequences. Default: 256
- Type
int | None
- max_completion_length#
Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None
- Type
int | None
- is_encoder_decoder#
Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None
- Type
bool | None
- disable_dropout#
Whether to disable dropout during training for deterministic behavior. Default: True
- Type
bool
- precompute_ref_log_probs#
Whether to precompute reference model log probabilities before training. Default: False
- Type
bool
- dataset_num_proc#
Number of processes for dataset preprocessing. Default: None (sequential processing)
- Type
int | None
- reference_free#
Whether to use reference-free variant of DPO. Default: False
- Type
bool
- force_use_ref_model#
Force use reference model even when reference_free=True. Default: False
- Type
bool
- sync_ref_model#
Whether to periodically sync reference model with training model. Default: False
- Type
bool
- learning_rate#
Optimizer learning rate. Default: 1e-6
- Type
float
- ref_model_mixup_alpha#
Alpha parameter for mixup between policy and reference models. Default: 0.9
- Type
float
- ref_model_sync_steps#
Number of steps between reference model syncs. Default: 64
- Type
int
- rpo_alpha#
Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None
- Type
float | None
- tools#
Additional tools for training process
- Type
list[dict | Callable] | None
Example
>>> config = DPOConfig( ... beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6 ... )
- beta: float = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- extra_optimizer_kwargs: dict#
- force_use_ref_model: bool = False#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- ids_to_pop_from_dataset: tp.Optional[tp.List[str]]#
- is_encoder_decoder: Optional[bool] = None#
- label_pad_token_id: int = -100#
- label_smoothing: float = 0.0#
- learning_rate: float = 1e-06#
- loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid'#
- max_completion_length: Optional[int] = None#
- max_length: Optional[int] = 512#
- max_prompt_length: Optional[int] = 256#
- model_name: str = 'DPOTrainer'#
- padding_value: Optional[int] = None#
- precompute_ref_log_probs: bool = False#
- ref_model_mixup_alpha: float = 0.9#
- ref_model_sync_steps: int = 64#
- reference_free: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- rpo_alpha: Optional[float] = None#
- sync_ref_model: bool = False#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- tools: Optional[List[Union[dict, Callable]]] = None#
- use_weighting: bool = False#
- class easydel.trainers.direct_preference_optimization_trainer.__init__.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#
Bases:
TrainerTrainer for Direct Preference Optimization (DPO).
This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.
- checkpoint_manager: tp.Any#
- checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
- compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#
Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.
- Parameters
state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).
padded_batch (tp.Dict) – The padded batch of data.
- Returns
A tuple containing the log probabilities for the chosen and rejected responses.
- Return type
tuple[tp.Any, tp.Any]
- config: EasyDeLBaseConfig#
- configure_dataloaders()[source]#
Returns the training dataloader, potentially with precomputed reference log probabilities.
If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.
- Returns
The training dataloader.
- Return type
tensorflow.data.Dataset
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a data collection function for batching.
For DPO training, this method simply returns the pre-configured data_collator.
- Parameters
max_sequence_length (int) – The maximum sequence length (not used in this implementation).
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.
- Returns
The data collator function.
- Return type
tp.Callable
- data_collator: tp.Optional[tp.Callable]#
- dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
- dataloader_train: tp.Iterator[np.ndarray]#
- dataset_eval: tp.Optional[Dataset]#
- dataset_train: tp.Optional[Dataset]#
- dtype: tp.Any#
- evalu_tracker: CompilationTracker#
- finetune: bool#
- max_evaluation_steps: int#
- max_training_steps: int#
- memory_monitor: tp.Any#
- model_state: EasyDeLState#
- on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#
hook process to call in start of the step.
- param_dtype: tp.Any#
- static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
- pruning_module: tp.Any#
- scheduler: optax.Schedule#
- sharded_evaluation_step_function: JitWrapped#
- sharded_training_step_function: JitWrapped#
- state: tp.Any#
- state_named_sharding: tp.Any#
- state_partition_spec: tp.Any#
- state_shape: tp.Any#
- static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
Tokenize a row of the dataset.
- Parameters
features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.
processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.
max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.
max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.
add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.
- Returns
Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.
- Return type
dict[str, list[int]]
- train_tracker: CompilationTracker#
- tx: optax.GradientTransformation#
- wandb_runtime: tp.Any#