easydel.inference.vinference.metrics#

class easydel.inference.vinference.metrics.ModelMetadata(batch_size: int, sequence_length: int, dtype: str, platfrom: str)[source]#

Bases: object

A dataclass to hold basic metadata about the loaded model and runtime environment.

This information can be useful for logging, monitoring, or debugging purposes.

batch_size#

The batch size used for inference.

Type

int

sequence_length#

The maximum sequence length the model is configured for.

Type

int

dtype#

The data type (e.g., ‘float16’, ‘bfloat16’) used for model parameters/computation.

Type

str

platform#

The JAX platform being used (e.g., ‘cpu’, ‘gpu’, ‘tpu’).

batch_size: int#
dtype: str#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

platfrom: str#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sequence_length: int#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.inference.vinference.metrics.vInferenceMetrics(model_name: str)[source]#

Bases: object

Manages and exposes Prometheus metrics for monitoring the vInference engine.

This class initializes various Prometheus metric objects (Counters, Gauges, Histograms, Info) to track key performance indicators and resource usage of a specific model during inference.

It provides decorators (track_compilation, measure_inference_first_step, measure_inference_afterward) to easily instrument functions and record relevant metrics.

It also starts background threads for monitoring system resources like JAX device memory and host memory if running in a local environment.

model_name#

The sanitized name of the model used in metric labels.

Type

str

inference_requests#

Tracks the total number of inference requests, labeled by status (success/error).

Type

Counter

inference_latency#

Measures the latency distribution of different inference stages (preprocessing, inference, postprocessing).

Type

Histogram

queue_size#

Tracks the current number of requests waiting in the queue.

Type

Gauge

jax_memory_used#

Tracks the current and peak JAX memory usage per device.

Type

Gauge

host_memory_used#

Tracks the host system’s memory usage (total, available, used).

Type

Gauge

token_throughput#

Counts the total number of tokens processed, labeled by the operation (e.g., ‘prefill’, ‘decode’).

Type

Counter

generation_length#

Measures the distribution of generated sequence lengths.

Type

Histogram

compilation_time#

Measures the time spent on JAX function compilation, labeled by the function name.

Type

Histogram

model_info#

Stores static model configuration metadata.

Type

Info

measure_inference_afterward(func)[source]#

Returns a decorator to measure metrics for subsequent inference steps (after the first step, often corresponding to decode steps).

This decorator behaves similarly to measure_inference_first_step but uses a separate latency histogram (afterward_inference_latency).

It: - Increments/decrements the queue_size gauge upon entry/exit. - Measures the latency of the decorated function using the

afterward_inference_latency histogram (labeled with stage=’inference’). (Note: Assumes a self.afterward_inference_latency histogram exists, which

might need to be added for separate afterward step latency tracking).

  • Increments the inference_requests counter (only on errors, success is likely tracked elsewhere or per-token).

  • Performs garbage collection upon exit.

Parameters

func – The function representing a subsequent inference step to be measured.

Returns

A decorator function.

measure_inference_first_step(func)[source]#

Returns a decorator to measure metrics specifically for the first step of an inference process (often corresponds to prefill).

This decorator: - Increments/decrements the queue_size gauge upon entry/exit. - Measures the latency of the decorated function using the

first_step_inference_latency histogram (labeled with stage=’inference’). (Note: Assumes a self.first_step_inference_latency histogram exists, which

might need to be added if different latency tracking for the first step is desired).

  • Increments the inference_requests counter (status=’success’ or ‘error’).

  • Performs garbage collection upon exit.

Parameters

func – The function representing the first inference step to be measured.

Returns

A decorator function.

record_model_metadata(metadata: ModelMetadata)[source]#

Records static model metadata using the Prometheus Info metric.

Parameters

metadata – A ModelMetadata object containing information like batch size, sequence length, dtype, and platform.

track_compilation(function_name: str)[source]#

Returns a decorator to measure and record the compilation time of a JAX function.

Usage: ```python metrics = vInferenceMetrics(“my_model”)

@metrics.track_compilation(“my_compiled_function”) @jax.jit def my_compiled_function(x):

# 
 JAX computation 
 return x * 2

```

Parameters

function_name – A descriptive name for the function being compiled, used as a label in the compilation_time metric.

Returns

A decorator function.