easydel.trainers.odds_ratio_preference_optimization_trainer._fn#
- easydel.trainers.odds_ratio_preference_optimization_trainer._fn.concatenated_forward(state: EasyDeLState, batch: Mapping[str, Union[List, Array, ndarray, bool, number]], is_encoder_decoder: bool, label_pad_token_id: int, padding_value: Any, max_length: int | None = None) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Computes log-probabilities and logits for both chosen and rejected examples by concatenating the inputs and performing a forward pass through the model.
The function processes the batch by concatenating the chosen and rejected examples. It then calls the model (stored in state) to obtain the logits, computes the negative log-likelihood loss for the chosen examples using a dynamic cross entropy loss function, and splits the logits and log-probabilities into those corresponding to the chosen and rejected examples.
- Parameters
state (EasyDeLState) – The current state of the model containing parameters and the model itself.
batch (tp.Mapping[str, tp.Union[tp.List, chex.Array]]) – A dictionary containing input arrays for chosen and rejected examples as well as other necessary inputs.
is_encoder_decoder (bool) – Flag indicating whether the model is an encoder-decoder.
label_pad_token_id (int) – The token ID used to mark padding positions in the labels.
padding_value (Any) – The value used for padding. Must not be None.
max_length (int | None, optional) – Maximum length for the inputs (if applicable). Defaults to None.
- Returns
- A tuple containing:
chosen_log_probs: Log probabilities for the chosen examples.
rejected_log_probs: Log probabilities for the rejected examples.
chosen_logits: Logits for the chosen examples.
rejected_logits: Logits for the rejected examples.
chosen_nll_loss: Negative log-likelihood loss for the chosen examples.
chosen_accuracy: Accuracy metric computed on the chosen examples.
- Return type
tp.Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]
- easydel.trainers.odds_ratio_preference_optimization_trainer._fn.concatenated_inputs(batch: Dict[str, Union[List, Array, ndarray, bool, number]], is_encoder_decoder: bool = False) Dict[str, Union[Array, ndarray, bool, number]][source]#
Concatenates chosen and rejected examples from the batch into unified arrays.
For each key in the batch that starts with “chosen” or “rejected”, the function creates a new key starting with “concatenated” and combines the corresponding arrays. In the case of an encoder-decoder model, the prompt inputs and attention masks are also repeated accordingly.
- Parameters
batch (tp.Dict[str, tp.Union[tp.List, chex.Array]]) – A dictionary containing the batch of data. Expected keys include those starting with “chosen”, “rejected”, “prompt_input_ids”, and “prompt_attention_mask”.
is_encoder_decoder (bool, optional) – Indicates whether the model is encoder-decoder. Defaults to False.
- Returns
- A dictionary containing concatenated arrays with keys prefixed with
”concatenated”.
- Return type
tp.Dict[str, chex.Array]
- easydel.trainers.odds_ratio_preference_optimization_trainer._fn.get_batch_logps(logits: Union[Array, ndarray, bool, number], labels: Union[Array, ndarray, bool, number], average_log_prob: bool = False, label_pad_token_id: int = -100, is_encoder_decoder: bool = False) Union[Array, ndarray, bool, number][source]#
Computes the log probabilities for a batch of sequences given the model logits and labels.
The function applies a log-softmax over the logits and extracts the log probability of each token corresponding to the label. It also masks out the padding tokens using label_pad_token_id.
- Parameters
logits (chex.Array) – The logits output by the model with shape (…, sequence_length, vocab_size).
labels (chex.Array) – The ground truth labels with shape matching logits except for the vocabulary dimension.
average_log_prob (bool, optional) – If True, returns the average log probability per sequence. Otherwise, returns the sum of log probabilities per sequence. Defaults to False.
label_pad_token_id (int, optional) – The token ID used for padding in the labels. Defaults to -100.
is_encoder_decoder (bool, optional) – Flag indicating whether the model is an encoder-decoder. Defaults to False.
- Returns
An array of log probabilities for each sequence in the batch.
- Return type
chex.Array
- easydel.trainers.odds_ratio_preference_optimization_trainer._fn.odds_ratio_loss(beta: float, policy_chosen_logps: Union[Array, ndarray, bool, number], policy_rejected_logps: Union[Array, ndarray, bool, number]) Tuple[Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number], Union[Array, ndarray, bool, number]][source]#
Computes the odds ratio loss used for training based on the log probabilities of chosen and rejected examples.
The odds ratio is calculated as the difference between the chosen and rejected log probabilities (with a correction term for numerical stability). The sigmoid of this log odds is then taken, and the log of this sigmoid forms the basis of the loss. The function also computes reward values for both chosen and rejected examples, as well as summary statistics.
- Parameters
beta (float) – A scaling hyperparameter applied to the loss and rewards.
policy_chosen_logps (chex.Array) – Log probabilities for the chosen examples.
policy_rejected_logps (chex.Array) – Log probabilities for the rejected examples.
- Returns
- A tuple containing:
losses: The computed odds ratio loss.
chosen_rewards: Rewards computed from the chosen log probabilities (detached).
rejected_rewards: Rewards computed from the rejected log probabilities (detached).
mean_ratio: The mean of the log sigmoid ratio.
mean_log_odds: The mean log odds difference.
- Return type
tp.Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array]
- easydel.trainers.odds_ratio_preference_optimization_trainer._fn.orpo_step(state: EasyDeLState, batch: dict, concatenated_forward: Callable, beta: float = 0.1, learning_rate_fn: Optional[Callable] = None, mode: Literal['train', 'eval'] = 'train', loss_config: Optional[LossConfig] = None, partition_spec: Optional[PartitionSpec] = None, gradient_accumulation_steps: int = 1) Union[Tuple[EasyDeLState, LossMetrics], LossMetrics][source]#
Performs a single training or evaluation step for the ORPO method.
The function handles both forward and backward passes (when in training mode) and computes the loss metrics. It supports minibatch processing and gradient accumulation. In training mode, the model state is updated based on the computed gradients, while in evaluation mode, only loss metrics are returned.
- Parameters
state (EasyDeLState) – The current model state containing parameters, optimizer state, etc.
batch (dict) – The input batch data.
concatenated_forward (tp.Callable) – A callable that performs the forward pass and returns logits and loss values for chosen and rejected examples.
beta (float, optional) – Scaling factor used in the odds ratio loss. Defaults to 0.1.
learning_rate_fn (tp.Optional[tp.Callable], optional) – A callable to compute the learning rate at the current step. Defaults to None.
mode (tp.Literal["train", "eval"], optional) – Specifies whether the step is for training or evaluation. Defaults to “train”.
loss_config (tp.Optional[LossConfig], optional) – Configuration for the loss computation. Defaults to None.
partition_spec (tp.Optional[PartitionSpec], optional) – Specification for sharding the batch data. Defaults to None.
gradient_accumulation_steps (int, optional) – Number of steps to accumulate gradients (only relevant in training mode). Defaults to 1.
- Returns
In “train” mode: A tuple containing the updated model state and the computed loss metrics.
In “eval” mode: The computed loss metrics.
- Return type
tp.Union[tp.Tuple[EasyDeLState, LossMetrics], LossMetrics]