Source code for easydel.trainers.training_configurations

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import functools
import re
import typing as tp
import warnings
from copy import deepcopy
from dataclasses import dataclass, field, fields
from pathlib import Path

import jax
import jax.numpy as jnp
import numpy as np
from eformer.optimizers import OptimizerFactory, SchedulerConfig
from jax.sharding import PartitionSpec

from easydel.infra.errors import EasyDeLTimerError
from easydel.infra.etils import (
	AVAILABLE_OPTIMIZERS,
	AVAILABLE_SCHEDULERS,
	AVAILABLE_SPARSE_MODULE_TYPES,
	EasyDeLOptimizers,
	EasyDeLSchedulers,
)
from easydel.infra.loss_utils import LossConfig
from easydel.utils.compiling_utils import hash_fn
from easydel.utils.helpers import get_logger

from .utils import JaxDistributedConfig

try:
	import wandb  # type: ignore # noqa: F821
except ImportError:
	wandb = None

if tp.TYPE_CHECKING:
	from flax.metrics.tensorboard import SummaryWriter
	from jax import Array
	from torch import Tensor

	MetricsType = tp.Dict[
		str,
		tp.Union[float, tp.List, tp.Tuple, np.ndarray, Array, Tensor],
	]
else:
	SummaryWriter = tp.Any
	MetricsType = tp.Any
logger = get_logger(__name__)


