easydel.trainers.training_configurations

Contents

easydel.trainers.training_configurations#

class easydel.trainers.training_configurations.TrainingArguments(auto_shard_states: 'bool' = True, aux_loss_enabled: 'bool' = False, backend: 'tp.Optional[str]' = None, clip_grad: 'tp.Optional[float]' = None, custom_scheduler: 'tp.Optional[tp.Callable[[int], tp.Any]]' = None, dataloader_num_workers: 'tp.Optional[int]' = 0, dataloader_pin_memory: 'tp.Optional[bool]' = False, do_eval: 'bool' = False, do_last_save: 'bool' = True, do_train: 'bool' = True, eval_batch_size: 'tp.Optional[int]' = None, evaluation_steps: 'tp.Optional[int]' = None, extra_optimizer_kwargs: 'dict' = <factory>, frozen_parameters: 'tp.Optional[str]' = None, gradient_accumulation_steps: 'int' = 1, ids_to_pop_from_dataset: 'tp.Optional[tp.List[str]]' = <factory>, is_fine_tuning: 'bool' = True, init_tx: 'bool' = True, jax_distributed_config: 'tp.Optional[dict]' = None, learning_rate: 'float' = 5e-05, learning_rate_end: 'tp.Optional[float]' = None, log_all_workers: 'bool' = False, log_grad_norms: 'bool' = True, report_metrics: 'bool' = True, log_steps: 'int' = 10, loss_config: 'tp.Optional[LossConfig]' = None, low_mem_usage: 'bool' = True, max_evaluation_steps: 'tp.Optional[int]' = None, max_sequence_length: 'tp.Optional[int]' = 4096, max_training_steps: 'tp.Optional[int]' = None, model_name: 'str' = 'BaseTrainer', model_parameters: 'tp.Optional[dict]' = None, metrics_to_show_in_rich_pbar: 'tp.Optional[tp.List[str]]' = None, num_train_epochs: 'int' = 10, offload_dataset: 'bool' = False, offload_device_type: 'str' = 'cpu', offload_device_index: 'int' = 0, optimizer: 'AVAILABLE_OPTIMIZERS' = <EasyDeLOptimizers.ADAMW: 'adamw'>, performance_mode: 'bool' = False, pruning_module: 'tp.Any' = None, process_zero_is_admin: 'bool' = True, progress_bar_type: "tp.Literal['tqdm', 'rich', 'json']" = 'tqdm', remove_ckpt_after_load: 'bool' = False, remove_unused_columns: 'bool' = True, report_steps: 'int' = 5, save_directory: 'str' = 'EasyDeL-Checkpoints', save_optimizer_state: 'bool' = True, save_steps: 'tp.Optional[int]' = None, save_total_limit: 'tp.Optional[int]' = None, scheduler: 'AVAILABLE_SCHEDULERS' = <EasyDeLSchedulers.NONE: 'None'>, sparsify_module: 'bool' = False, sparse_module_type: 'AVAILABLE_SPARSE_MODULE_TYPES' = 'bcoo', state_apply_fn_kwarguments_to_model: 'tp.Optional[dict]' = None, step_partition_spec: 'PartitionSpec' = PartitionSpec(('dp', 'fsdp'), 'sp'), step_start_point: 'tp.Optional[int]' = None, shuffle_train_dataset: 'bool' = True, total_batch_size: 'int' = 32, training_time_limit: 'tp.Optional[str]' = None, train_on_inputs: 'bool' = True, truncation_mode: "tp.Literal['keep_end', 'keep_start']" = 'keep_end', tx_mu_dtype: 'tp.Optional[jnp.dtype]' = None, track_memory: 'bool' = False, use_data_collactor: 'bool' = True, use_wandb: 'bool' = True, verbose: 'bool' = True, wandb_entity: 'tp.Optional[str]' = None, warmup_steps: 'int' = 0, weight_decay: 'float' = 0.01, weight_distribution_pattern: 'str' = '.*?(layernorm|norm).*?', weight_distribution_log_steps: 'int' = 0)[source]#

Bases: object

auto_shard_states: bool = True#
aux_loss_enabled: bool = False#
backend: Optional[str] = None#
clip_grad: Optional[float] = None#
custom_scheduler: Optional[Callable[[int], Any]] = None#
dataloader_num_workers: Optional[int] = 0#
dataloader_pin_memory: Optional[bool] = False#
do_eval: bool = False#
do_last_save: bool = True#
do_train: bool = True#
ensure_checkpoint_path()[source]#

