easydel.inference.esurge.esurge_engine

Contents

easydel.inference.esurge.esurge_engine#

eSurge Engine - High-Performance Inference Engine for EasyDeL.

This module provides the eSurge engine, a high-performance text generation system built on JAX that offers efficient batched inference with advanced features like continuous batching and comprehensive monitoring.

Key Components:
  • eSurge: Main engine class for text generation

  • RequestOutput: Container for generation results and metrics

  • CompletionOutput: Individual completion within a batch

Features:
  • Continuous Batching: Background scheduler thread processes requests continuously for optimal throughput.

  • Context Management: Automatic prompt truncation and token reservation with configurable strategies.

  • Streaming Support: Real-time token streaming with delta updates.

  • Monitoring: Built-in Prometheus metrics and console monitor (Grafana-ready).

Usage Example:
>>> from easydel.inference.esurge import eSurge
>>> from easydel.inference.sampling_params import SamplingParams
>>>
>>> # Initialize engine
>>> engine = eSurge(
...     model="model-name",
...     max_model_len=8192,
...     reserve_tokens=800
... )
>>>
>>> # Stream generation
>>> for output in engine.stream("Tell me about AI"):
...     print(output.delta_text, end="", flush=True)
>>>
>>> # Batch generation
>>> outputs = engine.generate(
...     ["Question 1?", "Question 2?"],
...     SamplingParams(max_tokens=100, temperature=0.7)
... )
Technical Details:

The engine uses a multi-threaded architecture with: - Main thread: Handles API calls and request submission - Scheduler thread: Continuously processes queued requests - JAX computation: Executes model forward passes

class easydel.inference.esurge.esurge_engine.CompletionOutput(index: int, text: str, token_ids: list[int], cumulative_logprob: float | None = None, logprobs: list[dict[int, float]] | None = None, finish_reason: str | None = None)[source]#

Bases: object

Output of a single completion.

Represents the generated output for a single completion within a batch request. Contains the generated text, token IDs, and optional probability information.

index#

Position of this completion in the batch (0-indexed).

Type

int

text#

The generated text string.

Type

str

token_ids#

List of token IDs that were generated.

Type

list[int]

cumulative_logprob#

Cumulative log probability of the generated sequence.

Type

float | None

logprobs#

Per-token log probabilities as dict mapping token_id to logprob.

Type

list[dict[int, float]] | None

finish_reason#

Reason for completion termination (‘stop’, ‘length’, ‘eos_token’, etc.).

Type

str | None

cumulative_logprob: float | None = None#
finish_reason: str | None = None#
index: int#
logprobs: list[dict[int, float]] | None = None#
text: str#
token_ids: list[int]#
class easydel.inference.esurge.esurge_engine.RequestOutput(request_id: str, prompt: str | list[str], prompt_token_ids: list[list[int]] | list[int], outputs: list[easydel.inference.esurge.esurge_engine.CompletionOutput], finished: bool = False, metrics: dict[str, Any] | None = None, accumulated_text: str = '', delta_text: str = '', tokens_per_second: float = 0.0, num_generated_tokens: int = 0, time_spent_generating: float = 0.0, first_token_time: float | None = None, processing_time: float = 0.0, update_seq: int = 0, delta_seq: int = 0)[source]#

Bases: object

Output of a generation request with comprehensive metrics.

Contains the complete output for a generation request including generated text, performance metrics, and streaming support fields. Used for both batch and streaming generation modes.

request_id#

Unique identifier for this request.

Type

str

prompt#

Original prompt text.

Type

str | list[str]

prompt_token_ids#

Tokenized prompt as list of token IDs.

Type

list[list[int]] | list[int]

outputs#

List of CompletionOutput objects (one per n in sampling params).

Type

list[easydel.inference.esurge.esurge_engine.CompletionOutput]

finished#

Whether generation has completed.

Type

bool

metrics#

Dictionary of performance metrics (tokens, timing, etc.).

Type

dict[str, Any] | None

accumulated_text#

