Source code for easydel.trainers.odds_ratio_preference_optimization_trainer.orpo_config
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing as tp
from dataclasses import field
from eformer.pytree import auto_pytree
from easydel.utils.compiling_utils import hash_fn
from ..training_configurations import TrainingArguments
[docs]@auto_pytree
class ORPOConfig(TrainingArguments):
"""
Configuration class for ORPO training settings.
This class inherits from TrainingArguments and holds configuration
parameters specific to the ORPO model training. The dataclass automatically
generates an initializer, and the __post_init__ method further processes
some of the parameters after object initialization.
Attributes:
model_name (str): The name of the model. Default is "ORPOTrainer".
learning_rate (float): The learning rate used during training.
Default is 1e-6.
max_length (Optional[int]): The maximum allowed sequence length for the input.
Default is 1024.
max_prompt_length (Optional[int]): The maximum allowed length of the prompt portion
of the input. Default is 512.
max_completion_length (Optional[int]): The maximum allowed length of the completion.
If not provided, it is set to max_length - max_prompt_length.
beta (float): A hyperparameter beta, with a default value of 0.1.
disable_dropout (bool): Flag to disable dropout during training.
Default is True.
label_pad_token_id (int): The token id used for padding labels.
Default is -100.
padding_value (Optional[int]): The value used for padding sequences.
Default is None.
generate_during_eval (bool): Flag indicating whether to generate sequences during evaluation.
Default is False.
is_encoder_decoder (Optional[bool]): Flag to indicate if the model is encoder-decoder.
Default is None.
model_init_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for model initialization.
Default is None.
dataset_num_proc (Optional[int]): Number of processes to use for dataset processing.
Default is None.
max_sequence_length (int): Computed attribute representing the maximum sequence length
used for training. It is set in the __post_init__ method.
"""
model_name: str = field(
default="ORPOTrainer",
metadata={"help": "The name of the model."},
)
learning_rate: float = field(
default=1e-6,
metadata={"help": "The learning rate used during training."},
)
max_length: tp.Optional[int] = field(
default=1024,
metadata={"help": "The maximum allowed sequence length for the input."},
)
max_prompt_length: tp.Optional[int] = field(
default=512,
metadata={"help": "The maximum allowed length of the prompt portion of the input."},
)
max_completion_length: tp.Optional[int] = field(
default=None,
metadata={
"help": "The maximum allowed length of the completion. If not provided, it is set to max_length - max_prompt_length."
},
)
beta: float = field(
default=0.1,
metadata={"help": "A hyperparameter beta."},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Flag to disable dropout during training."},
)
label_pad_token_id: int = field(
default=-100,
metadata={"help": "The token id used for padding labels."},
)
padding_value: tp.Optional[int] = field(
default=None,
metadata={"help": "The value used for padding sequences."},
)
generate_during_eval: bool = field(
default=False,
metadata={
"help": "Flag indicating whether to generate sequences during evaluation."
},
)
is_encoder_decoder: tp.Optional[bool] = field(
default=None,
metadata={"help": "Flag to indicate if the model is encoder-decoder."},
)
dataset_num_proc: tp.Optional[int] = field(
default=None,
metadata={"help": "Number of processes to use for dataset processing."},
)
def __post_init__(self):
"""
Post-initialization processing.
This method is automatically called after the dataclass __init__ method.
It sets the 'max_completion_length' if it is not provided by subtracting the
'max_prompt_length' from 'max_length'. It also defines 'max_sequence_length'
(here set to twice the max_length, based on a chosen/rejected policy).
Returns:
The result of the superclass __post_init__ method.
"""
# If max_completion_length is not provided, derive it from max_length and max_prompt_length.
if self.max_completion_length is None:
self.max_completion_length = self.max_length - self.max_prompt_length
# Set max_sequence_length based on a chosen policy.
self.max_sequence_length = self.max_length * 2 # Chosen - Rejected
# Call the post_init of the parent class if it exists.
if hasattr(super(), "__post_init__"):
super().__post_init__()
__hash__ = hash_fn