Source code for easydel.utils.helpers

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper utilities for EasyDeL framework.

Provides logging, timing, caching, and general utility functions used
throughout the EasyDeL framework.

Classes:
    ColorFormatter: Colored console logging formatter
    LazyLogger: Deferred initialization logger
    Timer: Simple timing utility
    Timers: Multiple timer management with logging
    DummyStream: Null output stream for suppression

Functions:
    get_logger: Create a lazy logger instance
    set_loggers_level: Set logging level globally
    capture_time: Context manager for timing
    get_cache_dir: Get EasyDeL cache directory
    quiet: Context manager to suppress output
    check_bool_flag: Parse boolean environment variables

Constants:
    COLORS: Terminal color codes
    LEVEL_COLORS: Log level to color mapping
    _LOGGING_LEVELS: String to log level mapping

Example:
    >>> from easydel.utils.helpers import Timer
    >>>
    >>> with Timer("computation") as timer:
    ...     result = expensive_computation()
    >>> print(f"Took {timer.elapsed_time()} seconds")

"""

from __future__ import annotations

import contextlib
import os
import sys
import time
import typing as tp
import warnings
from contextlib import contextmanager
from pathlib import Path

from eformer.loggings import get_logger

warnings.filterwarnings("ignore", message=".*'repr' attribute.*", category=UserWarning)
warnings.filterwarnings("ignore", message=".*'frozen' attribute.*", category=UserWarning)

if tp.TYPE_CHECKING:
    from flax.metrics.tensorboard import SummaryWriter
try:
    import wandb  # type: ignore
except ModuleNotFoundError:
    wandb = None

logger = get_logger(__name__)


[docs]@contextlib.contextmanager def capture_time(): """Context manager that measures elapsed time. Yields a callable that returns the current elapsed time in seconds. The timer continues running until the context exits. Yields: Callable that returns elapsed time in seconds. Example: >>> with capture_time() as get_time: ... # Do some work ... print(f"After 1 second: {get_time()}") ... # Do more work ... print(f"After 2 seconds: {get_time()}") >>> print(f"Total time: {get_time()}") """ start = time.perf_counter_ns() is_active = True def get_elapsed(): """Get elapsed time in seconds.""" current = time.perf_counter_ns() if is_active else end return (current - start) / 1e9 try: yield get_elapsed finally: end = time.perf_counter_ns() is_active = False
[docs]class Timer: """Simple timer for measuring execution time. Can be used as a context manager or manually with start/stop methods. Accumulates time across multiple start/stop cycles. Attributes: name: Timer name for identification. elapsed: Total elapsed time in seconds. started: Whether timer is currently running. start_time: Start time of current cycle. Example: >>> timer = Timer("training") >>> timer.start() >>> # Do work >>> timer.stop() >>> print(f"Elapsed: {timer.elapsed_time()} seconds") >>> >>> # Or as context manager >>> with Timer("inference") as t: ... result = model(input) """ def __init__(self, name): """Initialize Timer. Args: name: Name for this timer. """ self.name = name self.elapsed = 0.0 self.started = False self.start_time = 0.0
[docs] def start(self): """Start the timer. Raises: RuntimeError: If timer is already running. """ if self.started: raise RuntimeError(f"Timer '{self.name}' is already running") self.start_time = time.time() self.started = True
[docs] def stop(self): """Stop the timer and accumulate elapsed time. Raises: RuntimeError: If timer is not running. """ if not self.started: raise RuntimeError(f"Timer '{self.name}' is not running") self.elapsed += time.time() - self.start_time self.started = False
[docs] def reset(self): self.elapsed = 0.0 self.started = False self.start_time = 0.0
[docs] def elapsed_time(self, reset=True): """Get total elapsed time. Args: reset: Whether to reset timer after reading. Returns: Total elapsed time in seconds. """ if self.started: self.stop() total_time = self.elapsed if reset: self.reset() return total_time
def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop()
[docs]class Timers: """Manager for multiple named timers with logging support. Manages a collection of timers and integrates with logging backends like Weights & Biases and TensorBoard for metrics tracking. Attributes: timers: Dictionary of timer instances. use_wandb: Whether to log to Weights & Biases. tensorboard_writer: TensorBoard summary writer. Example: >>> timers = Timers(use_wandb=True, tensorboard_writer=writer) >>> with timers.timed("forward_pass"): ... output = model(input) >>> timers.write(["forward_pass"], iteration=100) """ def __init__(self, use_wandb, tensorboard_writer: SummaryWriter): """Initialize Timers. Args: use_wandb: Enable Weights & Biases logging. tensorboard_writer: TensorBoard writer instance. """ self.timers = {} self.use_wandb = use_wandb self.tensorboard_writer = tensorboard_writer def __call__(self, name): """Get or create a timer by name. Args: name: Timer name. Returns: Timer instance. """ if name not in self.timers: self.timers[name] = Timer(name) return self.timers[name]
[docs] def write(self, names, iteration, normalizer=1.0, reset=False): assert normalizer > 0.0 for name in names: value = self.timers[name].elapsed_time(reset=reset) / normalizer if self.tensorboard_writer: self.tensorboard_writer.scalar(f"timers/{name}", value, iteration) if self.use_wandb: if wandb is None: warnings.warn( "`wandb` is not installed use `pip install wandb` (use_wandb=True will be ignored)", stacklevel=1, ) self.use_wandb = False else: wandb.log({f"timers/{name}": value}, step=iteration)
[docs] def log(self, names, normalizer=1.0, reset=True): assert normalizer > 0.0 if isinstance(names, str): names = [names] for name in names: elapsed_time = self.timers[name].elapsed_time(reset=reset) * 1000.0 / normalizer self._print_log(name, elapsed_time)
def _print_log(self, name, elapsed_time): if elapsed_time < 1000: time_str = f"{elapsed_time:.4f} ms" color = "\033[94m" # Blue elif elapsed_time < 60000: time_str = f"{elapsed_time / 1000:.4f} sec" color = "\033[92m" # Green elif elapsed_time < 3600000: time_str = f"{elapsed_time / 60000:.4f} min" color = "\033[93m" # Yellow else: time_str = f"{elapsed_time / 3600000:.4f} hr" color = "\033[91m" # Red logger.info(f"time took for {name} : {color}{time_str}\033[0m")
[docs] @contextlib.contextmanager def timed(self, name, log=True, reset=True): timer = self(name) try: timer.start() yield timer finally: timer.stop() if log: elapsed_time = timer.elapsed_time(reset=reset) * 1000.0 # Convert to milliseconds self._print_log(name, elapsed_time)
[docs]def get_cache_dir() -> Path: """Get the EasyDeL cache directory. Returns the platform-specific cache directory for EasyDeL. Creates the directory if it doesn't exist. Returns: Path to the cache directory. Example: >>> cache_dir = get_cache_dir() >>> print(cache_dir) /home/user/.cache/easydel """ home_dir = Path.home() app_name = "easydel" if os.name == "nt": # Windows cache_dir = Path(os.getenv("LOCALAPPDATA", home_dir / "AppData" / "Local")) / app_name elif os.name == "posix": # Linux and macOS if "darwin" in os.sys.platform: # macOS cache_dir = home_dir / "Library" / "Caches" / app_name else: # Linux cache_dir = home_dir / ".cache" / app_name else: cache_dir = home_dir / ".cache" / app_name cache_dir.mkdir(parents=True, exist_ok=True) return cache_dir
[docs]class DummyStream: """A null device-like stream that discards all writes. Used for suppressing output by replacing stdout/stderr. All write and flush operations are no-ops. """
[docs] def write(self, *args, **kwargs): """Discard all write operations.""" pass
[docs] def flush(self, *args, **kwargs): """Discard all flush operations.""" pass
[docs]@contextmanager def quiet(suppress_stdout=True, suppress_stderr=True): """Context manager to temporarily suppress stdout and/or stderr output. Replaces stdout/stderr with null streams to discard all output. Restores original streams on exit. Args: suppress_stdout: Whether to suppress stdout. suppress_stderr: Whether to suppress stderr. Yields: None Example: >>> with quiet(): ... print("This won't be displayed") ... noisy_function() >>> print("This will be displayed") Note: This will suppress ALL output to the specified streams within the context, including output from C extensions and system calls. """ original_stdout = sys.stdout original_stderr = sys.stderr try: if suppress_stdout: sys.stdout = DummyStream() if suppress_stderr: sys.stderr = DummyStream() yield finally: if suppress_stdout: sys.stdout = original_stdout if suppress_stderr: sys.stderr = original_stderr
[docs]def check_bool_flag(name: str, default: bool = True) -> bool: """Parse boolean environment variable. Interprets various string representations as boolean values. Accepts: 'true', 'yes', 'ok', '1', 'easy' (case-insensitive). Args: name: Environment variable name. default: Default value if variable not set. Returns: Boolean interpretation of the environment variable. Example: >>> os.environ['DEBUG'] = 'yes' >>> check_bool_flag('DEBUG') True >>> check_bool_flag('MISSING', default=False) False """ default = "1" if default else "0" return str(os.getenv(name, default)).lower() in ["true", "yes", "ok", "1", "easy"]