# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metrics tracking and visualization for EasyDeL training.
This module provides comprehensive metrics collection, aggregation, and
visualization tools for training large language models. It includes:
- Real-time metrics calculation (loss, accuracy, throughput, TFLOPs)
- Progress bar implementations (tqdm, rich, JSON)
- Weight distribution analysis and histograms
- Performance profiling and benchmarking utilities
"""
from __future__ import annotations
import abc
import re
import time
import typing as tp
from collections import defaultdict
import jax
import numpy as np
from eformer.pytree import auto_pytree
from rich.progress import Progress, ProgressColumn, Task, TaskID
from rich.text import Text
from tqdm.autonotebook import tqdm
from easydel.infra.loss_utils import LossMetrics
from easydel.utils import traversals
from easydel.utils.compiling_utils import ejit
from easydel.utils.traversals import flatten_dict
try:
import wandb # type:ignore
except ImportError:
wandb = None
from eformer.loggings import get_logger
from jax import numpy as jnp
logger = get_logger("TrainerMetrics")
[docs]class StepMetrics:
"""Handles calculation and tracking of training metrics.
This class computes various metrics for each training/evaluation step,
including loss, accuracy, performance metrics (TFLOPs, throughput),
and gradient statistics.
Attributes:
arguments: Training configuration arguments.
start_time: Global start time for training.
step_start_time: Start time for current step.
Note:
Designed to work efficiently within JAX training loops.
Automatically handles metric aggregation across devices.
"""
def __init__(self, arguments):
"""Initialize the metrics calculator.
Args:
arguments: Training configuration with logging preferences.
"""
self.arguments = arguments
self.start_time = time.time()
self.step_start_time = time.time()
[docs] def start_step(self):
"""Mark the start of a training step.
Records the current time for step duration calculation.
Should be called at the beginning of each training/evaluation step.
"""
self.step_start_time = time.time()
[docs] def calculate(
self,
metrics: LossMetrics,
current_step: int,
epoch: int,
flops_per_token: float,
extra_flops_per_token: float,
batch_size: int,
seq_length: int,
learning_rate: float,
mode: tp.Literal["eval", "train"] | None = None,
**extras,
) -> dict[str, float]:
"""Calculate comprehensive metrics for the training step.
Computes performance metrics, loss statistics, and optional detailed
metrics like gradient norms.
Args:
metrics: Loss metrics from the training step.
current_step: Current training/evaluation step number.
epoch: Current epoch number.
flops_per_token: FLOPs required per token for forward pass.
extra_flops_per_token: Additional FLOPs for backward pass.
batch_size: Number of samples in the batch.
seq_length: Sequence length of inputs.
learning_rate: Current learning rate value.
mode: 'train' or 'eval' to prefix metric names.
**extras: Additional metrics to include.
Returns:
dict: Comprehensive metrics including:
- Basic metrics (loss, perplexity, accuracy)
- Performance metrics (TFLOPs, throughput)
- MLPerf benchmark metrics
- Optional gradient norms and detailed statistics
Note:
In performance mode, detailed metrics are skipped for efficiency.
"""
step_time = time.time() - self.step_start_time
total_time = time.time() - self.start_time
preprocessing_time = 0
if metrics.other_metrics is not None:
preprocessing_time = metrics.other_metrics.get("preprocessing_time", 0)
execution_time = metrics.execution_time - preprocessing_time
flops = flops_per_token * seq_length
total_flops = flops * batch_size
extra_flops = extra_flops_per_token * seq_length
total_flops += extra_flops * batch_size
tflops = (total_flops / execution_time) / 1e12
total_tokens = batch_size * seq_length
visited_tokens = total_tokens * current_step
throughput = total_tokens / execution_time
perf_key = mode + "-mlperf"
mlperf_metrics = {
f"{perf_key}/execution_time": float(execution_time),
f"{perf_key}/flops": float(flops),
f"{perf_key}/flops_per_token": float(flops_per_token),
f"{perf_key}/extra_flops": float(extra_flops),
f"{perf_key}/extra_flops_per_token": float(extra_flops_per_token),
f"{perf_key}/step_time": float(step_time),
f"{perf_key}/tflops": float(tflops),
f"{perf_key}/throughput": throughput,
f"{perf_key}/total_flops": float(total_flops),
f"{perf_key}/total_time": float(total_time),
f"{perf_key}/total_tokens": total_tokens,
}
loss = metrics.loss
z_loss = metrics.z_loss
basic_metrics = {
"epoch": int(epoch),
"execution_time": float(execution_time),
"learning_rate": float(np.array(learning_rate).item()),
"loss": float(loss),
"perplexity": float(jnp.exp(loss)),
f"{mode}_step": int(current_step),
f"{mode}_step_time": float(step_time),
"tflops": float(tflops),
"visited_tokens": visited_tokens,
"z_loss": float(z_loss) if z_loss is not None else None,
**extras,
}
if metrics.accuracy is not None:
basic_metrics["accuracy"] = float(metrics.accuracy)
if metrics.chosen_rewards is not None:
basic_metrics["chosen_rewards"] = float(jnp.mean(metrics.chosen_rewards).item())
if metrics.rejected_rewards is not None:
basic_metrics["rejected_rewards"] = float(jnp.mean(metrics.rejected_rewards).item())
if metrics.other_metrics is not None:
basic_metrics.update(metrics.other_metrics)
if not self.arguments.performance_mode and (mode == "train" or mode is None):
detailed_metrics = self._calculate_detailed_metrics(metrics)
basic_metrics.update(detailed_metrics)
if mode is not None:
basic_metrics = {f"{mode}/{k}": v for k, v in basic_metrics.items()}
basic_metrics.update(mlperf_metrics)
return basic_metrics
def _calculate_detailed_metrics(self, metrics: LossMetrics):
"""Calculate additional detailed metrics.
Computes gradient norms and other detailed statistics when not
in performance mode.
Args:
metrics: Loss metrics containing gradient information.
Returns:
dict: Detailed metrics including per-layer gradient norms.
Note:
Only computed when log_grad_norms is True and not in performance mode.
"""
detailed_metrics = {}
getattr_in = lambda x: x if not hasattr(x, "value") else x.value # noqa
if self.arguments.log_grad_norms:
if metrics.max_grad_norm is not None:
detailed_metrics.update({"train/max_grad_norm": getattr_in(metrics.max_grad_norm).tolist()})
if metrics.mean_grad_norm is not None:
detailed_metrics.update({"train/mean_grad_norm": getattr_in(metrics.mean_grad_norm).tolist()})
# Add per-layer gradient norms
if metrics.grad_norms is not None:
detailed_metrics.update(
{
f"grad_norm/{'.'.join([str(s) for s in layer_name])}": getattr_in(grad_norm).tolist()
for layer_name, grad_norm in flatten_dict(metrics.grad_norms).items()
if getattr_in(grad_norm) is not None
}
)
return detailed_metrics
[docs]class MetricsTracker:
"""Tracks and aggregates training metrics over time.
Maintains running averages of loss and accuracy across training steps,
useful for monitoring training progress and convergence.
Attributes:
loss_sum: Cumulative loss sum.
accuracy_sum: Cumulative accuracy sum.
metrics_history: Historical metrics for analysis.
step_offset: Step number offset for averaging.
"""
def __init__(self):
"""Initialize the metrics tracker with empty state."""
self.loss_sum = None
self.accuracy_sum = None
self.metrics_history = defaultdict(list)
self.step_offset = 0
[docs] def update(self, loss, accuracy, step):
"""Update tracked metrics with new values.
Args:
loss: Current step loss.
accuracy: Current step accuracy (can be None or inf).
step: Current step number.
Returns:
tuple | float: (mean_loss, mean_accuracy) if accuracy is valid,
otherwise just mean_loss.
Note:
Handles missing accuracy values gracefully.
"""
self.loss_sum = loss if self.loss_sum is None else self.loss_sum + loss
mean_loss = self.loss_sum / (step - self.step_offset)
if accuracy != float("inf"):
if accuracy is None:
accuracy = 0.0
self.accuracy_sum = accuracy if self.accuracy_sum is None else self.accuracy_sum + accuracy
mean_accuracy = self.accuracy_sum / (step - self.step_offset)
return float(mean_loss), float(mean_accuracy)
return float(mean_loss)
[docs] def reset(self, step):
"""Reset tracked metrics.
Args:
step: New step offset for averaging.
Note:
Typically called at the start of each epoch or evaluation phase.
"""
self.loss_sum = None
self.accuracy_sum = None
self.step_offset = step
[docs]class MetricsColumn(ProgressColumn):
"""A custom Rich progress column for displaying metrics.
Formats and displays training metrics in a readable format within
Rich progress bars.
Attributes:
metrics_to_show: Optional list of metric names to display.
If None, shows all metrics.
"""
def __init__(self, metrics_to_show=None):
"""Initialize the metrics column.
Args:
metrics_to_show: Optional list of metric names to filter display.
"""
super().__init__()
self.metrics_to_show = metrics_to_show
[docs] def render(self, task: Task) -> Text:
"""Render the metrics in an organized way.
Args:
task: Rich Task object containing metrics to display.
Returns:
Text: Formatted metrics text with styling.
Note:
Automatically formats floats with scientific notation for
very small or large values.
"""
if not task.fields.get("metrics"):
return Text("")
metrics = task.fields["metrics"]
display_items = []
for key, value in metrics.items():
if self.metrics_to_show is None:
if isinstance(value, float):
if abs(value) < 0.01 or abs(value) > 1000:
formatted_value = f"{value:.4e}"
else:
formatted_value = f"{value:.4f}"
else:
formatted_value = str(value)
display_items.append(f"{key}={formatted_value}")
else:
if any(metric in key for metric in self.metrics_to_show):
if isinstance(value, float):
if abs(value) < 0.01 or abs(value) > 1000:
formatted_value = f"{value:.4e}"
else:
formatted_value = f"{value:.4f}"
else:
formatted_value = str(value)
display_items.append(f"{key}={formatted_value}")
return Text(" • ".join(display_items), style="cyan")
[docs]class BaseProgressBar(abc.ABC):
"""Abstract base class for progress bar implementations.
Defines the interface for different progress bar backends
(tqdm, rich, JSON logging).
"""
[docs] @abc.abstractmethod
def update(self, n: int = 1) -> None:
"""Update the progress bar.
Args:
n: Number of steps to advance.
"""
pass
[docs] @abc.abstractmethod
def set_postfix(self, **kwargs) -> None:
"""Set postfix metrics to display.
Args:
**kwargs: Metric key-value pairs to display.
"""
pass
[docs] @abc.abstractmethod
def reset(self) -> None:
"""Reset the progress bar to initial state."""
pass
[docs] @abc.abstractmethod
def close(self) -> None:
"""Close and cleanup the progress bar."""
pass
[docs]class NullProgressBar(BaseProgressBar):
"""Dummy progress bar that does nothing.
Useful for multiprocessing scenarios where only the main process
should display progress, or when progress display is disabled.
"""
[docs] def update(self, n: int = 1) -> None:
pass
[docs] def set_postfix(self, **kwargs) -> None:
pass
[docs] def reset(self) -> None:
pass
[docs] def close(self) -> None:
pass
[docs]class TqdmProgressBar(BaseProgressBar):
"""Wrapper for tqdm progress bar.
Adapts tqdm progress bars to the BaseProgressBar interface.
Attributes:
pbar: Underlying tqdm progress bar instance.
"""
def __init__(self, pbar: tqdm):
"""Initialize with an existing tqdm progress bar.
Args:
pbar: tqdm progress bar instance.
"""
self.pbar = pbar
[docs] def update(self, n: int = 1) -> None:
self.pbar.update(n)
[docs] def set_postfix(self, **kwargs) -> None:
for k in list(kwargs.keys()):
val = kwargs.get(k)
if isinstance(val, float) and k != "learning_rate":
kwargs[k] = round(val, 3)
self.pbar.set_postfix(**kwargs)
[docs] def reset(self) -> None:
self.pbar.n = 0
self.pbar.start_t = self.pbar._time()
[docs] def close(self) -> None:
self.pbar.close()
[docs]class JSONProgressBar(BaseProgressBar):
"""JSON-based progress reporting.
Outputs progress as JSON logs instead of a visual progress bar.
Useful for structured logging and CI/CD environments.
Attributes:
desc: Description text for the progress.
"""
def __init__(self, desc=""):
"""Initialize JSON progress reporter.
Args:
desc: Description text for the progress.
"""
self.desc = desc
[docs] def update(self, n: int = 1) -> None: ...
[docs] def set_postfix(self, **kwargs) -> None:
for k in list(kwargs.keys()):
val = kwargs.get(k)
if hasattr(val, "size") and val.size == 1:
kwargs[k] = val.item()
if isinstance(val, float) and k != "learning_rate":
kwargs[k] = round(val, 3)
logger.info(kwargs)
[docs] def reset(self) -> None: ...
[docs] def close(self) -> None: ...
[docs]class RichProgressBar(BaseProgressBar):
"""Wrapper for Rich library progress bar.
Provides beautiful, customizable progress bars with support for
multiple columns and custom rendering.
Attributes:
progress: Rich Progress instance.
task_id: ID of the task being tracked.
_postfix: Current postfix metrics.
"""
def __init__(self, progress: Progress, task_id: TaskID):
"""Initialize RichProgressBar with an existing Progress instance.
Args:
progress: Rich Progress instance managing the display.
task_id: ID of the task to track within the Progress instance.
"""
self.progress = progress
self.task_id = task_id
self._postfix = {}
[docs] def update(self, n: int = 1) -> None:
self.progress.update(self.task_id, advance=n)
[docs] def set_postfix(self, **kwargs) -> None:
self._postfix.update(kwargs)
self.progress.update(self.task_id, metrics=self._postfix)
[docs] def reset(self) -> None:
self.progress.reset(self.task_id)
self._postfix = {}
[docs] def close(self) -> None:
try:
self.progress.remove_task(self.task_id)
except KeyError:
pass
[docs]@auto_pytree
class MetricsHistogram:
"""Compute and store histogram data for model weights or activations.
This class provides a PyTree-compatible way to compute histograms and statistics
for JAX arrays, optimized for use within JIT-compiled functions.
"""
bin_counts: jax.Array
bin_edges: jax.Array
size: int
min: jax.Array
max: jax.Array
sum: jax.Array
sum_squares: jax.Array
@staticmethod
@ejit
def _create_histogram_bin_edges(arr: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Create histogram bins and counts.
Args:
arr: Input array to create histogram from
Returns:
Tuple of (bin_counts, bin_edges)
"""
bin_edges = jnp.histogram_bin_edges(arr, 64)
left_edges = bin_edges[:-1, None]
right_edges = bin_edges[1:, None]
index = ((arr >= left_edges) & (arr < right_edges)).astype(arr.dtype)
out = index.sum(axis=1, dtype=arr.dtype)
out = jax.lax.cond(
out.size >= 1,
lambda o: o.at[-1].add(jnp.sum(arr == arr.max())),
lambda o: o,
out,
)
return out, bin_edges
[docs] @classmethod
def from_array(cls, arr: jax.Array) -> MetricsHistogram:
"""Create a histogram from an array.
Args:
arr: Input array
Returns:
MetricsHistogram instance
"""
flat_arr = arr.reshape(-1)
bin_counts, bin_edges = cls._create_histogram_bin_edges(flat_arr)
return cls(
bin_counts=bin_counts,
bin_edges=bin_edges,
size=flat_arr.size,
min=jnp.min(flat_arr),
max=jnp.max(flat_arr),
sum=jnp.sum(flat_arr),
sum_squares=jnp.sum(flat_arr**2),
)
[docs] def numpy_histogram(self) -> tuple[jax.Array, jax.Array]:
"""Return histogram data in numpy-compatible format.
Returns:
Tuple of (bin_counts, bin_edges)
"""
return self.bin_counts, self.bin_edges
@property
def mean(self) -> jax.Array:
"""Calculate mean of the original array.
Returns:
Mean value
"""
return self.sum / self.size
@property
def variance(self) -> jax.Array:
"""Calculate variance of the original array.
Returns:
Variance value
"""
mean = self.mean
mean_of_squares = self.sum_squares / self.size
variance = mean_of_squares - (mean**2)
return variance
@property
def std(self) -> jax.Array:
"""Calculate standard deviation of the original array.
Returns:
Standard deviation value
"""
return jnp.sqrt(self.variance).reshape(-1)
[docs]@ejit(static_argnums=(1,))
def compute_weight_stats(params: dict[str, tp.Any], repattern: str) -> dict[str, MetricsHistogram]:
"""Compute statistics for model weights in a JIT-compatible way.
Analyzes model parameters matching the given pattern and computes
histograms and statistical measures for monitoring training stability.
Args:
params: Model parameters as nested dictionary or PyTree.
repattern: Regular expression pattern to match parameter paths.
Use '.*' to match all parameters.
Returns:
dict: Weight statistics with keys formatted as 'path/to/param/histogram'
containing MetricsHistogram objects.
Note:
JIT-compiled with static pattern argument for efficiency.
Useful for detecting gradient explosion, vanishing gradients,
or monitoring weight distributions during training.
Example:
>>> stats = compute_weight_stats(model.params, r'.*dense.*')
>>> # Gets statistics for all dense layer weights
"""
stats = {}
for path, param in traversals.flatten_dict(params).items():
weight = param.value if hasattr(param, "value") else param
pattern_search = ".".join([str(p) for p in path])
output_path = "/".join([str(p) for p in path])
if re.match(repattern, pattern_search):
stats[f"{output_path}/histogram"] = MetricsHistogram.from_array(weight)
return stats