Full generated text accumulated so far.

Type

str

delta_text#

Only the latest decoded text chunk (for streaming).

Type

str

tokens_per_second#

Current generation throughput.

Type

float

num_generated_tokens#

Total number of tokens generated.

Type

int

time_spent_generating#

Total time spent in generation.

Type

float

first_token_time#

Time to first token (TTFT) in seconds.

Type

float | None

processing_time#

Total processing time including queuing.

Type

float

update_seq#

Sequence number incremented on any update.

Type

int

delta_seq#

Sequence number incremented only when delta_text changes.

Type

int

accumulated_text: str = ''#
delta_seq: int = 0#
delta_text: str = ''#
finished: bool = False#
first_token_time: float | None = None#
get_summary() dict[str, Any][source]#

Get a summary of the request output.

Returns

request_id, text, throughput, token count, timing, completion status and finish reason.

Return type

Dictionary containing key metrics

get_text() str[source]#

Get the generated text from the first completion output.

Returns

Generated text string, or empty string if no outputs.

metrics: dict[str, Any] | None = None#
num_generated_tokens: int = 0#
outputs: list[easydel.inference.esurge.esurge_engine.CompletionOutput]#
processing_time: float = 0.0#
prompt: str | list[str]#
prompt_token_ids: list[list[int]] | list[int]#
request_id: str#
time_spent_generating: float = 0.0#
tokens_per_second: float = 0.0#
update_seq: int = 0#
class easydel.inference.esurge.esurge_engine.eSurge(model: str | EasyDeLBaseModule, tokenizer: str | PreTrainedTokenizerBase | None = None, dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>, max_model_len: int = 8192, min_input_pad: int = 16, max_num_seqs: int = 256, max_num_batched_tokens: int | None = None, hbm_utilization: float = 0.85, page_size: int = 128, use_aot_forward: bool = True, enable_prefix_caching: bool = True, auto_shard_model: bool = True, sharding_axis_dims: tuple[int, ...] = (1, 1, 1, -1, 1), compile_runner: bool = True, runner_verbose: bool = False, overlap_execution: bool = False, sampler_metrics: bool = False, esurge_name: str | None = None, reserve_tokens: int | None = None, auto_truncate_prompt: bool = True, auto_cap_new_tokens: bool = True, strict_context: bool = False, truncate_mode: typing.Literal['left', 'right', 'middle'] = 'left', prefer_preserve_prompt: bool = True, decode_truncated_prompt: bool = True, destroy_pages_on_pause: bool = True, detokenizer_max_states: int = 65536, tokenizer_endpoint: str | None = None, detokenizer_endpoint: str | None = None, sampling_params_callback: typing.Callable[[SamplingParams, dict[str, typing.Any]], SamplingParams | None] | None = None, extra_eos_token_ids: list[int] | None = None, silent_mode: bool = False, **kwargs)[source]#

Bases: object

High-level engine interface for text generation with eSurge.

eSurge is a high-performance inference engine built on JAX that provides: - Efficient batched inference with paged attention - Continuous batching with background scheduling - Streaming generation with delta text tracking - Comprehensive monitoring and metrics - Thread-safe request handling - Dynamic context management with automatic prompt truncation

The engine runs a background scheduler thread that continuously processes requests from the queue, enabling high throughput and low latency.

Key Features:
  • Context Management: Automatically manages context length with configurable truncation strategies and token reservation.

  • Streaming Support: Efficient incremental decoding with configurable intervals for optimal performance.

  • Monitoring: Built-in Prometheus metrics and console monitoring (visualize with Grafana).

Example

>>> # Initialize engine
>>> engine = eSurge(
...     model="model-name",
...     max_model_len=8192,
...     reserve_tokens=800  # Reserve tokens for generation
... )
>>> engine.initiate()
>>>
>>> # Generate with streaming
>>> for output in engine.stream("Tell me a story"):
...     print(output.delta_text, end="", flush=True)
abort_request(request_id: str) None[source]#

Abort an in-progress request.

