easydel.trainers.odds_ratio_preference_optimization_trainer.__init__

Contents

easydel.trainers.odds_ratio_preference_optimization_trainer.__init__#

class easydel.trainers.odds_ratio_preference_optimization_trainer.__init__.ORPOConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 1e-06, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: tp.Optional[int] = 4096, max_training_steps: tp.Optional[int] = None, model_name: str = 'ORPOTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, max_length: ~typing.Optional[int] = 1024, max_prompt_length: ~typing.Optional[int] = 512, max_completion_length: ~typing.Optional[int] = None, beta: float = 0.1, disable_dropout: bool = True, label_pad_token_id: int = -100, padding_value: ~typing.Optional[int] = None, generate_during_eval: bool = False, is_encoder_decoder: ~typing.Optional[bool] = None, dataset_num_proc: ~typing.Optional[int] = None)[source]#

Bases: TrainingArguments

Configuration class for ORPO training settings.

This class inherits from TrainingArguments and holds configuration parameters specific to the ORPO model training. The dataclass automatically generates an initializer, and the __post_init__ method further processes some of the parameters after object initialization.

model_name#

The name of the model. Default is “ORPOTrainer”.

Type

str

learning_rate#

The learning rate used during training. Default is 1e-6.

Type

float

max_length#

The maximum allowed sequence length for the input. Default is 1024.

Type

Optional[int]

max_prompt_length#

The maximum allowed length of the prompt portion of the input. Default is 512.

Type

Optional[int]

max_completion_length#

The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length.

Type

Optional[int]

beta#

A hyperparameter beta, with a default value of 0.1.

Type

float

disable_dropout#

Flag to disable dropout during training. Default is True.

Type

bool

label_pad_token_id#

The token id used for padding labels. Default is -100.

Type

int

padding_value#

The value used for padding sequences. Default is None.

Type

Optional[int]

generate_during_eval#

Flag indicating whether to generate sequences during evaluation. Default is False.

Type

bool

is_encoder_decoder#

Flag to indicate if the model is encoder-decoder. Default is None.

Type

Optional[bool]

model_init_kwargs#

Additional keyword arguments for model initialization. Default is None.

Type

Optional[Dict[str, Any]]

dataset_num_proc#

Number of processes to use for dataset processing. Default is None.

Type

Optional[int]

max_sequence_length#

Computed attribute representing the maximum sequence length used for training. It is set in the __post_init__ method.

Type

int

beta: float = 0.1#
dataset_num_proc: Optional[int] = None#
disable_dropout: bool = True#
extra_optimizer_kwargs: dict#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

generate_during_eval: bool = False#
ids_to_pop_from_dataset: tp.Optional[tp.List[str]]#
is_encoder_decoder: Optional[bool] = None#
label_pad_token_id: int = -100#
learning_rate: float = 1e-06#
max_completion_length: Optional[int] = None#
max_length: Optional[int] = 1024#
max_prompt_length: Optional[int] = 512#
model_name: str = 'ORPOTrainer'#
padding_value: Optional[int] = None#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.trainers.odds_ratio_preference_optimization_trainer.__init__.ORPOTrainer(arguments: ORPOConfig, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, data_collator: Optional[DPODataCollatorWithPadding] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None)[source]#

Bases: Trainer

arguments: ORPOConfig#
build_tokenized_answer(prompt: str, answer: str) Dict[str, ndarray][source]#

Tokenizes a prompt and answer pair, handling special tokens and padding/truncation.

Parameters
  • prompt (str) – The prompt text.

  • answer (str) – The answer text.

Returns

A dictionary containing the tokenized prompt and answer, along with attention masks.

Return type

tp.Dict[str, np.ndarray]

Raises

ValueError – If there’s a mismatch in token lengths.

checkpoint_manager: tp.Any#
checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
config: EasyDeLBaseConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length: int, truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end') Callable[source]#

Creates a data collection function for batching.

For DPO training, this method simply returns the pre-configured data_collator.

Parameters
  • max_sequence_length (int) – The maximum sequence length (not used in this implementation).

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – The truncation mode (not used in this implementation). Defaults to “keep_end”.

Returns

The data collator function.

Return type

tp.Callable

data_collator: tp.Optional[tp.Callable]#
dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
dataloader_train: tp.Iterator[np.ndarray]#
dataset_eval: tp.Optional[Dataset]#
dataset_train: tp.Optional[Dataset]#
dtype: tp.Any#
evalu_tracker: CompilationTracker#
finetune: bool#
max_evaluation_steps: int#
max_training_steps: int#
memory_monitor: tp.Any#
model_state: EasyDeLState#
param_dtype: tp.Any#
pruning_module: tp.Any#
scheduler: optax.Schedule#
sharded_evaluation_step_function: JitWrapped#
sharded_training_step_function: JitWrapped#
state: tp.Any#
state_named_sharding: tp.Any#
state_partition_spec: tp.Any#
state_shape: tp.Any#
timer: Timers#
tokenize_row(feature: Dict[str, str], state: Optional[object] = None) Dict[str, ndarray][source]#

Tokenizes a single row of data from the ORPO dataset.

This method tokenizes the prompt, chosen response, and rejected response, handles padding and truncation, and prepares the data for input to the DPO model.

Parameters
  • feature (tp.Dict) – A dictionary containing the “prompt”, “chosen”, and “rejected” texts.

  • state (EasyDeLState, optional) – Not used in this implementation. Defaults to None.

Returns

A dictionary containing the tokenized prompt, chosen response, and rejected response,

along with attention masks and labels.

Return type

tp.Dict

Raises

ValueError – If the input data types are incorrect.

train_tracker: CompilationTracker#
tx: optax.GradientTransformation#
wandb_runtime: tp.Any#