easydel.trainers.metrics#
Metrics tracking and visualization for EasyDeL training.
This module provides comprehensive metrics collection, aggregation, and visualization tools for training large language models. It includes: - Real-time metrics calculation (loss, accuracy, throughput, TFLOPs) - Progress bar implementations (tqdm, rich, JSON) - Weight distribution analysis and histograms - Performance profiling and benchmarking utilities
- class easydel.trainers.metrics.BaseProgressBar[source]#
Bases:
ABCAbstract base class for progress bar implementations.
Defines the interface for different progress bar backends (tqdm, rich, JSON logging).
- class easydel.trainers.metrics.JSONProgressBar(desc='')[source]#
Bases:
BaseProgressBarJSON-based progress reporting.
Outputs progress as JSON logs instead of a visual progress bar. Useful for structured logging and CI/CD environments.
- desc#
Description text for the progress.
- class easydel.trainers.metrics.MetricsColumn(metrics_to_show=None)[source]#
Bases:
ProgressColumnA custom Rich progress column for displaying metrics.
Formats and displays training metrics in a readable format within Rich progress bars.
- metrics_to_show#
Optional list of metric names to display. If None, shows all metrics.
- class easydel.trainers.metrics.MetricsHistogram(bin_counts: Array, bin_edges: Array, size: int, min: Array, max: Array, sum: Array, sum_squares: Array)[source]#
Bases:
objectCompute and store histogram data for model weights or activations.
This class provides a PyTree-compatible way to compute histograms and statistics for JAX arrays, optimized for use within JIT-compiled functions.
- classmethod from_array(arr: Array) MetricsHistogram[source]#
Create a histogram from an array.
- Parameters
arr – Input array
- Returns
MetricsHistogram instance
- classmethod from_dict(data: dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- numpy_histogram() tuple[jax.Array, jax.Array][source]#
Return histogram data in numpy-compatible format.
- Returns
Tuple of (bin_counts, bin_edges)
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- size: int#
- property std: Array#
Calculate standard deviation of the original array.
- Returns
Standard deviation value
- to_dict() dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.trainers.metrics.MetricsTracker[source]#
Bases:
objectTracks and aggregates training metrics over time.
Maintains running averages of loss and accuracy across training steps, useful for monitoring training progress and convergence.
- loss_sum#
Cumulative loss sum.
- accuracy_sum#
Cumulative accuracy sum.
- metrics_history#
Historical metrics for analysis.
- step_offset#
Step number offset for averaging.
- reset(step)[source]#
Reset tracked metrics.
- Parameters
step – New step offset for averaging.
Note
Typically called at the start of each epoch or evaluation phase.
- update(loss, accuracy, step)[source]#
Update tracked metrics with new values.
- Parameters
loss – Current step loss.
accuracy – Current step accuracy (can be None or inf).
step – Current step number.
- Returns
- (mean_loss, mean_accuracy) if accuracy is valid,
otherwise just mean_loss.
- Return type
tuple | float
Note
Handles missing accuracy values gracefully.
- class easydel.trainers.metrics.NullProgressBar[source]#
Bases:
BaseProgressBarDummy progress bar that does nothing.
Useful for multiprocessing scenarios where only the main process should display progress, or when progress display is disabled.
- class easydel.trainers.metrics.RichProgressBar(progress: Progress, task_id: TaskID)[source]#
Bases:
BaseProgressBarWrapper for Rich library progress bar.
Provides beautiful, customizable progress bars with support for multiple columns and custom rendering.
- progress#
Rich Progress instance.
- task_id#
ID of the task being tracked.
- _postfix#
Current postfix metrics.
- class easydel.trainers.metrics.StepMetrics(arguments)[source]#
Bases:
objectHandles calculation and tracking of training metrics.
This class computes various metrics for each training/evaluation step, including loss, accuracy, performance metrics (TFLOPs, throughput), and gradient statistics.
- arguments#
Training configuration arguments.
- start_time#
Global start time for training.
- step_start_time#
Start time for current step.
Note
Designed to work efficiently within JAX training loops. Automatically handles metric aggregation across devices.
- calculate(metrics: LossMetrics, current_step: int, epoch: int, flops_per_token: float, extra_flops_per_token: float, batch_size: int, seq_length: int, learning_rate: float, mode: Optional[Literal['eval', 'train']] = None, **extras) dict[str, float][source]#
Calculate comprehensive metrics for the training step.
Computes performance metrics, loss statistics, and optional detailed metrics like gradient norms.
- Parameters
metrics – Loss metrics from the training step.
current_step – Current training/evaluation step number.
epoch – Current epoch number.
flops_per_token – FLOPs required per token for forward pass.
extra_flops_per_token – Additional FLOPs for backward pass.
batch_size – Number of samples in the batch.
seq_length – Sequence length of inputs.
learning_rate – Current learning rate value.
mode – ‘train’ or ‘eval’ to prefix metric names.
**extras – Additional metrics to include.
- Returns
- Comprehensive metrics including:
Basic metrics (loss, perplexity, accuracy)
Performance metrics (TFLOPs, throughput)
MLPerf benchmark metrics
Optional gradient norms and detailed statistics
- Return type
dict
Note
In performance mode, detailed metrics are skipped for efficiency.
- class easydel.trainers.metrics.TqdmProgressBar(pbar: tqdm)[source]#
Bases:
BaseProgressBarWrapper for tqdm progress bar.
Adapts tqdm progress bars to the BaseProgressBar interface.
- pbar#
Underlying tqdm progress bar instance.
- easydel.trainers.metrics.compute_weight_stats(params: dict[str, Any], repattern: str) dict[str, easydel.trainers.metrics.MetricsHistogram][source]#
Compute statistics for model weights in a JIT-compatible way.
Analyzes model parameters matching the given pattern and computes histograms and statistical measures for monitoring training stability.
- Parameters
params – Model parameters as nested dictionary or PyTree.
repattern – Regular expression pattern to match parameter paths. Use ‘.*’ to match all parameters.
- Returns
- Weight statistics with keys formatted as ‘path/to/param/histogram’
containing MetricsHistogram objects.
- Return type
dict
Note
JIT-compiled with static pattern argument for efficiency. Useful for detecting gradient explosion, vanishing gradients, or monitoring weight distributions during training.
Example
>>> stats = compute_weight_stats(model.params, r'.*dense.*') >>> # Gets statistics for all dense layer weights