Marks the request as aborted and signals any waiting threads. The request will be removed from the scheduler queue if still waiting.

Parameters

request_id – ID of the request to abort.

chat(messages: list[dict[str, str]], tools: list[dict] | None = None, sampling_params: easydel.inference.sampling_params.SamplingParams | None = None, request_id: str | None = None, stream: bool = False, chat_template: str | None = None)[source]#

High-level chat interface compatible with vLLM and OpenAI APIs.

Provides a convenient chat-based interface for conversational AI applications. Automatically formats messages using the model’s chat template and handles both streaming and non-streaming responses.

Parameters
  • messages – List of message dictionaries representing the conversation history. Each message must have ‘role’ and ‘content’ keys. Supported roles are typically ‘system’, ‘user’, and ‘assistant’, but may vary by model. Example: [{“role”: “user”, “content”: “Hello!”}]

  • tools – Optional list of tool/function definitions for function calling. Format should match the model’s expected tool schema. Tools allow the model to request function calls as part of its response.

  • sampling_params – Generation parameters controlling temperature, top_p, max_tokens, etc. Defaults to SamplingParams(max_tokens=128) if None.

  • request_id – Optional unique identifier for tracking this request. Auto-generated if None.

  • stream – If True, returns an iterator yielding incremental RequestOutput objects with delta_text for real-time streaming. If False, returns a single RequestOutput with the complete response.

  • chat_template – Optional custom Jinja2 template to override the tokenizer’s default chat template. Useful for models with non-standard formats.

Returns

Single RequestOutput object containing the complete

assistant response with all metrics and generated text.

  • If stream=True: Iterator[RequestOutput] yielding incremental updates with delta_text containing newly generated text chunks.

Return type

  • If stream=False

Raises
  • ValueError – If messages format is invalid or empty.

  • RuntimeError – If scheduler is not running or tokenizer lacks chat template.

Example

>>> # Non-streaming chat
>>> messages = [
...     {"role": "system", "content": "You are a helpful assistant."},
...     {"role": "user", "content": "Explain quantum computing"}
... ]
>>> response = engine.chat(messages, sampling_params=SamplingParams(max_tokens=200))
>>> print(response.get_text())
>>>
>>> # Streaming chat with function calling
>>> tools = [{
...     "type": "function",
...     "function": {
...         "name": "get_weather",
...         "description": "Get weather for a location",
...         "parameters": {...}
...     }
... }]
>>> for chunk in engine.chat(messages, tools=tools, stream=True):
...     print(chunk.delta_text, end="", flush=True)
>>>
>>> # Custom chat template
>>> custom_template = "{% for message in messages %}...{% endfor %}"
>>> response = engine.chat(messages, chat_template=custom_template)

Note

This method provides compatibility with OpenAI’s chat completions API and vLLM’s chat interface, making it easy to migrate existing applications. The exact behavior of tool calling and special tokens depends on the specific model being used.

property esurge_name: str#
generate(prompts: str | list[str], sampling_params: easydel.inference.sampling_params.SamplingParams | None = None, request_id: str | list[str] | None = None, use_tqdm: bool = True) list[easydel.inference.esurge.esurge_engine.RequestOutput][source]#

Generate completions for one or more prompts (blocking).

Synchronous batch generation that waits for all completions to finish before returning. Suitable for batch processing scenarios where you need all results at once.

Parameters
  • prompts – Single prompt string or list of prompts to generate from.

  • sampling_params – Generation parameters controlling temperature, top_p, max_tokens, etc. Defaults to SamplingParams(max_tokens=128) if None.

  • request_id – Optional request ID(s) for tracking. Auto-generated if None. Can be a single string (for single prompt) or list of strings.

  • use_tqdm – Show progress bar for batch generation. Useful for tracking progress with multiple prompts.

Returns

  • Generated text in the text field

  • Token IDs in the token_ids field

  • Performance metrics (tokens/sec, latency, etc.)

  • Finish reason (‘stop’, ‘length’, ‘eos_token’)

Return type