Creates the checkpoint directory if it doesn’t exist.

ensure_training_time_limit(time_passed)[source]#
eval_batch_size: Optional[int] = None#
evaluation_steps: Optional[int] = None#
extra_optimizer_kwargs: dict#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

frozen_parameters: Optional[str] = None#
get_optimizer_and_scheduler(steps: Optional[int] = None)[source]#

Returns the configured optimizer and learning rate scheduler.

Parameters

steps (tp.Optional[int]) – The number of training steps. If not provided, uses the value from self.optimizer_kwargs.

Returns

A tuple containing the optimizer and scheduler.

Return type

tuple

get_path() Path[source]#

Returns the path to the checkpoint directory.

Returns

The path to the checkpoint directory.

Return type

Path

get_streaming_checkpointer()[source]#

Returns the checkpoint manager, responsible for saving model checkpoints.

Returns

The checkpoint manager.

Return type

CheckpointManager

get_tensorboard()[source]#

Returns the TensorBoard SummaryWriter, used for logging metrics.

Returns

The TensorBoard SummaryWriter.

Return type

flax.metrics.tensorboard.SummaryWriter

get_wandb_init()[source]#

Initializes Weights & Biases for experiment tracking if enabled.

Returns

The WandB run object if initialized, else None.

Return type

tp.Optional[wandb.sdk.wandb_run.Run]

gradient_accumulation_steps: int = 1#
ids_to_pop_from_dataset: Optional[List[str]]#
init_tx: bool = True#
is_fine_tuning: bool = True#
property is_process_zero#
jax_distributed_config: Optional[dict] = None#
learning_rate: float = 5e-05#
learning_rate_end: Optional[float] = None#
log_all_workers: bool = False#
log_grad_norms: bool = True#
log_metrics(metrics: Any, step: int, log_as: Optional[Literal['summary', 'config']] = None)[source]#

Logs training metrics to Weights & Biases and/or TensorBoard.

Parameters
  • metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]) – A dictionary where keys are metric names and values are metric values.

  • step (int) – The current training step or iteration.

log_steps: int = 10#
log_weight_distribution(state, step: int)[source]#
loss_config: Optional[LossConfig] = None#
low_mem_usage: bool = True#
max_evaluation_steps: Optional[int] = None#
max_sequence_length: Optional[int] = 4096#
max_training_steps: Optional[int] = None#
metrics_to_show_in_rich_pbar: Optional[List[str]] = None#
model_name: str = 'BaseTrainer'#
model_parameters: Optional[dict] = None#
num_train_epochs: int = 10#
offload_dataset: bool = False#
property offload_device#
offload_device_index: int = 0#
offload_device_type: str = 'cpu'#
optimizer: Literal['adafactor', 'lion', 'adamw', 'rmsprop'] = 'adamw'#
performance_mode: bool = False#
process_zero_is_admin: bool = True#
progress_bar_type: Literal['tqdm', 'rich', 'json'] = 'tqdm'#
pruning_module: Any = None#
remove_ckpt_after_load: bool = False#
remove_unused_columns: bool = True#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

report_metrics: bool = True#
report_steps: int = 5#
save_directory: str = 'EasyDeL-Checkpoints'#
save_optimizer_state: bool = True#
save_steps: Optional[int] = None#
save_total_limit: Optional[int] = None#
scheduler: Literal['linear', 'cosine', 'none'] = 'None'#
shuffle_train_dataset: bool = True#
sparse_module_type: Literal['bcoo', 'bcsr', 'coo', 'csr'] = 'bcoo'#
sparsify_module: bool = False#
state_apply_fn_kwarguments_to_model: Optional[dict] = None#
step_partition_spec: PartitionSpec = PartitionSpec(('dp', 'fsdp'), 'sp')#
step_start_point: Optional[int] = None#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

total_batch_size: int = 32#
track_memory: bool = False#
train_on_inputs: bool = True#
training_time_limit: Optional[str] = None#
property training_time_seconds: int#
truncation_mode: Literal['keep_end', 'keep_start'] = 'keep_end'#
tx_mu_dtype: Optional[dtype] = None#
use_data_collactor: bool = True#
use_wandb: bool = True#
verbose: bool = True#
wandb_entity: Optional[str] = None#
warmup_steps: int = 0#
weight_decay: float = 0.01#
weight_distribution_log_steps: int = 0#
weight_distribution_pattern: str = '.*?(layernorm|norm).*?'#
easydel.trainers.training_configurations.get_safe_arr(xs)[source]#