easydel.trainers.reward_trainer.reward_config#
- class easydel.trainers.reward_trainer.reward_config.RewardConfig(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: str | None = None, clip_grad: float | None = None, custom_scheduler: tp.Callable[[int], tp.Any] | None = None, dataloader_num_workers: int | None = 0, dataloader_pin_memory: bool | None = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: int | None = None, evaluation_steps: int | None = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: str | None = None, grain_shard_index: int | None = None, grain_shard_count: int | None = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: list[str] | None = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: dict | None = None, learning_rate: float = 5e-05, learning_rate_end: float | None = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: LossConfig | None = None, low_mem_usage: bool = True, max_evaluation_steps: int | None = None, max_sequence_length: int | None = 1024, max_training_steps: int | None = None, per_epoch_training_steps: int | None = None, per_epoch_evaluation_steps: int | None = None, model_name: str | None = None, model_parameters: dict | None = None, metrics_to_show_in_rich_pbar: list[str] | None = None, generation_top_p: float | None = None, generation_top_k: int | None = None, generation_temperature: float | None = None, generation_do_sample: bool | None = None, generation_num_return_sequences: int | None = None, generation_max_new_tokens: int | None = None, generation_shard_inputs: bool = True, generation_interval: int | None = None, generation_prompts: list[str | dict[str, tp.Any]] = <factory>, generation_use_train_prompts: bool = False, generation_num_prompts: int = 1, generation_dataset_prompt_field: str | None = 'prompt', generation_extra_kwargs: dict[str, tp.Any] | None = None, generation_config_overrides: dict[str, tp.Any] | None = None, generation_seed: int | None = None, generation_preview_print: bool = False, generation_log_to_wandb: bool = True, use_esurge_generation: bool = True, esurge_use_tqdm: bool = True, esurge_hbm_utilization: float | None = 0.45, esurge_max_num_seqs: int | None = None, esurge_min_input_pad: int | None = None, esurge_page_size: int | None = 32, esurge_silent_mode: bool = True, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: AVAILABLE_OPTIMIZERS = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: tp.Any = None, process_zero_is_admin: bool = True, progress_bar_type: tp.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = False, report_steps: int = 5, save_interval_minutes: float | None = None, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: int | None = None, save_total_limit: int | None = None, scheduler: AVAILABLE_SCHEDULERS = EasyDeLSchedulers.NONE, shuffle_seed_train: int = 64871, sparsify_module: bool = False, sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = 'bcoo', state_apply_fn_kwarguments_to_model: dict | None = None, step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: int | None = None, resume_if_possible: bool = True, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: str | None = None, train_on_inputs: bool = True, trainer_prefix: str | None = 'rewardtrainer', truncation_mode: tp.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: jnp.dtype | None = None, track_memory: bool | float = False, use_data_collactor: bool = True, use_grain: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: str | None = None, wandb_name: str | None = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*', weight_distribution_log_steps: int = 50, _can_log_metrics: bool | None = None, _im_a_hidden_checkpoint_manager: Checkpointer | None = None, disable_dropout: bool = True, dataset_num_proc: int | None = None, center_rewards_coefficient: float | None = 0.1)[source]#
Bases:
TrainingArgumentsConfiguration class for Reward Model training.
Reward models are crucial components in RLHF pipelines, learning to predict human preferences between different model outputs. The trained reward model serves as a proxy for human judgment, providing feedback signals for policy optimization.
This configuration extends TrainingArguments with parameters specific to training reward models using pairwise preference data. The model learns to assign higher scores to preferred (chosen) responses compared to non-preferred (rejected) responses.
Key concepts: - Bradley-Terry model: P(chosen > rejected) = sigmoid(r_chosen - r_rejected) - Margin-based losses: Optionally enforce minimum score differences - Reward centering: Regularization to maintain mean-zero rewards
- trainer_prefix#
Prefix for trainer logs and checkpoints. Default: “rewardtrainer”
- Type
str | None
- max_sequence_length#
Maximum length of sequences (prompt + completion). Sequences exceeding this limit are filtered out. Default: 1024
- Type
int | None
- disable_dropout#
Whether to disable dropout during training for more deterministic behavior. Recommended for reward models. Default: True
- Type
bool
- dataset_num_proc#
Number of processes for parallel dataset preprocessing. None uses sequential processing. Default: None
- Type
int | None
- center_rewards_coefficient#
Coefficient for reward centering regularization. Encourages the model to output mean-zero rewards, preventing reward drift. Default: 0.1
- Type
float | None
- remove_unused_columns#
Whether to remove columns not used by the model’s forward pass. Only set True if dataset is pretokenized. Default: False
- Type
bool | None
Example
>>> config = RewardConfig( ... max_sequence_length=2048, ... center_rewards_coefficient=0.01, ... learning_rate=2e-5, ... num_train_epochs=1 ... )
Note
The reward model typically uses the same architecture as the base LLM but with a scalar reward head instead of the language modeling head. Training requires paired preference data with chosen and rejected examples.
- disable_dropout: bool = True#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- remove_unused_columns: bool = False#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.