easydel.trainers.group_relative_policy_optimization._fn

easydel.trainers.group_relative_policy_optimization._fn#

easydel.trainers.group_relative_policy_optimization._fn.compute_per_token_logps(logits, input_ids, prompt_length)[source]#

Compute per-token log probabilities in a vectorized way.

Parameters
  • logits – Pre-trimmed logits [batch_size, seq_len, vocab_size]

  • input_ids – Input token ids [batch_size, seq_len]

  • prompt_length – Length of the prompt

easydel.trainers.group_relative_policy_optimization._fn.get_per_token_logps(model, input_ids, attention_mask, prompt_length)[source]#

Get per-token log probabilities using the model outputs.

Parameters
  • model – The language model

  • input_ids – Input token ids [batch_size, seq_len]

  • attention_mask – Input masks [batch_size, seq_len]

  • prompt_length – Length of the prompt

easydel.trainers.group_relative_policy_optimization._fn.grpo_step(state: EasyDeLState, batch: Mapping[str, Array], eos_token_id: int, num_generations: int, beta: float, prompt_length: int, loss_config: Optional[LossConfig] = None, learning_rate_fn: Callable[[Union[Array, ndarray, bool, number, float, int]], Union[Array, ndarray, bool, number, float, int]] = None, partition_spec: Optional[PartitionSpec] = None, gradient_accumulation_steps: int = 1, is_training: bool = True) Tuple[EasyDeLState, LossMetrics][source]#