easydel.trainers.odds_ratio_preference_optimization_trainer.__init__#
- class easydel.trainers.odds_ratio_preference_optimization_trainer.__init__.ORPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'ORPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, max_length: ~typing.Optional[int] = 1024, max_prompt_length: ~typing.Optional[int] = 512, max_completion_length: ~typing.Optional[int] = None, beta: float = 0.1, disable_dropout: bool = True, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, generate_during_eval: bool = False, is_encoder_decoder: ~typing.Optional[bool] = None, dataset_num_proc: ~typing.Optional[int] = None)[source]#
Bases:
TrainingArgumentsConfiguration class for ORPO training settings.
This class inherits from TrainingArguments and holds configuration parameters specific to the ORPO model training. The dataclass automatically generates an initializer, and the __post_init__ method further processes some of the parameters after object initialization.
- model_name#
The name of the model. Default is “ORPOTrainer”.
- Type
str
- learning_rate#
The learning rate used during training. Default is 1e-6.
- Type
float
- max_length#
The maximum allowed sequence length for the input. Default is 1024.
- Type
Optional[int]
- max_prompt_length#
The maximum allowed length of the prompt portion of the input. Default is 512.
- Type
Optional[int]
- max_completion_length#
The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length.
- Type
Optional[int]
- beta#
A hyperparameter beta, with a default value of 0.1.
- Type
float
- disable_dropout#
Flag to disable dropout during training. Default is True.
- Type
bool
- label_pad_token_id#
The token id used for padding labels. Default is -100.
- Type
int
- padding_value#
The value used for padding sequences. Default is None.
- Type
Optional[int]
- generate_during_eval#
Flag indicating whether to generate sequences during evaluation. Default is False.
- Type
bool
- is_encoder_decoder#
Flag to indicate if the model is encoder-decoder. Default is None.
- Type
Optional[bool]
- model_init_kwargs#
Additional keyword arguments for model initialization. Default is None.
- Type
Optional[Dict[str, Any]]
- dataset_num_proc#
Number of processes to use for dataset processing. Default is None.
- Type
Optional[int]
- max_sequence_length#
Computed attribute representing the maximum sequence length used for training. It is set in the __post_init__ method.
- Type
int
- beta: float = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- extra_optimizer_kwargs: dict#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- generate_during_eval: bool = False#
- ids_to_pop_from_dataset: tp.Optional[tp.List[str]]#
- is_encoder_decoder: Optional[bool] = None#
- label_pad_token_id: int = -100#
- learning_rate: float = 1e-06#
- max_completion_length: Optional[int] = None#
- max_length: Optional[int] = 1024#
- max_prompt_length: Optional[int] = 512#
- model_name: str = 'ORPOTrainer'#
- padding_value: Optional[int] = None#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.odds_ratio_preference_optimization_trainer.__init__.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#
Bases:
Trainer- arguments: ORPOConfig#
- build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#
Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.
- Parameters
prompt (str) – The prompt text.
answer (str) – The answer text.
- Returns
A dictionary containing the tokenized prompt and answer, along with attention masks.
- Return type
tp.Dict[str, np.ndarray]
- Raises
ValueError – If there’s a mismatch in token lengths.
- checkpoint_manager: tp.Any#
- checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
- config: EasyDeLBaseConfig#
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a data collection function for batching.
For DPO training, this method simply returns the pre-configured data_collator.
- Parameters
max_sequence_length (int) – The maximum sequence length (not used in this implementation).
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.
- Returns
The data collator function.
- Return type
tp.Callable
- data_collator: tp.Optional[tp.Callable]#
- dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
- dataloader_train: tp.Iterator[np.ndarray]#
- dataset_eval: tp.Optional[Dataset]#
- dataset_train: tp.Optional[Dataset]#
- dtype: tp.Any#
- evalu_tracker: CompilationTracker#
- finetune: bool#
- max_evaluation_steps: int#
- max_training_steps: int#
- memory_monitor: tp.Any#
- model_state: EasyDeLState#
- param_dtype: tp.Any#
- pruning_module: tp.Any#
- scheduler: optax.Schedule#
- sharded_evaluation_step_function: JitWrapped#
- sharded_training_step_function: JitWrapped#
- state: tp.Any#
- state_named_sharding: tp.Any#
- state_partition_spec: tp.Any#
- state_shape: tp.Any#
- tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#
Tokenizes a single row of data from the ORPO dataset.
This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.
- Parameters
feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.
state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.
- Returns
- A dictionary containing the tokenized prompt, chosen response, and rejected response,
along with attention masks and labels.
- Return type
tp.Dict
- Raises
ValueError – If the input data types are incorrect.
- train_tracker: CompilationTracker#
- tx: optax.GradientTransformation#
- wandb_runtime: tp.Any#