easydel.trainers.training_configurations

Contents

easydel.trainers.training_configurations#

class easydel.trainers.training_configurations.TrainingArguments(auto_shard_states: bool = True, aux_loss_enabled: bool = False, backend: str | None = None, clip_grad: float | None = None, custom_scheduler: ~typing.Optional[~typing.Callable[[int], ~typing.Any]] = None, dataloader_num_workers: int | None = 0, dataloader_pin_memory: bool | None = False, do_eval: bool = False, do_last_save: bool = True, do_train: bool = True, eval_batch_size: int | None = None, evaluation_steps: int | None = None, extra_optimizer_kwargs: dict = <factory>, frozen_parameters: str | None = None, grain_shard_index: int | None = None, grain_shard_count: int | None = None, gradient_accumulation_steps: int = 1, ids_to_pop_from_dataset: list[str] | None = <factory>, is_fine_tuning: bool = True, init_tx: bool = True, jax_distributed_config: dict | None = None, learning_rate: float = 5e-05, learning_rate_end: float | None = None, log_all_workers: bool = False, log_grad_norms: bool = True, report_metrics: bool = True, log_steps: int = 10, loss_config: easydel.infra.loss_utils.LossConfig | None = None, low_mem_usage: bool = True, max_evaluation_steps: int | None = None, max_sequence_length: int | None = 4096, max_training_steps: int | None = None, per_epoch_training_steps: int | None = None, per_epoch_evaluation_steps: int | None = None, model_name: str | None = None, model_parameters: dict | None = None, metrics_to_show_in_rich_pbar: list[str] | None = None, generation_top_p: float | None = None, generation_top_k: int | None = None, generation_temperature: float | None = None, generation_do_sample: bool | None = None, generation_num_return_sequences: int | None = None, generation_max_new_tokens: int | None = None, generation_shard_inputs: bool = True, generation_interval: int | None = None, generation_prompts: list[str | dict[str, typing.Any]] = <factory>, generation_use_train_prompts: bool = False, generation_num_prompts: int = 1, generation_dataset_prompt_field: str | None = 'prompt', generation_extra_kwargs: dict[str, typing.Any] | None = None, generation_config_overrides: dict[str, typing.Any] | None = None, generation_seed: int | None = None, generation_preview_print: bool = False, generation_log_to_wandb: bool = True, use_esurge_generation: bool = True, esurge_use_tqdm: bool = True, esurge_hbm_utilization: float | None = 0.45, esurge_max_num_seqs: int | None = None, esurge_min_input_pad: int | None = None, esurge_page_size: int | None = 32, esurge_silent_mode: bool = True, num_train_epochs: int = 10, offload_dataset: bool = False, offload_device_type: str = 'cpu', offload_device_index: int = 0, optimizer: ~typing.Literal['adafactor', 'adamw', 'mars', 'muon', 'rmsprop', 'lion', 'skew', 'quad'] = EasyDeLOptimizers.ADAMW, performance_mode: bool = False, pruning_module: ~typing.Any = None, process_zero_is_admin: bool = True, progress_bar_type: ~typing.Literal['tqdm', 'rich', 'json'] = 'tqdm', remove_ckpt_after_load: bool = False, remove_unused_columns: bool = True, report_steps: int = 5, save_interval_minutes: float | None = None, save_directory: str = 'EasyDeL-Checkpoints', save_optimizer_state: bool = True, save_steps: int | None = None, save_total_limit: int | None = None, scheduler: ~typing.Literal['linear', 'cosine', 'none'] = EasyDeLSchedulers.NONE, shuffle_seed_train: int = 64871, sparsify_module: bool = False, sparse_module_type: ~typing.Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo', state_apply_fn_kwarguments_to_model: dict | None = None, step_partition_spec: ~jax.sharding.PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: int | None = None, resume_if_possible: bool = True, shuffle_train_dataset: bool = True, total_batch_size: int = 32, training_time_limit: str | None = None, train_on_inputs: bool = True, trainer_prefix: str | None = None, truncation_mode: ~typing.Literal['keep_end', 'keep_start'] = 'keep_end', tx_mu_dtype: numpy.dtype | None = None, track_memory: bool | float = False, use_data_collactor: bool = True, use_grain: bool = True, use_wandb: bool = True, verbose: bool = True, wandb_entity: str | None = None, wandb_name: str | None = None, warmup_steps: int = 0, weight_decay: float = 0.01, weight_distribution_pattern: str = '.*', weight_distribution_log_steps: int = 50, _can_log_metrics: bool | None = None, _im_a_hidden_checkpoint_manager: eformer.serialization.checkpointer.Checkpointer | None = None)[source]#

Bases: object

Comprehensive configuration class for training and evaluation.

This class encapsulates all training hyperparameters, optimization settings, data loading configuration, logging preferences, and hardware-specific options. It provides a centralized way to manage the complex configuration required for distributed training of large models.

The configuration covers: - Training hyperparameters (learning rate, batch size, epochs) - Optimization settings (optimizer, scheduler, gradient clipping) - Data loading (dataset configuration, batch collation) - Checkpointing (save frequency, checkpoint limits) - Logging (WandB, TensorBoard, metrics reporting) - Hardware configuration (sharding, precision, device placement) - Performance optimization (compilation, memory tracking)

Example

>>> args = TrainingArguments(
...     learning_rate=1e-4,
...     num_train_epochs=3,
...     total_batch_size=32,
...     save_steps=1000,
...     use_wandb=True
... )
auto_shard_states: bool = True#
aux_loss_enabled: bool = False#
backend: str | None = None#
property can_log_metrics#
clip_grad: float | None = None#
custom_scheduler: Optional[Callable[[int], Any]] = None#
dataloader_num_workers: int | None = 0#
dataloader_pin_memory: bool | None = False#
do_eval: bool = False#
do_last_save: bool = True#
do_train: bool = True#
ensure_checkpoint_path()[source]#

Create the checkpoint directory if it doesn’t exist.

Ensures the full checkpoint path including parent directories exists on the filesystem. Safe to call multiple times.

Note

Uses mkdir with parents=True to create full directory tree.

ensure_training_time_limit(time_passed)[source]#
esurge_hbm_utilization: float | None = 0.45#
esurge_max_num_seqs: int | None = None#
esurge_min_input_pad: int | None = None#
esurge_page_size: int | None = 32#
esurge_silent_mode: bool = True#
esurge_use_tqdm: bool = True#
eval_batch_size: int | None = None#
evaluation_steps: int | None = None#
extra_optimizer_kwargs: dict#
classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

frozen_parameters: str | None = None#
generation_config_overrides: dict[str, Any] | None = None#
generation_dataset_prompt_field: str | None = 'prompt'#
generation_do_sample: bool | None = None#
generation_extra_kwargs: dict[str, Any] | None = None#
generation_interval: int | None = None#
generation_log_to_wandb: bool = True#
generation_max_new_tokens: int | None = None#
generation_num_prompts: int = 1#
generation_num_return_sequences: int | None = None#
generation_preview_print: bool = False#
generation_prompts: list[str | dict[str, Any]]#
generation_seed: int | None = None#
generation_shard_inputs: bool = True#
generation_temperature: float | None = None#
generation_top_k: int | None = None#
generation_top_p: float | None = None#
generation_use_train_prompts: bool = False#
get_checkpoint_policies()[source]#

Convert save_steps configuration to CheckpointInterval policies.

Returns

List of checkpoint interval policies.

Returns empty list if save_steps is None.

Return type

list[CheckpointInterval]

Example

>>> args = TrainingArguments(save_steps=1000)
>>> policies = args.get_checkpoint_policies()
>>> # Returns: [CheckpointInterval(every=1000, until=None)]
get_optimizer_and_scheduler(steps: int | None = None)[source]#

Create and return the optimizer and learning rate scheduler.

This method uses the OptimizerFactory to create the configured optimizer and scheduler based on the training arguments. It handles: - Standard optimizers (AdamW, SGD, etc.) - Learning rate schedules (linear, cosine, constant) - Gradient clipping and weight decay - Custom optimizers and schedulers

Parameters

steps – Optional override for the number of training steps. If not provided, uses the value from self.optimizer_kwargs.

Returns

A tuple of (optimizer, scheduler) where:
  • optimizer: Optax GradientTransformation

  • scheduler: Learning rate schedule function

Return type

tuple

Note

The optimizer is an Optax transformation chain that may include gradient clipping, weight decay, and other transformations.

get_path() eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath[source]#

Get the path to the checkpoint directory.

Returns

The path to the checkpoint directory, combining

save_directory and model_name.

Return type

ePathLike

Note

Creates a model-specific subdirectory within the main save directory.

get_save_interval_timedelta()[source]#

Get time-based checkpoint save interval as timedelta.

Returns

Time interval for temporary checkpoints,

or None if no time-based saving is configured.

Return type

timedelta | None

Note

Currently returns None. Can be extended to support time-based checkpoint saving via new TrainingArguments field.

get_streaming_checkpointer()[source]#

Get the asynchronous checkpoint manager.

Returns

The checkpoint manager for handling

asynchronous model checkpointing.

Return type

AsyncCheckpointManager

Note

The AsyncCheckpointManager allows non-blocking checkpoint saves during training, improving training efficiency.

get_tensorboard() SummaryWriter | None[source]#

Get the TensorBoard SummaryWriter for logging metrics.

Returns

The TensorBoard writer instance, or None if

TensorBoard is not available or not configured.

Return type

SummaryWriter | None

Note

Handles ModuleNotFoundError gracefully if TensorBoard is not installed. Uses cached property internally for efficiency.

get_tx_template(possible_max: int | None = None) GradientTransformation[source]#
get_wandb_init()[source]#

Initialize Weights & Biases for experiment tracking.

This method creates a new WandB run with appropriate configuration: - Project name based on model name and optional prefix - Run name with timestamp if not specified - Configuration dictionary from training arguments - Standard tags for EasyDeL experiments

The method handles process-level initialization, ensuring only the main process creates the WandB run in distributed settings.

Returns

The initialized WandB run object, or None if:
  • WandB is not installed

  • use_wandb is False

  • Not the main process and log_all_workers is False

Return type

wandb.Run | None

Note

WandB initialization is skipped in performance mode to reduce overhead.

gradient_accumulation_steps: int = 1#
grain_shard_count: int | None = None#
grain_shard_index: int | None = None#
ids_to_pop_from_dataset: list[str] | None#
init_tx: bool = True#
is_fine_tuning: bool = True#
property is_process_zero#
jax_distributed_config: dict | None = None#
learning_rate: float = 5e-05#
learning_rate_end: float | None = None#
classmethod load_arguments(json_file: str | os.PathLike)[source]#

Load training arguments from a JSON file.

This class method reconstructs a TrainingArguments instance from a previously saved JSON configuration file. It handles class resolution and proper type conversion.

Parameters

json_file – Path to the JSON file containing saved arguments

Returns

Reconstructed configuration object with all

settings from the saved file

Return type

TrainingArguments

Note

The JSON file should contain a ‘trainer_config_class’ field for proper class resolution when using subclasses.

classmethod load_from_json(config_dict)[source]#
log_all_workers: bool = False#
log_grad_norms: bool = True#
log_metrics(metrics: dict[str, float | list | tuple | numpy.ndarray | Any], step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#

Log metrics to configured logging backends.

This method handles logging to multiple backends (WandB, TensorBoard) and supports various metric types including scalars, histograms, and distributions. It automatically filters and formats metrics for each backend’s requirements.

Parameters
  • metrics – Dictionary of metric names to values. Values can be: - Scalars (float, int) - Arrays (numpy, JAX arrays) - Histograms (tuple of bin_counts and bin_edges) - Tensors (automatically converted)

  • step – The current training/evaluation step

  • log_as – Special logging mode: - None: Regular step-based logging - “summary”: Log as final summary (WandB only) - “config”: Log as configuration (WandB only)

Note

  • Metrics are automatically filtered for None values

  • Array metrics are converted to appropriate formats

  • Gradient norm metrics are restructured for clarity

  • Logging only occurs if can_log_metrics is True

log_steps: int = 10#
log_weight_distribution(state, step: int)[source]#

Log weight distribution histograms and statistics.

This method computes and logs detailed statistics about model weights, including histograms and summary statistics (mean, std, min, max). It’s useful for monitoring training stability and detecting issues like gradient explosion or vanishing.

Parameters
  • state – Model state containing parameters to analyze

  • step – Current training step for logging

Note

  • Only logs at intervals defined by weight_distribution_log_steps

  • Uses weight_distribution_pattern to filter parameters

  • Computes statistics across all processes in distributed training

  • Logs both histograms and scalar statistics for each parameter

loss_config: easydel.infra.loss_utils.LossConfig | None = None#
low_mem_usage: bool = True#
max_evaluation_steps: int | None = None#
max_sequence_length: int | None = 4096#
max_training_steps: int | None = None#
metrics_to_show_in_rich_pbar: list[str] | None = None#
model_name: str | None = None#
model_parameters: dict | None = None#
num_train_epochs: int = 10#
offload_dataset: bool = False#
property offload_device#
offload_device_index: int = 0#
offload_device_type: str = 'cpu'#
optimizer: Literal['adafactor', 'adamw', 'mars', 'muon', 'rmsprop', 'lion', 'skew', 'quad'] = 'adamw'#
per_epoch_evaluation_steps: int | None = None#
per_epoch_training_steps: int | None = None#
performance_mode: bool = False#
process_zero_is_admin: bool = True#
progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
pruning_module: Any = None#
remove_ckpt_after_load: bool = False#
remove_unused_columns: bool = True#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

report_metrics: bool = True#
report_steps: int = 5#
resume_if_possible: bool = True#
save_arguments(json_file_path: str | os.PathLike | eformer.paths.GCSPath | eformer.paths.LocalPath | eformer.paths.MLUtilPath)[source]#

Save training arguments to a JSON file.

This method serializes the current configuration to a JSON file, preserving all settings for later reconstruction. The saved file includes class information for proper deserialization.

Parameters

json_file_path – Path where the JSON file will be saved. Parent directories are created if needed.

Note

The saved JSON includes a ‘trainer_config_class’ field to ensure proper class resolution when loading.

save_directory: str = 'EasyDeL-Checkpoints'#
save_interval_minutes: float | None = None#
save_optimizer_state: bool = True#
save_steps: int | None = None#
save_total_limit: int | None = None#
scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
shuffle_seed_train: int = 64871#
shuffle_train_dataset: bool = True#
sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
sparsify_module: bool = False#
state_apply_fn_kwarguments_to_model: dict | None = None#
step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
step_start_point: int | None = None#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

to_json_string() str[source]#

Serializes this instance to a JSON string.

Returns

String containing all the attributes that make up this configuration instance in JSON format.

Return type

str

total_batch_size: int = 32#
track_memory: bool | float = False#
train_on_inputs: bool = True#
trainer_prefix: str | None = None#
training_time_limit: str | None = None#
property training_time_seconds: int#
truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
tx_mu_dtype: numpy.dtype | None = None#
use_data_collactor: bool = True#
use_esurge_generation: bool = True#
use_grain: bool = True#
use_wandb: bool = True#
verbose: bool = True#
wandb_entity: str | None = None#
wandb_name: str | None = None#
warmup_steps: int = 0#
weight_decay: float = 0.01#
weight_distribution_log_steps: int = 50#
weight_distribution_pattern: str = '.*'#
easydel.trainers.training_configurations.get_safe_arr(xs)[source]#