easydel.trainers.reward_trainer.__init__#
- class easydel.trainers.reward_trainer.__init__.RewardConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: tp.Optional[str] = None, clip_grad: tp.Optional[float] = None, custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = None, dataloader_num_workers: tp.Optional[int] = 0, dataloader_pin_memory: tp.Optional[bool] = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: tp.Optional[int] = None, evaluation_steps: tp.Optional[int] = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: tp.Optional[str] = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: tp.Optional[dict] = None, learning_rate: float = 5e-05, learning_rate_end: tp.Optional[float] = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: tp.Optional[LossConfig] = None, low_mem_usage: bool = True, max_evaluation_steps: tp.Optional[int] = None, max_sequence_length: ~typing.Optional[int] = 1024, max_training_steps: tp.Optional[int] = None, model_name: str = 'RewardTrainer', model_parameters: tp.Optional[dict] = None, metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = None, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = False, report_steps: int = 5, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: tp.Optional[int] = None, save_total_limit: tp.Optional[int] = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: tp.Optional[dict] = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: tp.Optional[int] = None, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: tp.Optional[str] = None, train_on_inputs: bool = True, truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: tp.Optional[jnp.dtype] = None, track_memory: bool = False, use_data_collactor: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: tp.Optional[str] = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*?(layernorm|norm).*?', weight_distribution_log_steps: int = 0, disable_dropout: bool = True, dataset_num_proc: ~typing.Optional[int] = None, center_rewards_coefficient: ~typing.Optional[float] = 0.1)[source]#
Bases:
TrainingArgumentsConfiguration class for the [RewardTrainer].
- Parameters
model_name (str) – The name of the model. Defaults to “RewardTrainer”.
max_length (int, optional) – Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the limit. Defaults to 1024.
disable_dropout (bool, optional) – Whether to disable dropout in the model. Defaults to True.
dataset_num_proc (int, optional) – Number of processes to use for processing the dataset. Defaults to None.
center_rewards_coefficient (float, optional) – Coefficient to incentivize the reward model to output mean-zero rewards. Defaults to 0.1.
remove_unused_columns (bool, optional) – Whether to remove the columns that are not used by the model’s forward pass. Can be True only if the dataset is pretokenized. Defaults to False.
- center_rewards_coefficient: Optional[float] = 0.1#
- dataset_num_proc: Optional[int] = None#
- disable_dropout: bool = True#
- extra_optimizer_kwargs: dict#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- ids_to_pop_from_dataset: tp.Optional[tp.List[str]]#
- max_sequence_length: Optional[int] = 1024#
- model_name: str = 'RewardTrainer'#
- remove_unused_columns: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.reward_trainer.__init__.RewardTrainer(arguments: RewardConfig, processing_class: Any, model: Optional[Union[EasyDeLBaseModule, EasyDeLState]] = None, train_dataset: Optional[Any] = None, eval_dataset: Optional[Union[Any, Dict[str, Any]]] = None, data_collator: Optional[RewardDataCollatorWithPadding] = None)[source]#
Bases:
TrainerThis trainer extends the Trainer and provides functionalities.
- configure_functions() TrainerConfigureFunctionOutput[source]#
Configures and JIT-compiles the training and evaluation step functions.
This method prepares the functions that will be used during training and evaluation. It sets up sharding for the model parameters and optimizer state, JIT-compiles the training and evaluation functions with the appropriate static arguments and sharding constraints, and also sets up the checkpoint manager.
- Returns
- An object containing:
sharded_training_step_function: The compiled training step function.
sharded_evaluation_step_function: The compiled evaluation step function.
mesh: The device mesh used for computation.
checkpoint_manager: The checkpointer for saving/loading model state.
- Return type
- create_collect_function(max_sequence_length, truncation_mode='keep_end')[source]#
Creates a collate/collect function to process batches of data for training or evaluation.
This function returns a callable that takes a batch (a list of dictionaries) and converts it into a dictionary of JAX arrays. For models of class “ForCausalLMLoss”, it also performs truncation (either keeping the end or the start of the sequence) so that each sequence does not exceed the specified maximum length.
- Parameters
max_sequence_length (int) – The maximum allowed sequence length.
truncation_mode (tp.Literal["keep_end", "keep_start"], optional) – Determines whether to keep the end or the start of the sequence when truncating. Defaults to “keep_end”.
- Returns
A function that takes a batch (list of dicts) and returns a processed dict of arrays.
- Return type
tp.Callable