easydel.trainers.group_relative_policy_optimization.grpo_trainer#

class easydel.trainers.group_relative_policy_optimization.grpo_trainer.GRPOTrainer(arguments: GRPOConfig, vinference: vInference, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]], reward_funcs: Union[EasyDeLBaseModule, EasyDeLState, Callable[[list, list], list[float]], list[Union[easydel.infra.base_module.EasyDeLBaseModule, easydel.infra.base_state.EasyDeLState, Callable[[list, list], list[float]]]]], train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, processing_class: Optional[Any] = None, reward_processing_classes: Optional[Any] = None, data_tokenize_fn: Optional[Callable] = None)[source]#

Bases: Trainer

arguments: GRPOConfig#
checkpoint_manager: tp.Any#
checkpoint_path: tp.Optional[tp.Union[str, os.PathLike]]#
config: EasyDeLBaseConfig#
configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method sets up the necessary functions for training and evaluation, including:
  • Initialization of the model state.

  • Sharding of the model parameters and optimizer state.

  • JIT-compilation of the training and evaluation step functions.

Returns

An object containing the configured functions and other relevant information.

Return type

TrainerConfigureFunctionOutput

data_collator: tp.Optional[tp.Callable]#
dataloader_eval: tp.Optional[tp.Iterator[np.ndarray]]#
dataloader_train: tp.Iterator[np.ndarray]#
dataset_eval: tp.Optional[Dataset]#
dataset_train: tp.Optional[Dataset]#
dtype: tp.Any#
evalu_tracker: CompilationTracker#
finetune: bool#
max_evaluation_steps: int#
max_training_steps: int#
memory_monitor: tp.Any#
model_state: EasyDeLState#
on_step_end(state: EasyDeLState, metrics: Any, step: int) Tuple[EasyDeLState, Any][source]#

hook process to call in start of the step.

param_dtype: tp.Any#
pruning_module: tp.Any#
scheduler: optax.Schedule#
sharded_evaluation_step_function: JitWrapped#
sharded_training_step_function: JitWrapped#
state: tp.Any#
state_named_sharding: tp.Any#
state_partition_spec: tp.Any#
state_shape: tp.Any#
timer: Timers#
train_tracker: CompilationTracker#
tx: optax.GradientTransformation#
wandb_runtime: tp.Any#
easydel.trainers.group_relative_policy_optimization.grpo_trainer.delete_tree(pytree)[source]#