easydel.trainers.reward_trainer.reward_trainer

easydel.trainers.reward_trainer.reward_trainer#

class easydel.trainers.reward_trainer.reward_trainer.RewardTrainer(arguments: RewardConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, data_collator: Optional[RewardDataCollatorWithPadding] = None)[source]#

Bases: Trainer

This trainer extends the Trainer and provides functionalities.

configure_functions() TrainerConfigureFunctionOutput[source]#

Configures and JIT-compiles the training and evaluation step functions.

This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.

Returns

An object containing:
  • sharded_training_step_function: The compiled training step function.

  • sharded_evaluation_step_function: The compiled evaluation step function.

  • mesh: The device mesh used for computation.

  • checkpoint_manager: The checkpointer for saving/loading model state.

Return type

TrainerConfigureFunctionOutput

create_collect_function(max_sequence_length, truncation_mode='keep_end')[source]#

Creates a collate/collect function to process batches of data for training or evaluation.

This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.

Parameters
  • max_sequence_length (int) – The maximum allowed sequence length.

  • truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.

Returns

A function that takes a batch (list of dicts) and returns a processed dict of arrays.

Return type

tp.Callable