easydel.trainers.group_relative_policy_optimization._fn#
Internal functions for Group Relative Policy Optimization training.
This module contains the core computational functions used by the GRPO trainer, implementing group-based relative policy optimization for RLHF. GRPO improves training stability by normalizing rewards within groups of samples rather than across the entire batch, reducing variance in gradient estimates.
The module provides functions for: - Computing per-token log probabilities from model outputs - Calculating KL divergence penalties between policy and reference models - Group-based reward normalization and advantage estimation - Policy gradient loss computation with various clipping strategies
All functions are JAX-compatible and support distributed training through sharding.
- easydel.trainers.group_relative_policy_optimization._fn.compute_per_token_logps(logits, input_ids, prompt_length)[source]#
Compute per-token log probabilities in a vectorized way.
- Parameters
logits – Pre-trimmed logits [batch_size, seq_len, vocab_size]
input_ids – Input token ids [batch_size, seq_len]
prompt_length – Length of the prompt
- easydel.trainers.group_relative_policy_optimization._fn.get_per_token_logps(model, input_ids, attention_mask, prompt_length)[source]#
Compute per-token log probabilities for generated sequences.
This function extracts log probabilities for each token in the completion portion of the sequence (after the prompt). It’s used to compute likelihood ratios between policy and reference models for GRPO training.
- Parameters
model – The language model (EasyDeLBaseModule) to compute log probabilities.
input_ids – Input token IDs including prompt and completion. Shape: [batch_size, seq_len]
attention_mask – Binary mask indicating valid tokens (1) vs padding (0). Shape: [batch_size, seq_len]
prompt_length – Number of tokens in the prompt portion. Log probabilities are only computed for tokens after this position.
- Returns
- Per-token log probabilities for the completion portion.
Shape: [batch_size, seq_len - prompt_length]
- Return type
chex.Array
Note
The function shifts logits by one position to align with the autoregressive nature of language models, where each position predicts the next token.
- easydel.trainers.group_relative_policy_optimization._fn.get_per_token_logps_and_entropies(model, input_ids, attention_mask, prompt_length)[source]#
Return per-token log probabilities and entropies for the completion portion.
- easydel.trainers.group_relative_policy_optimization._fn.grpo_step(state: EasyDeLState, batch: Mapping[str, Array], num_generations: int, beta: float, loss_config: easydel.infra.loss_utils.LossConfig | None = None, learning_rate_fn: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] = None, partition_spec: jax.sharding.PartitionSpec | None = None, gradient_accumulation_steps: int = 1, is_training: bool = True, loss_type: str = 'dapo', epsilon: float = 0.2, epsilon_high: float = 0.2, delta: float | None = None, importance_sampling_level: str = 'token', top_entropy_quantile: float = 1.0) tuple[easydel.infra.base_state.EasyDeLState, easydel.infra.loss_utils.LossMetrics][source]#