easydel.trainers.metrics

Contents

easydel.trainers.metrics#

Metrics tracking and visualization for EasyDeL training.

This module provides comprehensive metrics collection, aggregation, and visualization tools for training large language models. It includes: - Real-time metrics calculation (loss, accuracy, throughput, TFLOPs) - Progress bar implementations (tqdm, rich, JSON) - Weight distribution analysis and histograms - Performance profiling and benchmarking utilities

class easydel.trainers.metrics.BaseProgressBar[source]#

Bases: ABC

Abstract base class for progress bar implementations.

Defines the interface for different progress bar backends (tqdm, rich, JSON logging).

abstract close() None[source]#

Close and cleanup the progress bar.

abstract reset() None[source]#

Reset the progress bar to initial state.

abstract set_postfix(**kwargs) None[source]#

Set postfix metrics to display.

Parameters

**kwargs – Metric key-value pairs to display.

abstract update(n: int = 1) None[source]#

Update the progress bar.

Parameters

n – Number of steps to advance.

class easydel.trainers.metrics.JSONProgressBar(desc='')[source]#

Bases: BaseProgressBar

JSON-based progress reporting.

Outputs progress as JSON logs instead of a visual progress bar. Useful for structured logging and CI/CD environments.

desc#

Description text for the progress.

close() None[source]#

Close and cleanup the progress bar.

reset() None[source]#

Reset the progress bar to initial state.

set_postfix(**kwargs) None[source]#

Set postfix metrics to display.

Parameters

**kwargs – Metric key-value pairs to display.

update(n: int = 1) None[source]#

Update the progress bar.

Parameters

n – Number of steps to advance.

class easydel.trainers.metrics.MetricsColumn(metrics_to_show=None)[source]#

Bases: ProgressColumn

A custom Rich progress column for displaying metrics.

Formats and displays training metrics in a readable format within Rich progress bars.

metrics_to_show#

Optional list of metric names to display. If None, shows all metrics.

render(task: Task) Text[source]#

Render the metrics in an organized way.

Parameters

task – Rich Task object containing metrics to display.

Returns

Formatted metrics text with styling.

Return type

Text

Note

Automatically formats floats with scientific notation for very small or large values.

class easydel.trainers.metrics.MetricsHistogram(bin_counts: Array, bin_edges: Array, size: int, min: Array, max: Array, sum: Array, sum_squares: Array)[source]#

Bases: object

Compute and store histogram data for model weights or activations.

This class provides a PyTree-compatible way to compute histograms and statistics for JAX arrays, optimized for use within JIT-compiled functions.

bin_counts: Array#
bin_edges: Array#
classmethod from_array(arr: Array) MetricsHistogram[source]#

Create a histogram from an array.

Parameters

arr – Input array

Returns

MetricsHistogram instance