List of RequestOutput objects containing

Raises
  • RuntimeError – If background scheduler is not running. Call initiate() first.

  • ValueError – If prompts and request_ids have mismatched lengths.

Example

>>> # Single prompt generation
>>> outputs = engine.generate(
...     "What is AI?",
...     SamplingParams(max_tokens=100, temperature=0.7)
... )
>>> print(outputs[0].get_text())
>>>
>>> # Batch generation with progress bar
>>> prompts = ["Question 1?", "Question 2?", "Question 3?"]
>>> outputs = engine.generate(prompts, use_tqdm=True)
>>> for i, output in enumerate(outputs):
...     print(f"Prompt {i}: {output.get_text()[:50]}...")
get_metrics_summary() dict[str, Any][source]#

Get current performance metrics summary.

Returns

  • requests_per_second: Current request throughput

  • average_latency: Average request latency

  • average_ttft: Average time to first token

  • average_throughput: Average tokens/second

  • total_completed: Total completed requests

  • total_failed: Total failed requests

  • total_tokens: Total tokens generated

  • active_requests: Currently active requests

  • queue_size: Pending requests in queue

  • running_requests: Currently running requests

Return type

Dictionary containing

initiate() None[source]#

Start the background scheduler thread.

Initiates a daemon thread that continuously runs the scheduler loop, processing requests from the queue and updating outputs. This must be called before using generate() or stream() methods.

The scheduler thread will: 1. Schedule requests from the waiting queue 2. Execute model forward passes 3. Update request outputs with generated tokens 4. Signal waiting threads when updates are available

property monitoring_active: bool#
property num_pending_requests: int#

Get the number of requests waiting in queue.

Returns

Count of requests in the waiting queue.

property num_running_requests: int#

Get the number of actively running requests.

Returns

Count of requests currently being processed.

pause() None[source]#

Pause the background scheduler without clearing queued state.

resume() None[source]#

Resume the scheduler if it was paused.

set_sampling_params_callback(callback: Optional[Callable[[SamplingParams, dict[str, Any]], easydel.inference.sampling_params.SamplingParams | None]]) None[source]#

Register or clear the sampling-params callback.

Parameters

callback – Callable receiving a cloned SamplingParams and metadata dict (request_id, prompt, engine). Return a new SamplingParams, mutate the provided one, or return None to keep the original values. Pass None to disable the callback.

start_monitoring(dashboard_port: int | None = None, prometheus_port: int = 11184, dashboard_host: str | None = None, enable_prometheus: bool = True, enable_dashboard: bool | None = None, enable_console: bool = False, log_file: str | None = None, log_interval: float = 10.0, history_size: int = 1000, enable_detailed_logging: bool = True, start_grafana: bool = True, grafana_port: int = 3000, grafana_host: str | None = None, grafana_image: str = 'grafana/grafana-oss:latest', grafana_use_docker: bool = False, grafana_admin_user: str = 'admin', grafana_admin_password: str = 'admin', grafana_allow_anonymous: bool = True, grafana_datasource_name: str = 'eSurge Prometheus', grafana_datasource_uid: str | None = None, grafana_datasource_url: str | None = None) dict[str, str][source]#

Start Prometheus-based monitoring for the engine.

Initializes the Prometheus metrics exporter, optional console monitor, and (by default) tries to auto-start a Grafana instance with a pre-provisioned Prometheus data source (local grafana-server first, optionally Docker if enabled).

