easydel.trainers.reward_trainer._fn#
Internal functions for Reward Model training.
This module contains the core computational functions used by the reward trainer, implementing training and evaluation steps for reward models in RLHF pipelines. Reward models learn to predict human preferences between pairs of model outputs, serving as a proxy for human judgment when training policies with reinforcement learning.
The module provides functions for: - Training step computation with pairwise ranking losses - Evaluation step for assessing reward model accuracy - Support for Bradley-Terry model and margin-based losses - Reward centering and normalization strategies
The reward model is trained to assign higher scores to preferred (chosen) responses compared to non-preferred (rejected) responses, learning from human preference data.
All functions are JAX-compatible and support distributed training through sharding.
- easydel.trainers.reward_trainer._fn.evaluation_step(state: EasyDeLState, batch: Mapping[str, Array], loss_config: easydel.infra.loss_utils.LossConfig | None = None, partition_spec: jax.sharding.PartitionSpec | None = None, center_rewards_coefficient: float | None = None) tuple[Any, easydel.infra.loss_utils.LossMetrics][source]#
Performs a single evaluation step by computing loss metrics for the input batch.
The function determines the required partitioning for the batch, applies sharding constraints, and defines an inner loss function. This inner function merges the current state with the graph state, sets the model to evaluation mode, and computes loss and metrics via the model’s compute_loss method. The computed LossMetrics are then returned.
- Parameters
state (EasyDeLState) – The current model state.
batch (tp.Mapping[str, jax.Array]) – A mapping of input arrays for evaluation.
loss_config (tp.Optional[LossConfig], optional) – Configuration for loss computation. Defaults to None.
partition_spec (tp.Optional[PartitionSpec], optional) – Specification for sharding the batch. Defaults to None.
center_rewards_coefficient (int, optional) – Coefficient to incentivize the reward model to output mean-zero rewards.
- Returns
- A tuple containing:
(Any): An additional output from loss computation (if any).
LossMetrics: The computed loss metrics for the evaluation batch.
- Return type
tp.Tuple[tp.Any, LossMetrics]
- easydel.trainers.reward_trainer._fn.training_step(state: EasyDeLState, batch: Mapping[str, Array], loss_config: easydel.infra.loss_utils.LossConfig | None = None, learning_rate_fn: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] = None, partition_spec: jax.sharding.PartitionSpec | None = None, gradient_accumulation_steps: int = 1, center_rewards_coefficient: float | None = None) tuple[easydel.infra.base_state.EasyDeLState, easydel.infra.loss_utils.LossMetrics][source]#
Performs a single training step by computing gradients via minibatch processing, updating the model state, and returning updated state and loss metrics.
The function first determines the batch and minibatch sizes using assertions. It then applies sharding constraints to the batch. The loss function is defined as an inner function that merges the current model state with an updated tree, prepares the inputs, and computes the loss using the model’s compute_loss method. Gradients are computed using jax.value_and_grad over minibatches. The state is updated respectfully using the computed gradients and updated metrics.
- Parameters
state (EasyDeLState) – The current model state, which includes parameters and model graph.
batch (tp.Mapping[str, jax.Array]) – A mapping of input arrays for the current batch.
loss_config (tp.Optional[LossConfig], optional) – Configuration settings for the loss computation. Defaults to None.
learning_rate_fn (optax.Schedule, optional) – A schedule function for the learning rate. Defaults to None.
partition_spec (tp.Optional[PartitionSpec], optional) – Specification for data sharding. Defaults to None.
gradient_accumulation_steps (int, optional) – Number of steps over which to accumulate gradients. Defaults to 1.
center_rewards_coefficient (int, optional) – Coefficient to incentivize the reward model to
rewards. (output mean-zero) –
- Returns
- A tuple containing:
The updated EasyDeLState after applying gradients.
LossMetrics containing computed loss and other related metrics.
- Return type
tp.Tuple[EasyDeLState, LossMetrics]