Source code for easydel.__init__.trainers.training_configurations

# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from eformer.pytree import auto_pytree
import functools
import re
import typing as tp
import warnings
from copy import deepcopy
from dataclasses import field, fields
from pathlib import Path

import jax
import jax.experimental
import jax.experimental.multihost_utils
import jax.numpy as jnp
import numpy as np
from eformer.optimizers import OptimizerFactory, SchedulerConfig
from jax.sharding import PartitionSpec

from easydel.infra.errors import EasyDeLTimerError
from easydel.infra.etils import (
	AVAILABLE_OPTIMIZERS,
	AVAILABLE_SCHEDULERS,
	AVAILABLE_SPARSE_MODULE_TYPES,
	EasyDeLOptimizers,
	EasyDeLSchedulers,
)
from easydel.infra.loss_utils import LossConfig
from easydel.utils.compiling_utils import hash_fn
from easydel.utils.helpers import get_logger

from .utils import JaxDistributedConfig, compute_weight_stats

try:
	import wandb  # type: ignore # noqa: F821
except ImportError:
	wandb = None

if tp.TYPE_CHECKING:
	from flax.metrics.tensorboard import SummaryWriter
	from jax import Array
	from torch import Tensor  # type:ignore

	MetricsType = tp.Dict[
		str,
		tp.Union[float, tp.List, tp.Tuple, np.ndarray, Array, Tensor],
	]
else:
	SummaryWriter = tp.Any
	MetricsType = tp.Any
logger = get_logger(__name__)


def get_safe_arr(xs):
	if isinstance(xs, (np.generic, jax.Array)):
		if xs.size == 1:  # Only try .item() on size-1 arrays
			return xs.item()
		return xs
	return xs


# Constants
AVAILABLE_BACKENDS: tp.List[str] = ["cpu", "gpu", "tpu", None]