classmethod from_dict(data: dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

max: Array#
property mean: Array#

Calculate mean of the original array.

Returns

Mean value

min: Array#
numpy_histogram() tuple[jax.Array, jax.Array][source]#

Return histogram data in numpy-compatible format.

Returns

Tuple of (bin_counts, bin_edges)

replace(**kwargs)#

Creates a new instance with specified fields replaced.

size: int#
property std: Array#

Calculate standard deviation of the original array.

Returns

Standard deviation value

sum: Array#
sum_squares: Array#
to_dict() dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

property variance: Array#

Calculate variance of the original array.

Returns

Variance value

class easydel.trainers.metrics.MetricsTracker[source]#

Bases: object

Tracks and aggregates training metrics over time.

Maintains running averages of loss and accuracy across training steps, useful for monitoring training progress and convergence.

loss_sum#

Cumulative loss sum.

accuracy_sum#

Cumulative accuracy sum.

metrics_history#

Historical metrics for analysis.

step_offset#

Step number offset for averaging.

reset(step)[source]#

Reset tracked metrics.

Parameters

step – New step offset for averaging.

Note

Typically called at the start of each epoch or evaluation phase.

update(loss, accuracy, step)[source]#

Update tracked metrics with new values.

Parameters
  • loss – Current step loss.

  • accuracy – Current step accuracy (can be None or inf).

  • step – Current step number.

Returns

(mean_loss, mean_accuracy) if accuracy is valid,

otherwise just mean_loss.

Return type

tuple | float

Note

Handles missing accuracy values gracefully.

class easydel.trainers.metrics.NullProgressBar[source]#

Bases: BaseProgressBar

Dummy progress bar that does nothing.

Useful for multiprocessing scenarios where only the main process should display progress, or when progress display is disabled.

close() None[source]#

Close and cleanup the progress bar.

reset() None[source]#

Reset the progress bar to initial state.

set_postfix(**kwargs) None[source]#

Set postfix metrics to display.

Parameters

**kwargs – Metric key-value pairs to display.

update(n: int = 1) None[source]#

Update the progress bar.

Parameters

n – Number of steps to advance.

class easydel.trainers.metrics.RichProgressBar(progress: Progress, task_id: TaskID)[source]#

Bases: BaseProgressBar

Wrapper for Rich library progress bar.

Provides beautiful, customizable progress bars with support for multiple columns and custom rendering.

progress#

Rich Progress instance.

task_id#

ID of the task being tracked.

_postfix#

Current postfix metrics.

close() None[source]#

Close and cleanup the progress bar.

reset() None[source]#

Reset the progress bar to initial state.

set_postfix(**kwargs) None[source]#

Set postfix metrics to display.

Parameters

**kwargs – Metric key-value pairs to display.

update(n: int = 1) None[source]#

Update the progress bar.

Parameters

n – Number of steps to advance.

class easydel.trainers.metrics.StepMetrics(arguments)[source]#

Bases: object

Handles calculation and tracking of training metrics.

This class computes various metrics for each training/evaluation step, including loss, accuracy, performance metrics (TFLOPs, throughput), and gradient statistics.

arguments#

Training configuration arguments.

start_time#

Global start time for training.

step_start_time#

Start time for current step.

Note

Designed to work efficiently within JAX training loops. Automatically handles metric aggregation across devices.

calculate(metrics: LossMetrics, current_step: int, epoch: int, flops_per_token: float, extra_flops_per_token: float, batch_size: int, seq_length: int, learning_rate: float, mode: Optional[Literal['eval', 'train']] = None, **extras) dict[str, float][source]#

Calculate comprehensive metrics for the training step.

Computes performance metrics, loss statistics, and optional detailed metrics like gradient norms.

Parameters
  • metrics – Loss metrics from the training step.

  • current_step – Current training/evaluation step number.

  • epoch – Current epoch number.

  • flops_per_token – FLOPs required per token for forward pass.

  • extra_flops_per_token – Additional FLOPs for backward pass.

  • batch_size – Number of samples in the batch.

  • seq_length – Sequence length of inputs.

  • learning_rate – Current learning rate value.

  • mode – ‘train’ or ‘eval’ to prefix metric names.

  • **extras – Additional metrics to include.

Returns

Comprehensive metrics including:
  • Basic metrics (loss, perplexity, accuracy)

  • Performance metrics (TFLOPs, throughput)

  • MLPerf benchmark metrics

  • Optional gradient norms and detailed statistics

Return type

dict

Note

In performance mode, detailed metrics are skipped for efficiency.

start_step()[source]#

Mark the start of a training step.

Records the current time for step duration calculation. Should be called at the beginning of each training/evaluation step.

class easydel.trainers.metrics.TqdmProgressBar(pbar: tqdm)[source]#

Bases: BaseProgressBar

Wrapper for tqdm progress bar.

Adapts tqdm progress bars to the BaseProgressBar interface.

pbar#

Underlying tqdm progress bar instance.

close() None[source]#

Close and cleanup the progress bar.

reset() None[source]#

Reset the progress bar to initial state.

set_postfix(**kwargs) None[source]#

Set postfix metrics to display.

Parameters

**kwargs – Metric key-value pairs to display.

update(n: int = 1) None[source]#

Update the progress bar.

Parameters

n – Number of steps to advance.

easydel.trainers.metrics.compute_weight_stats(params: dict[str, Any], repattern: str) dict[str, easydel.trainers.metrics.MetricsHistogram][source]#

Compute statistics for model weights in a JIT-compatible way.

Analyzes model parameters matching the given pattern and computes histograms and statistical measures for monitoring training stability.

Parameters
  • params – Model parameters as nested dictionary or PyTree.

  • repattern – Regular expression pattern to match parameter paths. Use ‘.*’ to match all parameters.

Returns

Weight statistics with keys formatted as ‘path/to/param/histogram’

containing MetricsHistogram objects.

Return type

dict

Note

JIT-compiled with static pattern argument for efficiency. Useful for detecting gradient explosion, vanishing gradients, or monitoring weight distributions during training.

Example

>>> stats = compute_weight_stats(model.params, r'.*dense.*')
>>> # Gets statistics for all dense layer weights