easydel.trainers.direct_preference_optimization_trainer._fn#
- easydel.trainers.direct_preference_optimization_trainer._fn.concatenated_forward(model: EasyDeLBaseModule, batch: Dict[str, Union[List, Array, ndarray, bool, number]], is_encoder_decoder: bool, label_pad_token_id: int, padding_value: int, max_length: int | None = None, truncation_mode: str = 'keep_end', aux_loss_enabled: bool = False, loss_type: str = 'sigmoid') Dict[str, Union[Array, ndarray, bool, number]][source]#
Runs the model on concatenated chosen/rejected inputs for efficiency.
This function first concatenates inputs (using the concatenated_inputs function) and then runs a forward pass through the model. It handles both encoder-decoder and decoder-only architectures, applies truncation if required, and computes per-token log probabilities.
- Parameters
model (EasyDeLBaseModule) – The model to run.
batch (tp.Dict[str, tp.Union[tp.List, chex.Array]]) – The input batch of data.
is_encoder_decoder (bool) – Flag indicating whether the model is an encoder-decoder.
label_pad_token_id (int) – Token id used to mark padded tokens in the labels.
padding_value (int) – Padding value for inputs.
max_length (int | None, optional) – Maximum sequence length for truncation. Defaults to None.
truncation_mode (str, optional) – Truncation strategy (“keep_end” or “keep_start”). Defaults to “keep_end”.
aux_loss_enabled (bool, optional) – If True, enables auxiliary loss computation. Defaults to False.
loss_type (str, optional) – The type of loss function to be used. Defaults to “sigmoid”.
- Returns
- A dictionary containing:
”chosen_logps”: Log probabilities for chosen examples.
”rejected_logps”: Log probabilities for rejected examples.
”mean_chosen_logits”: Mean logits over tokens for chosen examples.
”mean_rejected_logits”: Mean logits over tokens for rejected examples.
Optionally, if aux_loss_enabled is True and the model output contains “aux_loss”, it is included in the output dictionary.
- Return type
tp.Dict[str, chex.Array]
- easydel.trainers.direct_preference_optimization_trainer._fn.concatenated_inputs(batch: Dict[str, Union[List, Array, ndarray, bool, number]], padding_value: int) Dict[str, Union[Array, ndarray, bool, number]][source]#
Concatenates chosen and rejected examples from the batch, and pads the inputs to a uniform length.
This function is used to merge paired inputs (e.g. chosen vs. rejected examples) so that the model can process them in one forward pass. It concatenates the prompt inputs, attention masks, and (if present) image-related arrays. The completion inputs (and their attention masks) are padded to the length of the longest completion among the chosen and rejected examples.
- Parameters
batch (tp.Dict[str, tp.Union[tp.List, chex.Array]]) – A dictionary containing the batch of data. Expected keys include: - “prompt_input_ids”, “prompt_attention_mask” - “chosen_input_ids”, “rejected_input_ids” - “chosen_attention_mask”, “rejected_attention_mask” Optionally, keys like “pixel_values”, “pixel_attention_mask”, and “image_sizes” may be present.
padding_value (int) – The padding value to use when padding completion inputs.
- Returns
- A dictionary with concatenated arrays under keys such as:
”prompt_input_ids”, “prompt_attention_mask”
”completion_input_ids”, “completion_attention_mask”
and optionally image-related keys.
- Return type
tp.Dict[str, chex.Array]
- easydel.trainers.direct_preference_optimization_trainer._fn.evaluation_step(state: EasyDeLState, batch: dict, reference_state: EasyDeLState, concatenated_forward: Callable, beta: float = 0.1, label_smoothing: float = 0, loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid', reference_free: bool = False, partition_spec: Optional[PartitionSpec] = None) LossMetrics[source]#
Performs a single evaluation step.
This function computes loss metrics for the input batch using the provided model state. It can optionally use a reference state to compute reference log probabilities.
- Parameters
state (EasyDeLState) – The current model state.
batch (dict) – Input batch data.
concatenated_forward (tp.Callable) – Function to perform a forward pass on concatenated inputs.
reference_state (EasyDeLState, optional) – A reference model state. Defaults to None.
beta (float, optional) – Scaling factor for loss computation. Defaults to 0.1.
label_smoothing (float, optional) – Label smoothing factor. Defaults to 0.
loss_type (LOSS_FN_VARIENTS, optional) – Type of loss function to use. Defaults to “sigmoid”.
reference_free (bool, optional) – If True, ignores reference log probabilities. Defaults to False.
partition_spec (tp.Optional[PartitionSpec], optional) – Partitioning specification for sharding the batch. Defaults to None.
- Returns
The computed loss metrics.
- Return type
- easydel.trainers.direct_preference_optimization_trainer._fn.get_loss_function(loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'], beta: float, label_smoothing: Union[float, int])[source]#
Returns a loss function based on the specified loss type.
This function maps a given loss type (e.g., “sigmoid”, “hinge”, “ipo”, etc.) to a corresponding loss function implementation that computes the DPO (Direct Preference Optimization) loss.
- Parameters
loss_type (LOSS_FN_VARIENTS) – The type of loss function to return.
beta (float) – A scaling factor applied to the loss computation.
label_smoothing (tp.Union[float, int]) – A value for label smoothing used in some loss functions.
- Returns
(chosen_logps, rejected_logps, ref_chosen_logps, ref_rejected_logps, beta, label_smoothing, **kwargs) and returns the computed loss.
- Return type
A callable loss function that accepts arguments
- easydel.trainers.direct_preference_optimization_trainer._fn.training_step(state: EasyDeLState, batch: dict, reference_state: EasyDeLState, learning_rate_fn: Callable, concatenated_forward: Callable, beta: float = 0.1, label_smoothing: float = 0, loss_type: Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'apo_zero', 'apo_down'] = 'sigmoid', reference_free: bool = False, loss_config: Optional[LossConfig] = None, partition_spec: Optional[PartitionSpec] = None, gradient_accumulation_steps: int = 1) Tuple[EasyDeLState, LossMetrics][source]#
Performs a single training step.
This function computes gradients via minibatch processing over the input batch, calculates the loss using a specified loss function, updates the model state, and returns the updated state along with loss metrics.
- Parameters
state (EasyDeLState) – The current model state.
batch (dict) – Input batch data.
reference_state (EasyDeLState) – A reference model state used for computing reference log probabilities.
learning_rate_fn (tp.Callable) – Function to compute the learning rate.
concatenated_forward (tp.Callable) – Function to perform a forward pass on concatenated inputs.
beta (float, optional) – Scaling factor for loss computation. Defaults to 0.1.
label_smoothing (float, optional) – Label smoothing factor. Defaults to 0.
loss_type (LOSS_FN_VARIENTS, optional) – Type of loss function to use. Defaults to “sigmoid”.
ref_precalculated (bool, optional) – If True, uses precalculated reference log probabilities from the batch. Defaults to True.
loss_config (tp.Optional[LossConfig], optional) – Additional configuration for loss. Defaults to None.
partition_spec (tp.Optional[PartitionSpec], optional) – Partitioning specification for sharding the batch. Defaults to None.
gradient_accumulation_steps (int, optional) – Number of steps for gradient accumulation. Defaults to 1.
- Returns
A tuple containing the updated model state and the loss metrics.
- Return type
tp.Tuple[EasyDeLState, LossMetrics]