[docs]@auto_pytree class TrainingArguments: auto_shard_states: bool = field( default=True, metadata={"help": "Whether to automatically shard model states across devices."}, ) aux_loss_enabled: bool = field( default=False, metadata={"help": "Whether to enable the auxiliary loss."}, ) backend: tp.Optional[str] = field( default=None, metadata={ "help": "The JAX backend to use (e.g., 'cpu', 'gpu', 'tpu'). If None, JAX will choose." }, ) clip_grad: tp.Optional[float] = field( default=None, metadata={"help": "The value at which to clip the gradients."}, ) custom_scheduler: tp.Optional[tp.Callable[[int], tp.Any]] = field( default=None, metadata={ "help": "A custom scheduler function that takes the current step as input." }, ) dataloader_num_workers: tp.Optional[int] = field( default=0, metadata={"help": "The number of workers to use for the dataloader."}, ) dataloader_pin_memory: tp.Optional[bool] = field( default=False, metadata={"help": "Whether to pin memory for the dataloader."}, ) do_eval: bool = field( default=False, metadata={"help": "Whether to run evaluation during training."}, ) do_last_save: bool = field( default=True, metadata={"help": "Whether to save the model at the end of training."}, ) do_train: bool = field( default=True, metadata={"help": "Whether to run training."}, ) eval_batch_size: tp.Optional[int] = field( default=None, metadata={"help": "The batch size to use for evaluation."}, ) evaluation_steps: tp.Optional[int] = field( default=None, metadata={"help": "Run evaluation every X steps."}, ) extra_optimizer_kwargs: dict = field( default_factory=dict, metadata={"help": "Additional keyword arguments to pass to the optimizer."}, ) frozen_parameters: tp.Optional[str] = field( default=None, metadata={"help": "A regex pattern of parameters to freeze (not train)."}, ) gradient_accumulation_steps: int = field( default=1, metadata={"help": "The number of steps to accumulate gradients over."}, ) ids_to_pop_from_dataset: tp.Optional[tp.List[str]] = field( default_factory=list, metadata={"help": "A list of dataset columns to remove before training."}, ) is_fine_tuning: bool = field( default=True, metadata={"help": "Whether the training is a fine-tuning run."}, ) init_tx: bool = field( default=True, metadata={"help": "Whether to initialize the training state."}, ) jax_distributed_config: tp.Optional[dict] = field( default=None, metadata={"help": "Configuration for JAX distributed training."}, ) learning_rate: float = field( default=5e-5, metadata={"help": "The learning rate."}, ) learning_rate_end: tp.Optional[float] = field( default=None, metadata={"help": "The final learning rate for linear decay schedulers."}, ) log_all_workers: bool = field( default=False, metadata={ "help": "Whether to log metrics from all workers in a distributed setup." }, ) log_grad_norms: bool = field( default=True, metadata={"help": "Whether to log gradient norms."}, ) report_metrics: bool = field( default=True, metadata={"help": "Whether to report metrics to a logger."}, ) log_steps: int = field( default=10, metadata={"help": "Log metrics every X steps."}, ) loss_config: tp.Optional[LossConfig] = field( default=None, metadata={"help": "Configuration for the loss function."}, ) low_mem_usage: bool = field( default=True, metadata={"help": "Whether to try to minimize memory usage."}, ) max_evaluation_steps: tp.Optional[int] = field( default=None, metadata={"help": "Maximum number of evaluation steps."}, ) max_sequence_length: tp.Optional[int] = field( default=4096, metadata={"help": "The maximum sequence length."}, ) max_training_steps: tp.Optional[int] = field( default=None, metadata={"help": "The maximum number of training steps."}, ) model_name: str = field( default="BaseTrainer", metadata={"help": "The name of the model."}, ) model_parameters: tp.Optional[dict] = field( default=None, metadata={"help": "Model architecture config"}, ) metrics_to_show_in_rich_pbar: tp.Optional[tp.List[str]] = field( default=None, metadata={"help": "Metrics to display in the rich progress bar."}, ) num_train_epochs: int = field( default=10, metadata={"help": "The number of training epochs."}, ) offload_dataset: bool = field( default=False, metadata={"help": "Whether to offload the dataset to CPU or disk."}, ) offload_device_type: str = field( default="cpu", metadata={"help": "The device type to offload the dataset to (cpu or disk)."}, ) offload_device_index: int = field( default=0, metadata={"help": "The device index to offload the dataset to."}, ) optimizer: AVAILABLE_OPTIMIZERS = field( default=EasyDeLOptimizers.ADAMW, metadata={"help": "The optimizer to use."}, ) performance_mode: bool = field( default=False, metadata={"help": "Whether to enable performance mode (e.g., XLA compilation)."}, ) pruning_module: tp.Any = field( default=None, metadata={"help": "The pruning module to use."}, ) process_zero_is_admin: bool = field( default=True, metadata={"help": "Whether the process with rank 0 is the admin process."}, ) progress_bar_type: tp.Literal["tqdm", "rich", "json"] = field( default="tqdm", metadata={"help": "The type of progress bar to use."}, ) remove_ckpt_after_load: bool = field( default=False, metadata={"help": "Whether to remove the checkpoint after loading it."}, ) remove_unused_columns: bool = field( default=True, metadata={"help": "Whether to remove unused columns from the dataset."}, ) report_steps: int = field( default=5, metadata={"help": "Report metrics every X steps."}, ) save_directory: str = field( default="EasyDeL-Checkpoints", metadata={"help": "The directory to save checkpoints to."}, ) save_optimizer_state: bool = field( default=True, metadata={"help": "Whether to save the optimizer state along with the model."}, ) save_steps: tp.Optional[int] = field( default=None, metadata={"help": "Save a checkpoint every X steps."}, ) save_total_limit: tp.Optional[int] = field( default=None, metadata={"help": "The maximum number of checkpoints to keep."}, ) scheduler: AVAILABLE_SCHEDULERS = field( default=EasyDeLSchedulers.NONE, metadata={"help": "The scheduler to use."}, ) sparsify_module: bool = field( default=False, metadata={"help": "Whether to sparsify the model."}, ) sparse_module_type: AVAILABLE_SPARSE_MODULE_TYPES = field( default="bcoo", metadata={"help": "The type of sparse module to use."}, ) state_apply_fn_kwarguments_to_model: tp.Optional[dict] = field( default=None, metadata={"help": "Keyword arguments to pass to the state apply function."}, ) step_partition_spec: PartitionSpec = field( default=PartitionSpec(("dp", "fsdp"), "sp"), metadata={"help": "The partition specification for the training step."}, ) step_start_point: tp.Optional[int] = field( default=None, metadata={"help": "The step to start training from (for resuming)."}, ) shuffle_train_dataset: bool = field( default=True, metadata={"help": "Whether to shuffle the training dataset."}, ) total_batch_size: int = field( default=32, metadata={"help": "The total batch size."}, ) training_time_limit: tp.Optional[str] = field( default=None, metadata={"help": "The maximum training time (e.g., '1d', '2h30m')."}, ) train_on_inputs: bool = field( default=True, metadata={"help": "Whether to train on the input data."}, ) truncation_mode: tp.Literal["keep_end", "keep_start"] = field( default="keep_end", metadata={"help": "The truncation mode to use."}, ) tx_mu_dtype: tp.Optional[jnp.dtype] = field( default=None, metadata={"help": "The dtype to use for the `tx.mu` variable."}, ) track_memory: bool = field( default=False, metadata={"help": "Whether to track memory usage."}, ) use_data_collactor: bool = field( default=True, metadata={"help": "Whether to use a data collator."}, ) use_wandb: bool = field( default=True, metadata={"help": "Whether to use Weights & Biases for logging."}, ) verbose: bool = field( default=True, metadata={"help": "Whether to print verbose output."}, ) wandb_entity: tp.Optional[str] = field( default=None, metadata={"help": "The Weights & Biases entity."}, ) warmup_steps: int = field( default=0, metadata={"help": "The number of warmup steps for the learning rate scheduler."}, ) weight_decay: float = field( default=0.01, metadata={"help": "The weight decay value."}, ) weight_distribution_pattern: str = field( default=r".*?(layernorm|norm).*?", metadata={"help": "The pattern to use to extract weight distribution."}, ) weight_distribution_log_steps: int = field( default=0, metadata={"help": "log weight distribution every X steps."}, ) @property def offload_device(self): return jax.devices(self.offload_device_type)[self.offload_device_index] @property def training_time_seconds(self) -> int: if self.training_time_limit is None: return None return self._time_to_seconds(self.training_time_limit) @functools.cached_property def is_process_zero(self): return jax.process_index() == 0 def __post_init__(self): """ Validates the configuration, sets up distributed training, initializes the optimizer, configures logging. This method is automatically called after the object is initialized. """ self._validate_config() self._setup_distributed() self._setup_optimizer() self._setup_logging() self._ensure_variables() def _validate_config(self): """ Performs validation checks on the provided configuration settings. Raises ValueError if any configuration is invalid. """ assert self.gradient_accumulation_steps > 0, ( "`gradient_accumulation_steps` can't be lower than 1." ) if self.backend not in AVAILABLE_BACKENDS: raise ValueError( f"Backend {self.backend} is not recognized. Available backends: {AVAILABLE_BACKENDS}" ) def _setup_distributed(self): """ Sets up JAX distributed training based on the chosen backend and sharding configuration. Determines the number of available devices and sets up the device mesh. """ self.available_backends = len(jax.devices(self.backend)) JaxDistributedConfig.initialize(self.jax_distributed_config) def _setup_optimizer(self): """ Configures the optimizer and learning rate scheduler based on the provided arguments. Sets up the optimizer_kwargs dictionary. """ extra_optimizer_kwargs = ( self.extra_optimizer_kwargs if self.extra_optimizer_kwargs is not None else {} ) self.optimizer_kwargs = { "learning_rate": self.learning_rate, "learning_rate_end": self.learning_rate_end, "optimizer": self.optimizer, "scheduler": self.scheduler, "warmup_steps": self.warmup_steps, "gradient_accumulation_steps": self.gradient_accumulation_steps, "weight_decay": self.weight_decay, "steps": self.max_training_steps, "clip_grad": self.clip_grad, "mu_dtype": self.tx_mu_dtype, **extra_optimizer_kwargs, } def _setup_logging(self): """ Sets up logging for training using TensorBoard and Weights & Biases. Handles warnings if performance mode is enabled and disables WandB logging accordingly. """ if self.use_wandb and self.performance_mode: logger.info("WandB logging disabled due to performance mode") self.use_wandb = False if self.report_metrics and self.performance_mode: logger.info("Metrics reporting disabled due to performance mode") self.report_metrics = False if self.report_metrics: if not self.is_process_zero and not self.log_all_workers: logger.info( "Metrics reporting disabled and it's only working on process index 0 or " "admin process (`log_all_workers` is `False`)." ) self.report_metrics = False def _ensure_variables(self): """ Checks and sets up variables for start. """ self.step_start_point = self.step_start_point or 0 self.eval_batch_size = ( self.eval_batch_size if self.eval_batch_size is not None else self.total_batch_size ) if self.loss_config is None: self.loss_config = LossConfig() @staticmethod def _time_to_seconds(time_str: str) -> int: """ Converts a time string in the format "50min" or "23h" to seconds. Args: time_str (str): The time string to convert. Returns: int: The equivalent time in seconds. """ match = re.match( r"(\d+)\s*(h|hour|hours|min|m|minutes|s|sec|seconds)", time_str.lower() ) if not match: raise ValueError( "Invalid time format. Use `50min` for minutes, `23h` for hours, or `30s` for seconds." ) value, unit = match.groups() unit_to_seconds = { "h": 3600, "hour": 3600, "hours": 3600, "min": 60, "m": 60, "minutes": 60, "s": 1, "sec": 1, "seconds": 1, }.get(unit.lower()) return int(value) * unit_to_seconds
[docs] def get_path(self) -> Path: """ Returns the path to the checkpoint directory. Returns: Path: The path to the checkpoint directory. """ return Path(self.save_directory, self.model_name)
[docs] def ensure_checkpoint_path(self): """ Creates the checkpoint directory if it doesn't exist. """ path = self.get_path() path.mkdir(parents=True, exist_ok=True)
[docs] def get_optimizer_and_scheduler(self, steps: tp.Optional[int] = None): """ Returns the configured optimizer and learning rate scheduler. Args: steps (tp.Optional[int]): The number of training steps. If not provided, uses the value from `self.optimizer_kwargs`. Returns: tuple: A tuple containing the optimizer and scheduler. """ self.optimizer_kwargs["steps"] = steps or self.optimizer_kwargs["steps"] optimizer_kwargs = deepcopy(self.optimizer_kwargs) scheduler = optimizer_kwargs.pop("scheduler", None) if scheduler == "none": scheduler = None if scheduler == EasyDeLSchedulers.NONE: scheduler = None scheduler_config = SchedulerConfig( scheduler_type=scheduler, steps=optimizer_kwargs.pop("steps"), learning_rate=optimizer_kwargs.pop("learning_rate"), learning_rate_end=optimizer_kwargs.pop("learning_rate_end"), warmup_steps=optimizer_kwargs.pop("warmup_steps"), exponent=optimizer_kwargs.pop("exponent", 1), ) optimizer_kwargs.pop("gradient_accumulation_steps", 0) optimizer, scheduler = OptimizerFactory.create( optimizer_type=optimizer_kwargs.pop("optimizer"), scheduler_config=scheduler_config, clip_grad=optimizer_kwargs.pop("clip_grad"), weight_decay=optimizer_kwargs.pop("weight_decay"), custom_scheduler=self.custom_scheduler, **optimizer_kwargs, ) return optimizer, scheduler
[docs] def get_streaming_checkpointer(self): """ Returns the checkpoint manager, responsible for saving model checkpoints. Returns: CheckpointManager: The checkpoint manager. """ import os.path from easydel.utils.checkpoint_managers import CheckpointManager return CheckpointManager( checkpoint_dir=os.path.join(self.save_directory, self.model_name), save_optimizer_state=self.save_optimizer_state, verbose=self.verbose, )
@functools.cached_property def _tensorboard(self): from flax.metrics.tensorboard import SummaryWriter return SummaryWriter(log_dir=str(self._get_save_directory(create=True)))
[docs] def get_tensorboard(self): """ Returns the TensorBoard SummaryWriter, used for logging metrics. Returns: flax.metrics.tensorboard.SummaryWriter: The TensorBoard SummaryWriter. """ return self._tensorboard
[docs] def get_wandb_init(self): """ Initializes Weights & Biases for experiment tracking if enabled. Returns: tp.Optional[wandb.sdk.wandb_run.Run]: The WandB run object if initialized, else None. """ if self.report_metrics: if not self.use_wandb or wandb is None: warnings.warn( "you have used `use_wandb=True` but you haven't install wandb.", stacklevel=1, ) return None return wandb.init( project=f"EasyDeL-{self.model_name}", config=self.to_dict(), tags=["EasyDeL", "JAX/Flax"], entity=self.wandb_entity, ) return None
[docs] def ensure_training_time_limit(self, time_passed): if self.training_time_limit is not None and time_passed > self._time_to_seconds( self.training_time_limit ): raise EasyDeLTimerError("Time Out")
[docs] def log_metrics( self, metrics: MetricsType, step: int, log_as: tp.Optional[tp.Literal["summary", "config"]] = None, ): """ Logs training metrics to Weights & Biases and/or TensorBoard. Args: metrics (tp.Dict[str, tp.Union[float, tp.List, tp.Tuple, np.ndarray, 'jnp.ndarray', 'torch.Tensor']]): A dictionary where keys are metric names and values are metric values. step (int): The current training step or iteration. """ if self.report_metrics: filtered_metrics = {k: v for k, v in metrics.items() if v is not None} metrics = { self._restructure_metric_name(k): get_safe_arr(v) for k, v in filtered_metrics.items() } self._log_to_wandb(metrics, step, log_as) self._log_to_tensorboard(metrics, step, log_as)
def _restructure_metric_name(self, metric_name: str) -> str: """ Restructures the metric name for logging. Args: metric_name (str): The original metric name. Returns: str: The restructured metric name. """ if metric_name.startswith("train/grad_norm/"): return metric_name.replace("train/grad_norm/", "grad_norm/") return metric_name
[docs] def log_weight_distribution(self, state, step: int): if self.weight_distribution_log_steps > 0 and ( (step % self.weight_distribution_log_steps) == 0 ): stats = compute_weight_stats(state.graphstate, self.weight_distribution_pattern) stats = jax.experimental.multihost_utils.process_allgather(stats) metrics = {} for key, value in stats.items(): if key.endswith("/values"): path: str = key[:-7] path = path.replace("/", ".") metrics[f"weights-histogram/{path}"] = np.array(value) else: key = key.replace("/", ".") metrics[f"weights-information/{key}"] = float(value) self.log_metrics(metrics, step)
def _log_to_wandb( self, metrics, step, log_as: tp.Optional[tp.Literal["summary", "config"]] = None, ): """ Log metrics to Weights & Biases (wandb). This method processes the given metrics and logs them to wandb if it's enabled and properly initialized. Args: metrics (dict): A dictionary of metrics to log. Keys are metric names, values are the metric values. step (int): The current step or iteration number. """ if self.use_wandb and wandb is not None: if log_as == "summary": wandb.summary.update(metrics) elif log_as == "config": wandb.config.update(metrics) else: wandb_metrics = {} for key, value in metrics.items(): try: wandb_metrics[key] = ( self._create_wandb_histogram(value) if isinstance(value, (list, tuple, np.generic, jax.Array)) else value ) except Exception as e: warnings.warn(f"Failed to log metric {key} to wandb: {e}", stacklevel=3) try: wandb.log(wandb_metrics, step=step) except Exception: ... def _log_to_tensorboard( self, metrics, step, log_as: tp.Optional[tp.Literal["summary", "config"]] = None, ): """ Log metrics to TensorBoard. This method processes the given metrics and logs them to TensorBoard. Args: metrics (dict): A dictionary of metrics to log. Keys are metric names, values are the metric values. step (int): The current step or iteration number. """ summary_writer = self.get_tensorboard() for key, value in metrics.items(): try: if isinstance(value, (float, int)): summary_writer.scalar(key, value, step) elif isinstance(value, (list, tuple, np.ndarray, jnp.ndarray)): summary_writer.histogram(key, np.array(value), step) except Exception as e: warnings.warn(f"Failed to log metric {key} to TensorBoard: {e}", stacklevel=1) finally: summary_writer.flush() def _create_wandb_histogram(self, value): """ Create a wandb.Histogram object from the given value. This method handles the conversion of various data types to a format suitable for wandb histograms. Args: value: The value to convert into a wandb.Histogram. Can be a list, tuple, numpy array, etc. Returns: wandb.Histogram or None: A wandb.Histogram object if successful, None if an error occurs. Notes: - Non-numpy array inputs are converted to numpy arrays. - float16 and bfloat16 dtypes are converted to float32 to avoid potential issues. - tp.Any exceptions during histogram creation are caught and logged, returning None in such cases. """ try: if isinstance(value, jax.Array): value = np.array(jax.device_get(value)) if value.dtype in [np.bfloat16]: value = value.astype(np.float32) value = value.astype(np.float16) return wandb.Histogram(value) except Exception as e: (f"Failed to create wandb histogram: {e}") return None def to_dict(self) -> tp.Dict[str, tp.Any]: """ Converts the TrainingArguments object into a dictionary. Returns: tp.Dict[str, tp.Any]: A dictionary representation of the TrainingArguments. """ return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} @classmethod def from_dict(cls, config: tp.Dict[str, tp.Any]) -> "TrainingArguments": """ Creates a TrainingArguments instance from a dictionary. Args: config (tp.Dict[str, tp.Any]): The configuration dictionary. Returns: TrainingArguments: A TrainingArguments object initialized with values from the dictionary. """ return cls(**config) def __repr__(self): cls_name = self.__class__.__name__ field_lines = [ f" {f.name}: {getattr(self, f.name)!r}".replace("\n", "\n ") for f in fields(self) ] return f"{cls_name}(\n" + "\n".join(field_lines) + "\n)" __str__ = __repr__ def _get_save_directory(self, create: bool = True) -> Path: bd = Path(self.save_directory) dir = bd / Path(self.model_name) if create: dir.mkdir(exist_ok=True, parents=True) return dir def _get_save_directory_milestone(self, step, create: bool = True) -> Path: directory_name = f"run-{step}" save_directory = self._get_save_directory(create=create) / directory_name if create: save_directory.mkdir(exist_ok=True, parents=True) return save_directory __hash__ = hash_fn