easydel.trainers.training_configurations#
- class easydel.trainers.training_configurations.TrainingArguments(auto_shard_states: 'bool' = True, aux_loss_enabled: 'bool' = False, backend: 'tp.Optional[str]' = None, clip_grad: 'tp.Optional[float]' = None, custom_scheduler: 'tp.Optional[tp.Callable[[int], tp.Any]]' = None, dataloader_num_workers: 'tp.Optional[int]' = 0, dataloader_pin_memory: 'tp.Optional[bool]' = False, do_eval: 'bool' = False, do_last_save: 'bool' = True, do_train: 'bool' = True, eval_batch_size: 'tp.Optional[int]' = None, evaluation_steps: 'tp.Optional[int]' = None, extra_optimizer_kwargs: 'dict' = <factory>, frozen_parameters: 'tp.Optional[str]' = None, gradient_accumulation_steps: 'int' = 1, ids_to_pop_from_dataset: 'tp.Optional[tp.List[str]]' = <factory>, is_fine_tuning: 'bool' = True, init_tx: 'bool' = True, jax_distributed_config: 'tp.Optional[dict]' = None, learning_rate: 'float' = 5e-05, learning_rate_end: 'tp.Optional[float]' = None, log_all_workers: 'bool' = False, log_grad_norms: 'bool' = True, report_metrics: 'bool' = True, log_steps: 'int' = 10, loss_config: 'tp.Optional[LossConfig]' = None, low_mem_usage: 'bool' = True, max_evaluation_steps: 'tp.Optional[int]' = None, max_sequence_length: 'tp.Optional[int]' = 4096, max_training_steps: 'tp.Optional[int]' = None, model_name: 'str' = 'BaseTrainer', model_parameters: 'tp.Optional[dict]' = None, metrics_to_show_in_rich_pbar: 'tp.Optional[tp.List[str]]' = None, num_train_epochs: 'int' = 10, offload_dataset: 'bool' = False, offload_device_type: 'str' = 'cpu', offload_device_index: 'int' = 0, optimizer: 'AVAILABLE_OPTIMIZERS' = <EasyDeLOptimizers.ADAMW: 'adamw'>, performance_mode: 'bool' = False, pruning_module: 'tp.Any' = None, process_zero_is_admin: 'bool' = True, progress_bar_type: "tp.Literal['tqdm', 'rich', 'json']" = 'tqdm', remove_ckpt_after_load: 'bool' = False, remove_unused_columns: 'bool' = True, report_steps: 'int' = 5, save_directory: 'str' = 'EasyDeL-Checkpoints', save_optimizer_state: 'bool' = True, save_steps: 'tp.Optional[int]' = None, save_total_limit: 'tp.Optional[int]' = None, scheduler: 'AVAILABLE_SCHEDULERS' = <EasyDeLSchedulers.NONE: 'None'>, sparsify_module: 'bool' = False, sparse_module_type: 'AVAILABLE_SPARSE_MODULE_TYPES' = 'bcoo', state_apply_fn_kwarguments_to_model: 'tp.Optional[dict]' = None, step_partition_spec: 'PartitionSpec' = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: 'tp.Optional[int]' = None, shuffle_train_dataset: 'bool' = True, total_batch_size: 'int' = 32, training_time_limit: 'tp.Optional[str]' = None, train_on_inputs: 'bool' = True, truncation_mode: "tp.Literal['keep_end', 'keep_start']" = 'keep_end', tx_mu_dtype: 'tp.Optional[jnp.dtype]' = None, track_memory: 'bool' = False, use_data_collactor: 'bool' = True, use_wandb: 'bool' = True, verbose: 'bool' = True, wandb_entity: 'tp.Optional[str]' = None, warmup_steps: 'int' = 0, weight_decay: 'float' = 0.01)[source]#
Bases:
object- auto_shard_states: bool = True#
- aux_loss_enabled: bool = False#
- backend: Optional[str] = None#
- clip_grad: Optional[float] = None#
- custom_scheduler: Optional[Callable[[int], Any]] = None#
- dataloader_num_workers: Optional[int] = 0#
- dataloader_pin_memory: Optional[bool] = False#
- do_eval: bool = False#
- do_last_save: bool = True#
- do_train: bool = True#
- eval_batch_size: Optional[int] = None#
- evaluation_steps: Optional[int] = None#
- extra_optimizer_kwargs: dict#
- classmethod from_dict(config: Dict[str, Any]) TrainingArguments[source]#
Creates a TrainingArguments instance from a dictionary.
- Parameters
config (tp.Dict[str, tp.Any]) – The configuration dictionary.
- Returns
A TrainingArguments object initialized with values from the dictionary.
- Return type
- frozen_parameters: Optional[str] = None#
- get_optimizer_and_scheduler(steps: Optional[int] = None)[source]#
Returns the configured optimizer and learning rate scheduler.
- Parameters
steps (tp.Optional[int]) – The number of training steps. If not provided, uses the value from self.optimizer_kwargs.
- Returns
A tuple containing the optimizer and scheduler.
- Return type
tuple
- get_path() Path[source]#
Returns the path to the checkpoint directory.
- Returns
The path to the checkpoint directory.
- Return type
Path
- get_streaming_checkpointer()[source]#
Returns the checkpoint manager, responsible for saving model checkpoints.
- Returns
The checkpoint manager.
- Return type
- get_tensorboard()[source]#
Returns the TensorBoard SummaryWriter, used for logging metrics.
- Returns
The TensorBoard SummaryWriter.
- Return type
flax.metrics.tensorboard.SummaryWriter
- get_wandb_init()[source]#
Initializes Weights & Biases for experiment tracking if enabled.
- Returns
The WandB run object if initialized, else None.
- Return type
tp.Optional[wandb.sdk.wandb_run.Run]
- gradient_accumulation_steps: int = 1#
- ids_to_pop_from_dataset: Optional[List[str]]#
- init_tx: bool = True#
- is_fine_tuning: bool = True#
- property is_process_zero#
- jax_distributed_config: Optional[dict] = None#
- learning_rate: float = 5e-05#
- learning_rate_end: Optional[float] = None#
- log_all_workers: bool = False#
- log_grad_norms: bool = True#
- log_metrics(metrics: Any, step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#
Logs training metrics to Weights & Biases and/or TensorBoard.
- Parameters
metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]) – A dictionary where keys are metric names and values are metric values.
step (int) – The current training step or iteration.
- log_steps: int = 10#
- loss_config: Optional[LossConfig] = None#
- low_mem_usage: bool = True#
- max_evaluation_steps: Optional[int] = None#
- max_sequence_length: Optional[int] = 4096#
- max_training_steps: Optional[int] = None#
- metrics_to_show_in_rich_pbar: Optional[List[str]] = None#
- model_name: str = 'BaseTrainer'#
- model_parameters: Optional[dict] = None#
- num_train_epochs: int = 10#
- offload_dataset: bool = False#
- property offload_device#
- offload_device_index: int = 0#
- offload_device_type: str = 'cpu'#
- optimizer: Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = 'adamw'#
- performance_mode: bool = False#
- process_zero_is_admin: bool = True#
- progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
- pruning_module: Any = None#
- remove_ckpt_after_load: bool = False#
- remove_unused_columns: bool = True#
- report_metrics: bool = True#
- report_steps: int = 5#
- save_directory: str = 'EasyDeL-Checkpoints'#
- save_optimizer_state: bool = True#
- save_steps: Optional[int] = None#
- save_total_limit: Optional[int] = None#
- scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
- shuffle_train_dataset: bool = True#
- sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
- sparsify_module: bool = False#
- state_apply_fn_kwarguments_to_model: Optional[dict] = None#
- step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
- step_start_point: Optional[int] = None#
- to_dict() Dict[str, Any][source]#
Converts the TrainingArguments object into a dictionary.
- Returns
A dictionary representation of the TrainingArguments.
- Return type
tp.Dict[str, tp.Any]
- total_batch_size: int = 32#
- track_memory: bool = False#
- train_on_inputs: bool = True#
- training_time_limit: Optional[str] = None#
- property training_time_seconds: int#
- truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
- use_data_collactor: bool = True#
- use_wandb: bool = True#
- verbose: bool = True#
- wandb_entity: Optional[str] = None#
- warmup_steps: int = 0#
- weight_decay: float = 0.01#