easydel.trainers.training_configurations#
- class easydel.trainers.training_configurations.TrainingArguments(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: str | None = None, clip_grad: float | None = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: int | None = 0, dataloader_pin_memory: bool | None = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: int | None = None, evaluation_steps: int | None = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: str | None = None, grain_shard_index: int | None = None, grain_shard_count: int | None = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: list[str] | None = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: dict | None = None, learning_rate: float = 5e-05, learning_rate_end: float | None = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: easydel.infra.loss_utils.LossConfig | None = None, low_mem_usage: bool = True, max_evaluation_steps: int | None = None, max_sequence_length: int | None = 4096, max_training_steps: int | None = None, per_epoch_training_steps: int | None = None, per_epoch_evaluation_steps: int | None = None, model_name: str | None = None, model_parameters: dict | None = None, metrics_to_show_in_rich_pbar: list[str] | None = None, generation_top_p: float | None = None, generation_top_k: int | None = None, generation_temperature: float | None = None, generation_do_sample: bool | None = None, generation_num_return_sequences: int | None = None, generation_max_new_tokens: int | None = None, generation_shard_inputs: bool = True, generation_interval: int | None = None, generation_prompts: list[str | dict[str, typing.Any]] = <factory>, generation_use_train_prompts: bool = False, generation_num_prompts: int = 1, generation_dataset_prompt_field: str | None = 'prompt', generation_extra_kwargs: dict[str, typing.Any] | None = None, generation_config_overrides: dict[str, typing.Any] | None = None, generation_seed: int | None = None, generation_preview_print: bool = False, generation_log_to_wandb: bool = True, use_esurge_generation: bool = True, esurge_use_tqdm: bool = True, esurge_hbm_utilization: float | None = 0.45, esurge_max_num_seqs: int | None = None, esurge_min_input_pad: int | None = None, esurge_page_size: int | None = 32, esurge_silent_mode: bool = True, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'adamw', 'mars', 'muon', 'rmsprop', 'lion', 'skew', 'quad'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_interval_minutes: float | None = None, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: int | None = None, save_total_limit: int | None = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, shuffle_seed_train: int = 64871, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: dict | None = None, step_partition_spec: ~jax.sharding.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: int | None = None, resume_if_possible: bool = True, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: str | None = None, train_on_inputs: bool = True, trainer_prefix: str | None = None, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: numpy.dtype | None = None, track_memory: bool | float = False, use_data_collactor: bool = True, use_grain: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: str | None = None, wandb_name: str | None = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*', weight_distribution_log_steps: int = 50, _can_log_metrics: bool | None = None, _im_a_hidden_checkpoint_manager: eformer.serialization.checkpointer.Checkpointer | None = None)[source]#
Bases:
objectComprehensive configuration class for training and evaluation.
This class encapsulates all training hyperparameters, optimization settings, data loading configuration, logging preferences, and hardware-specific options. It provides a centralized way to manage the complex configuration required for distributed training of large models.
The configuration covers: - Training hyperparameters (learning rate, batch size, epochs) - Optimization settings (optimizer, scheduler, gradient clipping) - Data loading (dataset configuration, batch collation) - Checkpointing (save frequency, checkpoint limits) - Logging (WandB, TensorBoard, metrics reporting) - Hardware configuration (sharding, precision, device placement) - Performance optimization (compilation, memory tracking)
Example
>>> args = TrainingArguments( ... learning_rate=1e-4, ... num_train_epochs=3, ... total_batch_size=32, ... save_steps=1000, ... use_wandb=True ... )
- auto_shard_states: bool = True#
- aux_loss_enabled: bool = False#
- property can_log_metrics#
- custom_scheduler: Optional[Callable[[int], Any]] = None#
- do_eval: bool = False#
- do_last_save: bool = True#
- do_train: bool = True#
- ensure_checkpoint_path()[source]#
Create the checkpoint directory if it doesn’t exist.
Ensures the full checkpoint path including parent directories exists on the filesystem. Safe to call multiple times.
Note
Uses mkdir with parents=True to create full directory tree.
- esurge_silent_mode: bool = True#
- esurge_use_tqdm: bool = True#
- extra_optimizer_kwargs: dict#
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- generation_log_to_wandb: bool = True#
- generation_num_prompts: int = 1#
- generation_preview_print: bool = False#
- generation_prompts: list[str | dict[str, Any]]#
- generation_shard_inputs: bool = True#
- generation_use_train_prompts: bool = False#
- get_checkpoint_policies()[source]#
Convert save_steps configuration to CheckpointInterval policies.
- Returns
- List of checkpoint interval policies.
Returns empty list if save_steps is None.
- Return type
list[CheckpointInterval]
Example
>>> args = TrainingArguments(save_steps=1000) >>> policies = args.get_checkpoint_policies() >>> # Returns: [CheckpointInterval(every=1000, until=None)]
- get_optimizer_and_scheduler(steps: int | None = None)[source]#
Create and return the optimizer and learning rate scheduler.
This method uses the OptimizerFactory to create the configured optimizer and scheduler based on the training arguments. It handles: - Standard optimizers (AdamW, SGD, etc.) - Learning rate schedules (linear, cosine, constant) - Gradient clipping and weight decay - Custom optimizers and schedulers
- Parameters
steps – Optional override for the number of training steps. If not provided, uses the value from self.optimizer_kwargs.
- Returns
- A tuple of (optimizer, scheduler) where:
optimizer: Optax GradientTransformation
scheduler: Learning rate schedule function
- Return type
tuple
Note
The optimizer is an Optax transformation chain that may include gradient clipping, weight decay, and other transformations.
- get_path() eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath[source]#
Get the path to the checkpoint directory.
- Returns
- The path to the checkpoint directory, combining
save_directory and model_name.
- Return type
ePathLike
Note
Creates a model-specific subdirectory within the main save directory.
- get_save_interval_timedelta()[source]#
Get time-based checkpoint save interval as timedelta.
- Returns
- Time interval for temporary checkpoints,
or None if no time-based saving is configured.
- Return type
timedelta | None
Note
Currently returns None. Can be extended to support time-based checkpoint saving via new TrainingArguments field.
- get_streaming_checkpointer()[source]#
Get the asynchronous checkpoint manager.
- Returns
- The checkpoint manager for handling
asynchronous model checkpointing.
- Return type
AsyncCheckpointManager
Note
The AsyncCheckpointManager allows non-blocking checkpoint saves during training, improving training efficiency.
- get_tensorboard() SummaryWriter | None[source]#
Get the TensorBoard SummaryWriter for logging metrics.
- Returns
- The TensorBoard writer instance, or None if
TensorBoard is not available or not configured.
- Return type
SummaryWriter | None
Note
Handles ModuleNotFoundError gracefully if TensorBoard is not installed. Uses cached property internally for efficiency.
- get_wandb_init()[source]#
Initialize Weights & Biases for experiment tracking.
This method creates a new WandB run with appropriate configuration: - Project name based on model name and optional prefix - Run name with timestamp if not specified - Configuration dictionary from training arguments - Standard tags for EasyDeL experiments
The method handles process-level initialization, ensuring only the main process creates the WandB run in distributed settings.
- Returns
- The initialized WandB run object, or None if:
WandB is not installed
use_wandb is False
Not the main process and log_all_workers is False
- Return type
wandb.Run | None
Note
WandB initialization is skipped in performance mode to reduce overhead.
- gradient_accumulation_steps: int = 1#
- init_tx: bool = True#
- is_fine_tuning: bool = True#
- property is_process_zero#
- learning_rate: float = 5e-05#
- classmethod load_arguments(json_file: str | os.PathLike)[source]#
Load training arguments from a JSON file.
This class method reconstructs a TrainingArguments instance from a previously saved JSON configuration file. It handles class resolution and proper type conversion.
- Parameters
json_file – Path to the JSON file containing saved arguments
- Returns
- Reconstructed configuration object with all
settings from the saved file
- Return type
Note
The JSON file should contain a ‘trainer_config_class’ field for proper class resolution when using subclasses.
- log_all_workers: bool = False#
- log_grad_norms: bool = True#
- log_metrics(metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#
Log metrics to configured logging backends.
This method handles logging to multiple backends (WandB, TensorBoard) and supports various metric types including scalars, histograms, and distributions. It automatically filters and formats metrics for each backend’s requirements.
- Parameters
metrics – Dictionary of metric names to values. Values can be: - Scalars (float, int) - Arrays (numpy, JAX arrays) - Histograms (tuple of bin_counts and bin_edges) - Tensors (automatically converted)
step – The current training/evaluation step
log_as – Special logging mode: - None: Regular step-based logging - “summary”: Log as final summary (WandB only) - “config”: Log as configuration (WandB only)
Note
Metrics are automatically filtered for None values
Array metrics are converted to appropriate formats
Gradient norm metrics are restructured for clarity
Logging only occurs if can_log_metrics is True
- log_steps: int = 10#
- log_weight_distribution(state, step: int)[source]#
Log weight distribution histograms and statistics.
This method computes and logs detailed statistics about model weights, including histograms and summary statistics (mean, std, min, max). It’s useful for monitoring training stability and detecting issues like gradient explosion or vanishing.
- Parameters
state – Model state containing parameters to analyze
step – Current training step for logging
Note
Only logs at intervals defined by weight_distribution_log_steps
Uses weight_distribution_pattern to filter parameters
Computes statistics across all processes in distributed training
Logs both histograms and scalar statistics for each parameter
- loss_config: easydel.infra.loss_utils.LossConfig | None = None#
- low_mem_usage: bool = True#
- num_train_epochs: int = 10#
- offload_dataset: bool = False#
- property offload_device#
- offload_device_index: int = 0#
- offload_device_type: str = 'cpu'#
- optimizer: Literal['adafactor', 'adamw', 'mars', 'muon', 'rmsprop', 'lion', 'skew', 'quad'] = 'adamw'#
- performance_mode: bool = False#
- process_zero_is_admin: bool = True#
- progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
- pruning_module: Any = None#
- remove_ckpt_after_load: bool = False#
- remove_unused_columns: bool = True#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- report_metrics: bool = True#
- report_steps: int = 5#
- resume_if_possible: bool = True#
- save_arguments(json_file_path: str | os.PathLike | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath)[source]#
Save training arguments to a JSON file.
This method serializes the current configuration to a JSON file, preserving all settings for later reconstruction. The saved file includes class information for proper deserialization.
- Parameters
json_file_path – Path where the JSON file will be saved. Parent directories are created if needed.
Note
The saved JSON includes a ‘trainer_config_class’ field to ensure proper class resolution when loading.
- save_directory: str = 'EasyDeL-Checkpoints'#
- save_optimizer_state: bool = True#
- scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
- shuffle_seed_train: int = 64871#
- shuffle_train_dataset: bool = True#
- sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
- sparsify_module: bool = False#
- step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- to_json_string() str[source]#
Serializes this instance to a JSON string.
- Returns
String containing all the attributes that make up this configuration instance in JSON format.
- Return type
str
- total_batch_size: int = 32#
- track_memory: bool | float = False#
- train_on_inputs: bool = True#
- property training_time_seconds: int#
- truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
- tx_mu_dtype: numpy.dtype | None = None#
- use_data_collactor: bool = True#
- use_esurge_generation: bool = True#
- use_grain: bool = True#
- use_wandb: bool = True#
- verbose: bool = True#
- warmup_steps: int = 0#
- weight_decay: float = 0.01#
- weight_distribution_log_steps: int = 50#
- weight_distribution_pattern: str = '.*'#