easydel.inference.vinference.metrics#
- class easydel.inference.vinference.metrics.ModelMetadata(batch_size: int, sequence_length: int, dtype: str, platfrom: str)[source]#
Bases:
objectA dataclass to hold basic metadata about the loaded model and runtime environment.
This information can be useful for logging, monitoring, or debugging purposes.
- batch_size#
The batch size used for inference.
- Type
int
- sequence_length#
The maximum sequence length the model is configured for.
- Type
int
- dtype#
The data type (e.g., âfloat16â, âbfloat16â) used for model parameters/computation.
- Type
str
- platform#
The JAX platform being used (e.g., âcpuâ, âgpuâ, âtpuâ).
- batch_size: int#
- dtype: str#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- platfrom: str#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sequence_length: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.inference.vinference.metrics.vInferenceMetrics(model_name: str)[source]#
Bases:
objectManages and exposes Prometheus metrics for monitoring the vInference engine.
This class initializes various Prometheus metric objects (Counters, Gauges, Histograms, Info) to track key performance indicators and resource usage of a specific model during inference.
It provides decorators (track_compilation, measure_inference_first_step, measure_inference_afterward) to easily instrument functions and record relevant metrics.
It also starts background threads for monitoring system resources like JAX device memory and host memory if running in a local environment.
- model_name#
The sanitized name of the model used in metric labels.
- Type
str
- inference_requests#
Tracks the total number of inference requests, labeled by status (success/error).
- Type
Counter
- inference_latency#
Measures the latency distribution of different inference stages (preprocessing, inference, postprocessing).
- Type
Histogram
- queue_size#
Tracks the current number of requests waiting in the queue.
- Type
Gauge
- jax_memory_used#
Tracks the current and peak JAX memory usage per device.
- Type
Gauge
- host_memory_used#
Tracks the host systemâs memory usage (total, available, used).
- Type
Gauge
- token_throughput#
Counts the total number of tokens processed, labeled by the operation (e.g., âprefillâ, âdecodeâ).
- Type
Counter
- generation_length#
Measures the distribution of generated sequence lengths.
- Type
Histogram
- compilation_time#
Measures the time spent on JAX function compilation, labeled by the function name.
- Type
Histogram
- model_info#
Stores static model configuration metadata.
- Type
Info
- measure_inference_afterward(func)[source]#
Returns a decorator to measure metrics for subsequent inference steps (after the first step, often corresponding to decode steps).
This decorator behaves similarly to measure_inference_first_step but uses a separate latency histogram (afterward_inference_latency).
It: - Increments/decrements the queue_size gauge upon entry/exit. - Measures the latency of the decorated function using the
afterward_inference_latency histogram (labeled with stage=âinferenceâ). (Note: Assumes a self.afterward_inference_latency histogram exists, which
might need to be added for separate afterward step latency tracking).
Increments the inference_requests counter (only on errors, success is likely tracked elsewhere or per-token).
Performs garbage collection upon exit.
- Parameters
func â The function representing a subsequent inference step to be measured.
- Returns
A decorator function.
- measure_inference_first_step(func)[source]#
Returns a decorator to measure metrics specifically for the first step of an inference process (often corresponds to prefill).
This decorator: - Increments/decrements the queue_size gauge upon entry/exit. - Measures the latency of the decorated function using the
first_step_inference_latency histogram (labeled with stage=âinferenceâ). (Note: Assumes a self.first_step_inference_latency histogram exists, which
might need to be added if different latency tracking for the first step is desired).
Increments the inference_requests counter (status=âsuccessâ or âerrorâ).
Performs garbage collection upon exit.
- Parameters
func â The function representing the first inference step to be measured.
- Returns
A decorator function.
- record_model_metadata(metadata: ModelMetadata)[source]#
Records static model metadata using the Prometheus Info metric.
- Parameters
metadata â A ModelMetadata object containing information like batch size, sequence length, dtype, and platform.
- track_compilation(function_name: str)[source]#
Returns a decorator to measure and record the compilation time of a JAX function.
Usage: ```python metrics = vInferenceMetrics(âmy_modelâ)
@metrics.track_compilation(âmy_compiled_functionâ) @jax.jit def my_compiled_function(x):
# ⊠JAX computation ⊠return x * 2
- Parameters
function_name â A descriptive name for the function being compiled, used as a label in the compilation_time metric.
- Returns
A decorator function.