easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_trainer#
- class easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_trainer.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#
Bases:
Trainer- arguments: ORPOConfig#
- build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#
Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.
- Parameters
prompt (str) – The prompt text.
answer (str) – The answer text.
- Returns
A dictionary containing the tokenized prompt and answer, along with attention masks.
- Return type
tp.Dict[str, np.ndarray]
- Raises
ValueError – If there’s a mismatch in token lengths.
- checkpoint_manager: tp.Any#
- checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
- config: EasyDeLBaseConfig#
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
- This method sets up the necessary functions for training and evaluation, including:
Initialization of the model state.
Sharding of the model parameters and optimizer state.
JIT-compilation of the training and evaluation step functions.
- Returns
An object containing the configured functions and other relevant information.
- Return type
- create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#
Creates a data collection function for batching.
For DPO training, this method simply returns the pre-configured data_collator.
- Parameters
max_sequence_length (int) – The maximum sequence length (not used in this implementation).
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.
- Returns
The data collator function.
- Return type
tp.Callable
- data_collator: tp.Optional[tp.Callable]#
- dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
- dataloader_train: tp.Iterator[np.ndarray]#
- dataset_eval: tp.Optional[Dataset]#
- dataset_train: tp.Optional[Dataset]#
- dtype: tp.Any#
- evalu_tracker: CompilationTracker#
- finetune: bool#
- max_evaluation_steps: int#
- max_training_steps: int#
- memory_monitor: tp.Any#
- model_state: EasyDeLState#
- param_dtype: tp.Any#
- pruning_module: tp.Any#
- scheduler: optax.Schedule#
- sharded_evaluation_step_function: JitWrapped#
- sharded_training_step_function: JitWrapped#
- state: tp.Any#
- state_named_sharding: tp.Any#
- state_partition_spec: tp.Any#
- state_shape: tp.Any#
- tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#
Tokenizes a single row of data from the ORPO dataset.
This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.
- Parameters
feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.
state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.
- Returns
- A dictionary containing the tokenized prompt, chosen response, and rejected response,
along with attention masks and labels.
- Return type
tp.Dict
- Raises
ValueError – If the input data types are incorrect.
- train_tracker: CompilationTracker#
- tx: optax.GradientTransformation#
- wandb_runtime: tp.Any#