Source code for easydel.inference.esurge.metrics

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""eSurge Metrics Collection System."""

from __future__ import annotations

import json
import logging
import time
from collections import defaultdict, deque
from dataclasses import asdict, dataclass
from pathlib import Path
from threading import Lock
from typing import Any

from jax import numpy as jnp


[docs]@dataclass class RequestMetrics: """Metrics for a single request.""" request_id: str start_time: float end_time: float | None = None first_token_time: float | None = None prompt_tokens: int = 0 generated_tokens: int = 0 total_tokens: int = 0 tokens_per_second: float = 0.0 time_to_first_token: float | None = None total_latency: float | None = None finish_reason: str | None = None error: str | None = None
[docs]@dataclass class SchedulerMetrics: """Metrics for scheduler operations.""" timestamp: float num_waiting_requests: int = 0 num_running_requests: int = 0 num_scheduled_tokens: int = 0 num_preempted_requests: int = 0 batch_size: int = 0 schedule_time: float = 0.0
[docs]@dataclass class ModelRunnerMetrics: """Metrics for model runner operations.""" timestamp: float execution_time: float = 0.0 batch_size: int = 0 num_tokens: int = 0 tokens_per_second: float = 0.0 memory_usage: dict[str, Any] | None = None
[docs]@dataclass class CacheMetrics: """Metrics for KV cache operations.""" timestamp: float total_pages: int = 0 used_pages: int = 0 free_pages: int = 0 cache_hit_rate: float = 0.0 page_allocation_rate: float = 0.0 page_free_rate: float = 0.0
[docs]@dataclass class SystemMetrics: """System-wide metrics summary.""" timestamp: float total_requests_completed: int = 0 total_requests_failed: int = 0 total_tokens_generated: int = 0 average_latency: float = 0.0 average_ttft: float = 0.0 average_throughput: float = 0.0 requests_per_second: float = 0.0
[docs]class MetricsCollector: """Centralized metrics collection and logging system for eSurge.""" def __init__( self, log_file: str | None = None, log_interval: float = 10.0, history_size: int = 1000, enable_detailed_logging: bool = True, ): """Initialize the metrics collector. Args: log_file: Path to metrics log file (JSON format) log_interval: Interval in seconds to log summary metrics history_size: Number of metrics records to keep in memory enable_detailed_logging: Whether to log detailed per-request metrics """ self.log_file = log_file self.log_interval = log_interval self.history_size = history_size self.enable_detailed_logging = enable_detailed_logging # Thread-safe data structures self._lock = Lock() # Metrics storage self.request_metrics: dict[str, RequestMetrics] = {} self.completed_requests: deque[RequestMetrics] = deque(maxlen=history_size) self.scheduler_metrics: deque[SchedulerMetrics] = deque(maxlen=history_size) self.runner_metrics: deque[ModelRunnerMetrics] = deque(maxlen=history_size) self.cache_metrics: deque[CacheMetrics] = deque(maxlen=history_size) # Counters and aggregates self.counters = defaultdict(int) self.timers = defaultdict(list) # Last log time self.last_log_time = time.time() # Setup logging self.logger = logging.getLogger("eSurge.metrics") if log_file: handler = logging.FileHandler(log_file) handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s")) self.logger.addHandler(handler) self.logger.setLevel(logging.INFO)
[docs] def start_request(self, request_id: str, prompt_tokens: int = 0) -> None: """Start tracking metrics for a new request.""" with self._lock: self.request_metrics[request_id] = RequestMetrics( request_id=request_id, start_time=time.time(), prompt_tokens=prompt_tokens, total_tokens=prompt_tokens, ) self.counters["total_requests"] += 1
[docs] def record_first_token(self, request_id: str) -> None: """Record when the first token is generated for a request.""" with self._lock: if request_id in self.request_metrics: current_time = time.time() metrics = self.request_metrics[request_id] metrics.first_token_time = current_time metrics.time_to_first_token = current_time - metrics.start_time
[docs] def add_generated_tokens(self, request_id: str, num_tokens: int) -> None: """Add generated tokens to a request's metrics.""" with self._lock: if request_id in self.request_metrics: metrics = self.request_metrics[request_id] metrics.generated_tokens += num_tokens metrics.total_tokens = metrics.prompt_tokens + metrics.generated_tokens
[docs] def complete_request( self, request_id: str, finish_reason: str | None = None, error: str | None = None, ) -> None: """Complete tracking for a request.""" with self._lock: if request_id not in self.request_metrics: return metrics = self.request_metrics[request_id] metrics.end_time = time.time() metrics.total_latency = metrics.end_time - metrics.start_time metrics.finish_reason = finish_reason metrics.error = error # Calculate tokens per second if metrics.generated_tokens > 0 and metrics.time_to_first_token is not None: generation_time = metrics.total_latency - metrics.time_to_first_token if generation_time > 0: metrics.tokens_per_second = metrics.generated_tokens / generation_time # Move to completed requests self.completed_requests.append(metrics) del self.request_metrics[request_id] # Update counters if error: self.counters["total_failed"] += 1 else: self.counters["total_completed"] += 1 self.counters["total_tokens_generated"] += metrics.generated_tokens
[docs] def record_scheduler_metrics( self, num_waiting: int, num_running: int, num_scheduled_tokens: int, num_preempted: int = 0, batch_size: int = 0, schedule_time: float = 0.0, ) -> None: """Record scheduler performance metrics.""" with self._lock: metrics = SchedulerMetrics( timestamp=time.time(), num_waiting_requests=num_waiting, num_running_requests=num_running, num_scheduled_tokens=num_scheduled_tokens, num_preempted_requests=num_preempted, batch_size=batch_size, schedule_time=schedule_time, ) self.scheduler_metrics.append(metrics)
[docs] def record_runner_metrics( self, execution_time: float, batch_size: int, num_tokens: int, memory_usage: dict[str, Any] | None = None, ) -> None: """Record model runner performance metrics.""" with self._lock: tokens_per_second = num_tokens / execution_time if execution_time > 0 else 0 metrics = ModelRunnerMetrics( timestamp=time.time(), execution_time=execution_time, batch_size=batch_size, num_tokens=num_tokens, tokens_per_second=tokens_per_second, memory_usage=memory_usage, ) self.runner_metrics.append(metrics)
[docs] def record_cache_metrics( self, total_pages: int, used_pages: int, cache_hit_rate: float = 0.0, page_allocation_rate: float = 0.0, page_free_rate: float = 0.0, ) -> None: """Record KV cache metrics.""" with self._lock: metrics = CacheMetrics( timestamp=time.time(), total_pages=total_pages, used_pages=used_pages, free_pages=total_pages - used_pages, cache_hit_rate=cache_hit_rate, page_allocation_rate=page_allocation_rate, page_free_rate=page_free_rate, ) self.cache_metrics.append(metrics)
[docs] def record_cache_event(self, event: str, details: dict[str, Any] | None = None) -> None: """Record lifecycle events for the KV cache.""" with self._lock: self.counters[f"cache_event_{event}"] += 1
[docs] def get_system_metrics(self, window_seconds: float = 60.0) -> SystemMetrics: """Get aggregated system metrics for the specified time window.""" current_time = time.time() cutoff_time = current_time - window_seconds with self._lock: # Filter recent completed requests recent_requests = [req for req in self.completed_requests if req.end_time and req.end_time >= cutoff_time] if not recent_requests: return SystemMetrics(timestamp=current_time) # Calculate aggregates total_completed = len(recent_requests) total_failed = len([req for req in recent_requests if req.error]) total_tokens = sum(req.generated_tokens for req in recent_requests) latencies = [req.total_latency for req in recent_requests if req.total_latency] ttfts = [req.time_to_first_token for req in recent_requests if req.time_to_first_token] throughputs = [req.tokens_per_second for req in recent_requests if req.tokens_per_second > 0] avg_latency = jnp.mean(jnp.array(latencies)) if latencies else 0.0 avg_ttft = jnp.mean(jnp.array(ttfts)) if ttfts else 0.0 avg_throughput = jnp.mean(jnp.array(throughputs)) if throughputs else 0.0 requests_per_second = total_completed / window_seconds return SystemMetrics( timestamp=current_time, total_requests_completed=total_completed, total_requests_failed=total_failed, total_tokens_generated=total_tokens, average_latency=avg_latency, average_ttft=avg_ttft, average_throughput=avg_throughput, requests_per_second=requests_per_second, )
[docs] def log_summary(self, force: bool = False) -> None: """Log a summary of current metrics.""" current_time = time.time() if not force and current_time - self.last_log_time < self.log_interval: return with self._lock: system_metrics = self.get_system_metrics() # Get latest metrics from each category latest_scheduler = self.scheduler_metrics[-1] if self.scheduler_metrics else None latest_runner = self.runner_metrics[-1] if self.runner_metrics else None latest_cache = self.cache_metrics[-1] if self.cache_metrics else None summary = { "timestamp": current_time, "system": asdict(system_metrics), "scheduler": asdict(latest_scheduler) if latest_scheduler else None, "runner": asdict(latest_runner) if latest_runner else None, "cache": asdict(latest_cache) if latest_cache else None, "active_requests": len(self.request_metrics), } if self.enable_detailed_logging and self.logger: self.logger.info(f"METRICS_SUMMARY: {json.dumps(summary)}") self.last_log_time = current_time
[docs] def export_metrics(self, file_path: str, format: str = "json") -> None: # noqa """Export all metrics to a file.""" with self._lock: data = { "timestamp": time.time(), "system_metrics": asdict(self.get_system_metrics()), "completed_requests": [asdict(req) for req in self.completed_requests], "scheduler_metrics": [asdict(m) for m in self.scheduler_metrics], "runner_metrics": [asdict(m) for m in self.runner_metrics], "cache_metrics": [asdict(m) for m in self.cache_metrics], "counters": dict(self.counters), } path = Path(file_path) if format.lower() == "json": with open(path, "w") as f: json.dump(data, f, indent=2) else: raise ValueError(f"Unsupported format: {format}")
[docs] def reset_metrics(self) -> None: """Reset all metrics and counters.""" with self._lock: self.request_metrics.clear() self.completed_requests.clear() self.scheduler_metrics.clear() self.runner_metrics.clear() self.cache_metrics.clear() self.counters.clear() self.timers.clear()
# Global metrics collector instance _global_metrics_collector: MetricsCollector | None = None
[docs]def get_metrics_collector() -> MetricsCollector | None: """Get the global metrics collector instance.""" return _global_metrics_collector
[docs]def initialize_metrics( log_file: str | None = None, log_interval: float = 10.0, history_size: int = 1000, enable_detailed_logging: bool = True, ) -> MetricsCollector: """Initialize the global metrics collector.""" global _global_metrics_collector _global_metrics_collector = MetricsCollector( log_file=log_file, log_interval=log_interval, history_size=history_size, enable_detailed_logging=enable_detailed_logging, ) return _global_metrics_collector
[docs]def log_metrics_summary() -> None: """Log a summary of current metrics if collector is initialized.""" if _global_metrics_collector: _global_metrics_collector.log_summary()