[docs]def get_safe_arr(xs): if isinstance(xs, (np.generic, jax.Array)): if xs.size == 1: # Only try .item() on size-1 arrays return xs.item() return xs return xs
# Constants AVAILABLE_BACKENDS: tp.List[str] = ["cpu", "gpu", "tpu", None]
[docs]@dataclass class TrainingArguments: auto_shard_states: bool = field( default=True, metadata={"help": "Whether to automatically shard model states across devices."}, ) aux_loss_enabled: bool = field( default=False, metadata={"help": "Whether to enable the auxiliary loss."}, ) backend: tp.Optional[str] = field( default=None, metadata={ "help": "The JAX backend to use (e.g., 'cpu', 'gpu', 'tpu'). If None, JAX will choose." }, ) clip_grad: tp.Optional[float] = field( default=None, metadata={"help": "The value at which to clip the gradients."}, ) custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = field( default=None, metadata={ "help": "A custom scheduler function that takes the current step as input." }, ) dataloader_num_workers: tp.Optional[int] = field( default=0, metadata={"help": "The number of workers to use for the dataloader."}, ) dataloader_pin_memory: tp.Optional[bool] = field( default=False, metadata={"help": "Whether to pin memory for the dataloader."}, ) do_eval: bool = field( default=False, metadata={"help": "Whether to run evaluation during training."}, ) do_last_save: bool = field( default=True, metadata={"help": "Whether to save the model at the end of training."}, ) do_train: bool = field( default=True, metadata={"help": "Whether to run training."}, ) eval_batch_size: tp.Optional[int] = field( default=None, metadata={"help": "The batch size to use for evaluation."}, ) evaluation_steps: tp.Optional[int] = field( default=None, metadata={"help": "Run evaluation every X steps."}, ) extra_optimizer_kwargs: dict = field( default_factory=dict, metadata={"help": "Additional keyword arguments to pass to the optimizer."}, ) frozen_parameters: tp.Optional[str] = field( default=None, metadata={"help": "A regex pattern of parameters to freeze (not train)."}, ) gradient_accumulation_steps: int = field( default=1, metadata={"help": "The number of steps to accumulate gradients over."}, ) ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = field( default_factory=list, metadata={"help": "A list of dataset columns to remove before training."}, ) is_fine_tuning: bool = field( default=True, metadata={"help": "Whether the training is a fine-tuning run."}, ) init_tx: bool = field( default=True, metadata={"help": "Whether to initialize the training state."}, ) jax_distributed_config: tp.Optional[dict] = field( default=None, metadata={"help": "Configuration for JAX distributed training."}, ) learning_rate: float = field( default=5e-5, metadata={"help": "The learning rate."}, ) learning_rate_end: tp.Optional[float] = field( default=None, metadata={"help": "The final learning rate for linear decay schedulers."}, ) log_all_workers: bool = field( default=False, metadata={ "help": "Whether to log metrics from all workers in a distributed setup." }, ) log_grad_norms: bool = field( default=True, metadata={"help": "Whether to log gradient norms."}, ) report_metrics: bool = field( default=True, metadata={"help": "Whether to report metrics to a logger."}, ) log_steps: int = field( default=10, metadata={"help": "Log metrics every X steps."}, ) loss_config: tp.Optional[LossConfig] = field( default=None, metadata={"help": "Configuration for the loss function."}, ) low_mem_usage: bool = field( default=True, metadata={"help": "Whether to try to minimize memory usage."}, ) max_evaluation_steps: tp.Optional[int] = field( default=None, metadata={"help": "Maximum number of evaluation steps."}, ) max_sequence_length: tp.Optional[int] = field( default=4096, metadata={"help": "The maximum sequence length."}, ) max_training_steps: tp.Optional[int] = field( default=None, metadata={"help": "The maximum number of training steps."}, ) model_name: str = field( default="BaseTrainer", metadata={"help": "The name of the model."}, ) model_parameters: tp.Optional[dict] = field( default=None, metadata={"help": "Model architecture config"}, ) metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = field( default=None, metadata={"help": "Metrics to display in the rich progress bar."}, ) num_train_epochs: int = field( default=10, metadata={"help": "The number of training epochs."}, ) offload_dataset: bool = field( default=False, metadata={"help": "Whether to offload the dataset to CPU or disk."}, ) offload_device_type: str = field( default="cpu", metadata={"help": "The device type to offload the dataset to (cpu or disk)."}, ) offload_device_index: int = field( default=0, metadata={"help": "The device index to offload the dataset to."}, ) optimizer: AVAILABLE_OPTIMIZERS = field( default=EasyDeLOptimizers.ADAMW, metadata={"help": "The optimizer to use."}, ) performance_mode: bool = field( default=False, metadata={"help": "Whether to enable performance mode (e.g., XLA compilation)."}, ) pruning_module: tp.Any = field( default=None, metadata={"help": "The pruning module to use."}, ) process_zero_is_admin: bool = field( default=True, metadata={"help": "Whether the process with rank 0 is the admin process."}, ) progress_bar_type: tp.Literal["tqdm", "rich", "json"] = field( default="tqdm", metadata={"help": "The type of progress bar to use."}, ) remove_ckpt_after_load: bool = field( default=False, metadata={"help": "Whether to remove the checkpoint after loading it."}, ) remove_unused_columns: bool = field( default=True, metadata={"help": "Whether to remove unused columns from the dataset."}, ) report_steps: int = field( default=5, metadata={"help": "Report metrics every X steps."}, ) save_directory: str = field( default="EasyDeL-Checkpoints", metadata={"help": "The directory to save checkpoints to."}, ) save_optimizer_state: bool = field( default=True, metadata={"help": "Whether to save the optimizer state along with the model."}, ) save_steps: tp.Optional[int] = field( default=None, metadata={"help": "Save a checkpoint every X steps."}, ) save_total_limit: tp.Optional[int] = field( default=None, metadata={"help": "The maximum number of checkpoints to keep."}, ) scheduler: AVAILABLE_SCHEDULERS = field( default=EasyDeLSchedulers.NONE, metadata={"help": "The scheduler to use."}, ) sparsify_module: bool = field( default=False, metadata={"help": "Whether to sparsify the model."}, ) sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = field( default="bcoo", metadata={"help": "The type of sparse module to use."}, ) state_apply_fn_kwarguments_to_model: tp.Optional[dict] = field( default=None, metadata={"help": "Keyword arguments to pass to the state apply function."}, ) step_partition_spec: PartitionSpec = field( default=PartitionSpec(("dp", "fsdp"), "sp"), metadata={"help": "The partition specification for the training step."}, ) step_start_point: tp.Optional[int] = field( default=None, metadata={"help": "The step to start training from (for resuming)."}, ) shuffle_train_dataset: bool = field( default=True, metadata={"help": "Whether to shuffle the training dataset."}, ) total_batch_size: int = field( default=32, metadata={"help": "The total batch size."}, ) training_time_limit: tp.Optional[str] = field( default=None, metadata={"help": "The maximum training time (e.g., '1d', '2h30m')."}, ) train_on_inputs: bool = field( default=True, metadata={"help": "Whether to train on the input data."}, ) truncation_mode: tp.Literal["keep_end", "keep_start"] = field( default="keep_end", metadata={"help": "The truncation mode to use."}, ) tx_mu_dtype: tp.Optional[jnp.dtype] = field( default=None, metadata={"help": "The dtype to use for the `tx.mu` variable."}, ) track_memory: bool = field( default=False, metadata={"help": "Whether to track memory usage."}, ) use_data_collactor: bool = field( default=True, metadata={"help": "Whether to use a data collator."}, ) use_wandb: bool = field( default=True, metadata={"help": "Whether to use Weights & Biases for logging."}, ) verbose: bool = field( default=True, metadata={"help": "Whether to print verbose output."}, ) wandb_entity: tp.Optional[str] = field( default=None, metadata={"help": "The Weights & Biases entity."}, ) warmup_steps: int = field( default=0, metadata={"help": "The number of warmup steps for the learning rate scheduler."}, ) weight_decay: float = field( default=0.01, metadata={"help": "The weight decay value."}, ) @property def offload_device(self): return jax.devices(self.offload_device_type)[self.offload_device_index] @property def training_time_seconds(self) -> int: if self.training_time_limit is None: return None return self._time_to_seconds(self.training_time_limit) @functools.cached_property def is_process_zero(self): return jax.process_index() == 0 def __post_init__(self): """ Validates the configuration, sets up distributed training, initializes the optimizer, configures logging. This method is automatically called after the object is initialized. """ self._validate_config() self._setup_distributed() self._setup_optimizer() self._setup_logging() self._ensure_variables() def _validate_config(self): """ Performs validation checks on the provided configuration settings. Raises ValueError if any configuration is invalid. """ assert self.gradient_accumulation_steps > 0, ( "`gradient_accumulation_steps` can't be lower than 1." ) if self.backend not in AVAILABLE_BACKENDS: raise ValueError( f"Backend {self.backend} is not recognized. Available backends: {AVAILABLE_BACKENDS}" ) def _setup_distributed(self): """ Sets up JAX distributed training based on the chosen backend and sharding configuration. Determines the number of available devices and sets up the device mesh. """ self.available_backends = len(jax.devices(self.backend)) JaxDistributedConfig.initialize(self.jax_distributed_config) def _setup_optimizer(self): """ Configures the optimizer and learning rate scheduler based on the provided arguments. Sets up the optimizer_kwargs dictionary. """ extra_optimizer_kwargs = ( self.extra_optimizer_kwargs if self.extra_optimizer_kwargs is not None else {} ) self.optimizer_kwargs = { "learning_rate": self.learning_rate, "learning_rate_end": self.learning_rate_end, "optimizer": self.optimizer, "scheduler": self.scheduler, "warmup_steps": self.warmup_steps, "gradient_accumulation_steps": self.gradient_accumulation_steps, "weight_decay": self.weight_decay, "steps": self.max_training_steps, "clip_grad": self.clip_grad, "mu_dtype": self.tx_mu_dtype, **extra_optimizer_kwargs, } def _setup_logging(self): """ Sets up logging for training using TensorBoard and Weights & Biases. Handles warnings if performance mode is enabled and disables WandB logging accordingly. """ if self.use_wandb and self.performance_mode: logger.info("WandB logging disabled due to performance mode") self.use_wandb = False if self.report_metrics and self.performance_mode: logger.info("Metrics reporting disabled due to performance mode") self.report_metrics = False if self.report_metrics: if not self.is_process_zero and not self.log_all_workers: logger.info( "Metrics reporting disabled and it's only working on process index 0 or " "admin process (`log_all_workers` is `False`)." ) self.report_metrics = False def _ensure_variables(self): """ Checks and sets up variables for start. """ self.step_start_point = self.step_start_point or 0 self.eval_batch_size = ( self.eval_batch_size if self.eval_batch_size is not None else self.total_batch_size ) if self.loss_config is None: self.loss_config = LossConfig() @staticmethod def _time_to_seconds(time_str: str) -> int: """ Converts a time string in the format "50min" or "23h" to seconds. Args: time_str (str): The time string to convert. Returns: int: The equivalent time in seconds. """ match = re.match( r"(\d+)\s*(h|hour|hours|min|m|minutes|s|sec|seconds)", time_str.lower() ) if not match: raise ValueError( "Invalid time format. Use `50min` for minutes, `23h` for hours, or `30s` for seconds." ) value, unit = match.groups() unit_to_seconds = { "h": 3600, "hour": 3600, "hours": 3600, "min": 60, "m": 60, "minutes": 60, "s": 1, "sec": 1, "seconds": 1, }.get(unit.lower()) return int(value) * unit_to_seconds
[docs] def get_path(self) -> Path: """ Returns the path to the checkpoint directory. Returns: Path: The path to the checkpoint directory. """ return Path(self.save_directory, self.model_name)
[docs] def ensure_checkpoint_path(self): """ Creates the checkpoint directory if it doesn't exist. """ path = self.get_path() path.mkdir(parents=True, exist_ok=True)
[docs] def get_optimizer_and_scheduler(self, steps: tp.Optional[int] = None): """ Returns the configured optimizer and learning rate scheduler. Args: steps (tp.Optional[int]): The number of training steps. If not provided, uses the value from `self.optimizer_kwargs`. Returns: tuple: A tuple containing the optimizer and scheduler. """ self.optimizer_kwargs["steps"] = steps or self.optimizer_kwargs["steps"] optimizer_kwargs = deepcopy(self.optimizer_kwargs) scheduler = optimizer_kwargs.pop("scheduler", None) if scheduler == "none": scheduler = None if scheduler == EasyDeLSchedulers.NONE: scheduler = None scheduler_config = SchedulerConfig( scheduler_type=scheduler, steps=optimizer_kwargs.pop("steps"), learning_rate=optimizer_kwargs.pop("learning_rate"), learning_rate_end=optimizer_kwargs.pop("learning_rate_end"), warmup_steps=optimizer_kwargs.pop("warmup_steps"), exponent=optimizer_kwargs.pop("exponent", 1), ) optimizer_kwargs.pop("gradient_accumulation_steps", 0) optimizer, scheduler = OptimizerFactory.create( optimizer_type=optimizer_kwargs.pop("optimizer"), scheduler_config=scheduler_config, clip_grad=optimizer_kwargs.pop("clip_grad"), weight_decay=optimizer_kwargs.pop("weight_decay"), custom_scheduler=self.custom_scheduler, **optimizer_kwargs, ) return optimizer, scheduler
[docs] def get_streaming_checkpointer(self): """ Returns the checkpoint manager, responsible for saving model checkpoints. Returns: CheckpointManager: The checkpoint manager. """ import os.path from easydel.utils.checkpoint_managers import CheckpointManager return CheckpointManager( checkpoint_dir=os.path.join(self.save_directory, self.model_name), save_optimizer_state=self.save_optimizer_state, verbose=self.verbose, )
@functools.cached_property def _tensorboard(self): from flax.metrics.tensorboard import SummaryWriter return SummaryWriter(log_dir=str(self._get_save_directory(create=True)))
[docs] def get_tensorboard(self): """ Returns the TensorBoard SummaryWriter, used for logging metrics. Returns: flax.metrics.tensorboard.SummaryWriter: The TensorBoard SummaryWriter. """ return self._tensorboard
[docs] def get_wandb_init(self): """ Initializes Weights & Biases for experiment tracking if enabled. Returns: tp.Optional[wandb.sdk.wandb_run.Run]: The WandB run object if initialized, else None. """ if self.report_metrics: if not self.use_wandb or wandb is None: warnings.warn( "you have used `use_wandb=True` but you haven't install wandb.", stacklevel=1, ) return None return wandb.init( project=f"EasyDeL-{self.model_name}", config=self.to_dict(), tags=["EasyDeL", "JAX/Flax"], entity=self.wandb_entity, ) return None
[docs] def ensure_training_time_limit(self, time_passed): if self.training_time_limit is not None and time_passed > self._time_to_seconds( self.training_time_limit ): raise EasyDeLTimerError("Time Out")
[docs] def log_metrics( self, metrics: MetricsType, step: int, log_as: tp.Optional[tp.Literal["summary", "config"]] = None, ): """ Logs training metrics to Weights & Biases and/or TensorBoard. Args: metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]): A dictionary where keys are metric names and values are metric values. step (int): The current training step or iteration. """ if self.report_metrics: filtered_metrics = {k: v for k, v in metrics.items() if v is not None} metrics = { self._restructure_metric_name(k): get_safe_arr(v) for k, v in filtered_metrics.items() } self._log_to_wandb(metrics, step, log_as) self._log_to_tensorboard(metrics, step, log_as)
def _restructure_metric_name(self, metric_name: str) -> str: """ Restructures the metric name for logging. Args: metric_name (str): The original metric name. Returns: str: The restructured metric name. """ if metric_name.startswith("train/grad_norm/"): return metric_name.replace("train/grad_norm/", "grad_norm/") return metric_name def _log_to_wandb( self, metrics, step, log_as: tp.Optional[tp.Literal["summary", "config"]] = None, ): """ Log metrics to Weights & Biases (wandb). This method processes the given metrics and logs them to wandb if it's enabled and properly initialized. Args: metrics (dict): A dictionary of metrics to log. Keys are metric names, values are the metric values. step (int): The current step or iteration number. """ if self.use_wandb and wandb is not None: if log_as == "summary": wandb.summary.update(metrics) elif log_as == "config": wandb.config.update(metrics) else: wandb_metrics = {} for key, value in metrics.items(): try: wandb_metrics[key] = ( self._create_wandb_histogram(value) if isinstance(value, (list, tuple, np.ndarray, jnp.ndarray)) else value ) except Exception as e: warnings.warn(f"Failed to log metric {key} to wandb: {e}", stacklevel=3) try: wandb.log(wandb_metrics, step=step) except Exception: ... def _log_to_tensorboard( self, metrics, step, log_as: tp.Optional[tp.Literal["summary", "config"]] = None, ): """ Log metrics to TensorBoard. This method processes the given metrics and logs them to TensorBoard. Args: metrics (dict): A dictionary of metrics to log. Keys are metric names, values are the metric values. step (int): The current step or iteration number. """ summary_writer = self.get_tensorboard() for key, value in metrics.items(): try: if isinstance(value, (float, int)): summary_writer.scalar(key, value, step) elif isinstance(value, (list, tuple, np.ndarray, jnp.ndarray)): summary_writer.histogram(key, np.array(value), step) except Exception as e: warnings.warn(f"Failed to log metric {key} to TensorBoard: {e}", stacklevel=1) finally: summary_writer.flush() def _create_wandb_histogram(self, value): """ Create a wandb.Histogram object from the given value. This method handles the conversion of various data types to a format suitable for wandb histograms. Args: value: The value to convert into a wandb.Histogram. Can be a list, tuple, numpy array, etc. Returns: wandb.Histogram or None: A wandb.Histogram object if successful, None if an error occurs. Notes: - Non-numpy array inputs are converted to numpy arrays. - float16 and bfloat16 dtypes are converted to float32 to avoid potential issues. - tp.Any exceptions during histogram creation are caught and logged, returning None in such cases. """ try: # Convert to numpy array if it's not already if not isinstance(value, np.ndarray): value = np.array(value) # Handle different dtypes if value.dtype in [np.float16, np.bfloat16]: value = value.astype(np.float32) return wandb.Histogram(value) except Exception as e: (f"Failed to create wandb histogram: {e}") return None
[docs] def to_dict(self) -> tp.Dict[str, tp.Any]: """ Converts the TrainingArguments object into a dictionary. Returns: tp.Dict[str, tp.Any]: A dictionary representation of the TrainingArguments. """ return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
[docs] @classmethod def from_dict(cls, config: tp.Dict[str, tp.Any]) -> "TrainingArguments": """ Creates a TrainingArguments instance from a dictionary. Args: config (tp.Dict[str, tp.Any]): The configuration dictionary. Returns: TrainingArguments: A TrainingArguments object initialized with values from the dictionary. """ return cls(**config)
def __repr__(self): cls_name = self.__class__.__name__ field_lines = [ f" {f.name}: {getattr(self, f.name)!r}".replace("\n", "\n ") for f in fields(self) ] return f"{cls_name}(\n" + "\n".join(field_lines) + "\n)" __str__ = __repr__ def _get_save_directory(self, create: bool = True) -> Path: bd = Path(self.save_directory) dir = bd / Path(self.model_name) if create: dir.mkdir(exist_ok=True, parents=True) return dir def _get_save_directory_milestone(self, step, create: bool = True) -> Path: directory_name = f"run-{step}" save_directory = self._get_save_directory(create=create) / directory_name if create: save_directory.mkdir(exist_ok=True, parents=True) return save_directory __hash__ = hash_fn