easydel.trainers.direct_preference_optimization_trainer.__init__

Contents

easydel.trainers.direct_preference_optimization_trainer.__init__#

class easydel.trainers.direct_preference_optimization_trainer.__init__.DPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: ~typing.Optional[str] = None, clip_grad: ~typing.Optional[float] = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: ~typing.Optional[int] = 0, dataloader_pin_memory: ~typing.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: ~typing.Optional[int] = None, evaluation_steps: ~typing.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: ~typing.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: ~typing.Optional[~typing.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: ~typing.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: ~typing.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: ~typing.Optional[~easydel.infra.loss_utils.LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: ~typing.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 4096, max_training_steps: ~typing.Optional[int] = None, model_name: str = 'BaseTrainer', model_parameters: ~typing.Optional[dict] = None, metrics_to_show_in_rich_pbar: ~typing.Optional[~typing.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: ~typing.Optional[int] = None, save_total_limit: ~typing.Optional[int] = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: ~typing.Optional[dict] = None, step_partition_spec: ~jax._src.partition_spec.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: ~typing.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: ~typing.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: ~typing.Optional[~numpy.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: ~typing.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0)[source]#

Bases: TrainingArguments

Configuration class for Direct Preference Optimization (DPO) training.

Inherits from TrainingArguments and adds parameters specific to DPO training as described in https://arxiv.org/abs/2305.18290. This configuration controls various aspects of the DPO training process including loss computation, model architecture, and dataset processing.

beta#

Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching. Default: 0.1

Type

float

label_smoothing#

Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing. Default: 0.0

Type

float

loss_type#

Type of contrastive loss function to use. Valid options: ‘sigmoid’, ‘hinge’, ‘ipo’, ‘exo_pair’, ‘nca_pair’, ‘robust’, ‘bco_pair’, ‘sppo_hard’, ‘aot’, ‘aot_pair’, ‘apo_zero’, ‘apo_down’. Default: ‘sigmoid’

Type

LOSS_FN_VARIENTS

use_weighting#

Whether to apply example weighting in loss calculation. Default: False

Type

bool

label_pad_token_id#

Token ID used for padding labels. Default: -100

Type

int

padding_value#

Value used for padding sequences. If None, uses model’s default padding token. Default: None

Type

int | None

max_length#

Maximum total sequence length (prompt + completion). Default: 512

Type

int | None

max_prompt_length#

Maximum length for prompt sequences. Default: 256

Type

int | None

max_completion_length#

Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None. Default: None

Type

int | None

is_encoder_decoder#

Explicitly set if model is encoder-decoder. Auto-detected if None. Default: None

Type

bool | None

disable_dropout#

Whether to disable dropout during training for deterministic behavior. Default: True

Type

bool

precompute_ref_log_probs#

Whether to precompute reference model log probabilities before training. Default: False

Type

bool

dataset_num_proc#

Number of processes for dataset preprocessing. Default: None (sequential processing)

Type

int | None

reference_free#

Whether to use reference-free variant of DPO. Default: False

Type

bool

force_use_ref_model#

Force use reference model even when reference_free=True. Default: False

Type

bool

sync_ref_model#

Whether to periodically sync reference model with training model. Default: False

Type

bool

learning_rate#

Optimizer learning rate. Default: 1e-6

Type

float

ref_model_mixup_alpha#

Alpha parameter for mixup between policy and reference models. Default: 0.9

Type

float

ref_model_sync_steps#

Number of steps between reference model syncs. Default: 64

Type

int

rpo_alpha#

Alpha parameter for Relative Preference Optimization. None disables RPO. Default: None

Type

float | None

tools#

Additional tools for training process

Type

list[dict | Callable] | None

Example

>>> config = DPOConfig(
...   beta=0.2, loss_type="ipo", max_length=1024, learning_rate=5e-6
... )
beta: float = Field(name=None,type=None,default=0.1,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Temperature parameter (β) controlling deviation from reference model. Higher values make training focus more on preference matching.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
dataset_num_proc: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of processes for dataset preprocessing. Default: None (sequential processing)'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
disable_dropout: bool = Field(name=None,type=None,default=True,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to disable dropout during training for deterministic behavior.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
extra_optimizer_kwargs: dict#
force_use_ref_model: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Force use reference model even when reference_free=True.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
ids_to_pop_from_dataset: tp.Optional[tp.List[str]]#
is_encoder_decoder: Optional[bool] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Explicitly set if model is encoder-decoder. Auto-detected if None.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
label_pad_token_id: int = Field(name=None,type=None,default=-100,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Token ID used for padding labels.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
label_smoothing: float = Field(name=None,type=None,default=0.0,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Smoothing factor for labels in loss calculation. Helps prevent overconfidence. 0.0 means no smoothing.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
learning_rate: float = Field(name=None,type=None,default=1e-06,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Optimizer learning rate.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = Field(name=None,type=None,default='sigmoid',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Type of contrastive loss function to use. Valid options: 'sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'."}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_completion_length: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum length for completion sequences. Auto-calculated as max_length - max_prompt_length if None.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_length: Optional[int] = Field(name=None,type=None,default=512,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum total sequence length (prompt + completion).'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
max_prompt_length: Optional[int] = Field(name=None,type=None,default=256,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Maximum length for prompt sequences.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
model_name: str = Field(name=None,type=None,default='DPOTrainer',default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'The name of the model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
padding_value: Optional[int] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': "Value used for padding sequences. If None, uses model's default padding token."}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
precompute_ref_log_probs: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to precompute reference model log probabilities before training.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
ref_model_mixup_alpha: float = Field(name=None,type=None,default=0.9,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Alpha parameter for mixup between policy and reference models.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
ref_model_sync_steps: int = Field(name=None,type=None,default=64,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Number of steps between reference model syncs.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
reference_free: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to use reference-free variant of DPO.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
replace(**kwargs)#
rpo_alpha: Optional[float] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Alpha parameter for Relative Preference Optimization. None disables RPO.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
sync_ref_model: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to periodically sync reference model with training model.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
tools: Optional[List[Union[dict, Callable]]] = Field(name=None,type=None,default=None,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Additional tools for training process.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
use_weighting: bool = Field(name=None,type=None,default=False,default_factory=<dataclasses._MISSING_TYPE object>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'help': 'Whether to apply example weighting in loss calculation.'}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)#
class easydel.trainers.direct_preference_optimization_trainer.__init__.DPOTrainer(arguments: DPOConfig, model: Union[EasyDeLBaseModule, EasyDeLState], reference_model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, processing_class: Optional[Any] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Any] = None, data_collator: Optional[Callable] = None)[source]#

Bases: Trainer

Trainer for Direct Preference Optimization (DPO).

This trainer handles the training, evaluation, and checkpointing of language models using the DPO algorithm. It supports sharding, gradient accumulation, mixed precision training, LoRA, and precomputed reference model log probabilities.

arguments: DPOConfig#
checkpoint_manager: tp.Any#
checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
compute_reference_log_probs(state: EasyDeLState, padded_batch: Dict) tuple[Any, Any][source]#

Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.

Parameters
  • state (EasyDeLState) – The EasyDeLState object of the model (used if no reference model is provided).

  • padded_batch (tp.Dict) – The padded batch of data.

Returns

A tuple containing the log probabilities for the chosen and rejected responses.

Return type

tuple[tp.Any, tp.Any]

config: EasyDeLBaseConfig#
configure_dataloaders()[source]#

Returns the training dataloader, potentially with precomputed reference log probabilities.

If precompute_ref_log_probs is enabled, this method computes the reference model’s log probabilities for the chosen and rejected responses in the training dataset and adds them as columns to the dataset.

Returns

The training dataloader.

Return type

tensorflow.data.Dataset

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

data_collator: tp.Optional[tp.Callable]#
dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
dataloader_train: tp.Iterator[np.ndarray]#
dataset_eval: tp.Optional[Dataset]#
dataset_train: tp.Optional[Dataset]#
dtype: tp.Any#
evalu_tracker: CompilationTracker#
finetune: bool#
max_evaluation_steps: int#
max_training_steps: int#
memory_monitor: tp.Any#
model_state: EasyDeLState#
on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

param_dtype: tp.Any#
static process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#
pruning_module: tp.Any#
scheduler: optax.Schedule#
sharded_evaluation_step_function: JitWrapped#
sharded_training_step_function: JitWrapped#
state: tp.Any#
state_named_sharding: tp.Any#
state_partition_spec: tp.Any#
state_shape: tp.Any#
timer: Timers#
static tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens)[source]#

Tokenize a row of the dataset.

Parameters
  • features (dict[str, str]) – Row of the dataset, should contain the keys “prompt”, “chosen”, and “rejected”.

  • processing_class (PreTrainedTokenizerBase) – Processing class used to process the data.

  • max_prompt_length (int or None) – Maximum length of the prompt sequence. If None, the prompt sequence is not truncated.

  • max_completion_length (int or None) – Maximum length of the completion sequences. If None, the completion sequences are not truncated.

  • add_special_tokens (bool) – Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If True, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.

Returns

Tokenized sequences with the keys “prompt_input_ids”, “chosen_input_ids”, and `”rejected_input_ids”.

Return type

dict[str, list[int]]

train_tracker: CompilationTracker#
tx: optax.GradientTransformation#
wandb_runtime: tp.Any#