Parameters
  • dashboard_port – Deprecated; no longer used (kept for compatibility).

  • prometheus_port – Port for Prometheus metrics endpoint.

  • dashboard_host – Deprecated; no longer used (kept for compatibility).

  • enable_prometheus – Start Prometheus metrics server.

  • enable_dashboard – Deprecated; no longer used (kept for compatibility).

  • enable_console – Start console monitor with rich display.

  • log_file – Optional file path for metrics logging.

  • log_interval – Interval in seconds between metric logs.

  • history_size – Number of historical metrics to retain.

  • enable_detailed_logging – Enable detailed metric logging.

  • start_grafana – Auto-start Grafana (via Docker) pointed at the Prometheus endpoint.

  • grafana_port – Host port to expose Grafana.

  • grafana_host – Hostname to use when reporting Grafana URL (defaults to localhost).

  • grafana_image – Docker image for Grafana (used when grafana_use_docker=True).

  • grafana_use_docker – Allow falling back to Docker for Grafana when local server is unavailable.

  • grafana_admin_user – Admin username for Grafana.

  • grafana_admin_password – Admin password for Grafana.

  • grafana_allow_anonymous – Allow anonymous admin access to Grafana (for quick local use).

  • grafana_datasource_name – Display name for the auto-provisioned Prometheus data source.

  • grafana_datasource_uid – UID for the Prometheus data source (auto-generated if None).

  • grafana_datasource_url – Override URL for the Prometheus data source inside Grafana.

Returns

  • ‘prometheus’: Prometheus metrics endpoint

  • ’grafana’: Grafana UI (when auto-start succeeds)

Return type

Dictionary of service URLs

start_profiling(output_dir: str, num_batches: int = 10, host_tracer_level: int | None = None, python_tracer_level: int | None = None) None[source]#

Start a JAX profiler trace for the next num_batches scheduler updates.

stop_monitoring() None[source]#

Stop all monitoring services.

Gracefully shuts down Prometheus server and console monitor if they are running.

stop_profiling() None[source]#

Stop the active JAX profiler trace, if any.

stream(prompts: str | list[str], sampling_params: easydel.inference.sampling_params.SamplingParams | None = None, request_id: str | None = None) Iterator[RequestOutput][source]#

Stream generation output as tokens are produced.

Yields RequestOutput objects incrementally as new tokens are generated, enabling real-time streaming of generated text. Perfect for interactive applications and chat interfaces.

Args:
prompts: Single prompt string or list with one prompt. For multiple

prompts, use generate() instead.

sampling_params: Generation parameters controlling temperature, top_p,

max_tokens, etc. Defaults to SamplingParams(max_tokens=128).

request_id: Optional request ID for tracking. Auto-generated if None.

Yields:
RequestOutput objects with incremental updates:
  • delta_text: Only the newly generated text since last yield

  • accumulated_text: Full text generated so far

  • finished: True when generation is complete

  • tokens_per_second: Current generation throughput

  • num_generated_tokens: Total tokens generated so far

Raises:

ValueError: If empty prompt list provided. RuntimeError: If scheduler not running or request setup fails.

Example:
>>> # Basic streaming
>>> for output in engine.stream("Tell me a story"):
...     if output.delta_text:
...         print(output.delta_text, end="", flush=True)
...     if output.finished:
...         break
>>>
>>> # Monitor generation speed
>>> for output in engine.stream("Long prompt here..."):
...     if output.delta_text:
...         print(output.delta_text, end="")
...     if output.num_generated_tokens % 10 == 0:
...         print(f"

[{output.tokens_per_second:.1f} tok/s]”, end=””)

terminate() None[source]#

Stop the background scheduler thread.

Gracefully shuts down the scheduler loop and waits for the thread to terminate. Should be called when the engine is no longer needed to free resources.

update_model_weights(model: EasyDeLBaseModule | None = None, *, graphdef=None, graphstate=None, graphother=None, restart_scheduler: bool = True) None[source]#

Hot-swap the underlying model weights/graphs.

The engine must be idle (no pending or running requests) before calling this method. It temporarily stops the scheduler loop, refreshes runner state, rebuilds the scheduler, and optionally restarts background serving.

Parameters
  • model – Optional EasyDeLBaseModule carrying the new weights.

  • graphdef – Optional graphdef override.

  • graphstate – Optional graphstate override.

  • graphother – Optional graphother override.

  • restart_scheduler – Restart the scheduler thread if it was previously running (default: True).

Raises
  • RuntimeError – If there are active or pending requests.

  • ValueError – If no model/graph data is provided.