easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_trainer#

class easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_trainer.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#

Bases: Trainer

arguments: ORPOConfig#
build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#

Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.

Parameters
  • prompt (str) – The prompt text.

  • answer (str) – The answer text.

Returns

A dictionary containing the tokenized prompt and answer, along with attention masks.

Return type

tp.Dict[str, np.ndarray]

Raises

ValueError – If there’s a mismatch in token lengths.

checkpoint_manager: tp.Any#
checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
config: EasyDeLBaseConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

data_collator: tp.Optional[tp.Callable]#
dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
dataloader_train: tp.Iterator[np.ndarray]#
dataset_eval: tp.Optional[Dataset]#
dataset_train: tp.Optional[Dataset]#
dtype: tp.Any#
evalu_tracker: CompilationTracker#
finetune: bool#
max_evaluation_steps: int#
max_training_steps: int#
memory_monitor: tp.Any#
model_state: EasyDeLState#
param_dtype: tp.Any#
pruning_module: tp.Any#
scheduler: optax.Schedule#
sharded_evaluation_step_function: JitWrapped#
sharded_training_step_function: JitWrapped#
state: tp.Any#
state_named_sharding: tp.Any#
state_partition_spec: tp.Any#
state_shape: tp.Any#
timer: Timers#
tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#

Tokenizes a single row of data from the ORPO dataset.

This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.

Parameters
  • feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.

  • state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.

Returns

A dictionary containing the tokenized prompt, chosen response, and rejected response,

along with attention masks and labels.

Return type

tp.Dict

Raises

ValueError – If the input data types are incorrect.

train_tracker: CompilationTracker#
tx: optax.GradientTransformation#
wandb_runtime: tp.Any#