easydel.inference.esurge.esurge_engine#
eSurge Engine - High-Performance Inference Engine for EasyDeL.
This module provides the eSurge engine, a high-performance text generation system built on JAX that offers efficient batched inference with advanced features like continuous batching and comprehensive monitoring.
- Key Components:
eSurge: Main engine class for text generation
RequestOutput: Container for generation results and metrics
CompletionOutput: Individual completion within a batch
- Features:
Continuous Batching: Background scheduler thread processes requests continuously for optimal throughput.
Context Management: Automatic prompt truncation and token reservation with configurable strategies.
Streaming Support: Real-time token streaming with delta updates.
Monitoring: Built-in Prometheus metrics and console monitor (Grafana-ready).
- Usage Example:
>>> from easydel.inference.esurge import eSurge >>> from easydel.inference.sampling_params import SamplingParams >>> >>> # Initialize engine >>> engine = eSurge( ... model="model-name", ... max_model_len=8192, ... reserve_tokens=800 ... ) >>> >>> # Stream generation >>> for output in engine.stream("Tell me about AI"): ... print(output.delta_text, end="", flush=True) >>> >>> # Batch generation >>> outputs = engine.generate( ... ["Question 1?", "Question 2?"], ... SamplingParams(max_tokens=100, temperature=0.7) ... )
- Technical Details:
The engine uses a multi-threaded architecture with: - Main thread: Handles API calls and request submission - Scheduler thread: Continuously processes queued requests - JAX computation: Executes model forward passes
- class easydel.inference.esurge.esurge_engine.CompletionOutput(index: int, text: str, token_ids: list[int], cumulative_logprob: float | None = None, logprobs: list[dict[int, float]] | None = None, finish_reason: str | None = None)[source]#
Bases:
objectOutput of a single completion.
Represents the generated output for a single completion within a batch request. Contains the generated text, token IDs, and optional probability information.
- index#
Position of this completion in the batch (0-indexed).
- Type
int
- text#
The generated text string.
- Type
str
- token_ids#
List of token IDs that were generated.
- Type
list[int]
- cumulative_logprob#
Cumulative log probability of the generated sequence.
- Type
float | None
- logprobs#
Per-token log probabilities as dict mapping token_id to logprob.
- Type
list[dict[int, float]] | None
- finish_reason#
Reason for completion termination (‘stop’, ‘length’, ‘eos_token’, etc.).
- Type
str | None
- index: int#
- text: str#
- token_ids: list[int]#
- class easydel.inference.esurge.esurge_engine.RequestOutput(request_id: str, prompt: str | list[str], prompt_token_ids: list[list[int]] | list[int], outputs: list[easydel.inference.esurge.esurge_engine.CompletionOutput], finished: bool = False, metrics: dict[str, Any] | None = None, accumulated_text: str = '', delta_text: str = '', tokens_per_second: float = 0.0, num_generated_tokens: int = 0, time_spent_generating: float = 0.0, first_token_time: float | None = None, processing_time: float = 0.0, update_seq: int = 0, delta_seq: int = 0)[source]#
Bases:
objectOutput of a generation request with comprehensive metrics.
Contains the complete output for a generation request including generated text, performance metrics, and streaming support fields. Used for both batch and streaming generation modes.
- request_id#
Unique identifier for this request.
- Type
str
- prompt#
Original prompt text.
- Type
str | list[str]
- prompt_token_ids#
Tokenized prompt as list of token IDs.
- Type
list[list[int]] | list[int]
- outputs#
List of CompletionOutput objects (one per n in sampling params).
- finished#
Whether generation has completed.
- Type
bool
- metrics#
Dictionary of performance metrics (tokens, timing, etc.).
- Type
dict[str, Any] | None
- accumulated_text#
Full generated text accumulated so far.
- Type
str
- delta_text#
Only the latest decoded text chunk (for streaming).
- Type
str
- tokens_per_second#
Current generation throughput.
- Type
float
- num_generated_tokens#
Total number of tokens generated.
- Type
int
- time_spent_generating#
Total time spent in generation.
- Type
float
- first_token_time#
Time to first token (TTFT) in seconds.
- Type
float | None
- processing_time#
Total processing time including queuing.
- Type
float
- update_seq#
Sequence number incremented on any update.
- Type
int
- delta_seq#
Sequence number incremented only when delta_text changes.
- Type
int
- accumulated_text: str = ''#
- delta_seq: int = 0#
- delta_text: str = ''#
- finished: bool = False#
- get_summary() dict[str, Any][source]#
Get a summary of the request output.
- Returns
request_id, text, throughput, token count, timing, completion status and finish reason.
- Return type
Dictionary containing key metrics
- get_text() str[source]#
Get the generated text from the first completion output.
- Returns
Generated text string, or empty string if no outputs.
- num_generated_tokens: int = 0#
- outputs: list[easydel.inference.esurge.esurge_engine.CompletionOutput]#
- processing_time: float = 0.0#
- prompt: str | list[str]#
- prompt_token_ids: list[list[int]] | list[int]#
- request_id: str#
- time_spent_generating: float = 0.0#
- tokens_per_second: float = 0.0#
- update_seq: int = 0#
- class easydel.inference.esurge.esurge_engine.eSurge(model: str | EasyDeLBaseModule, tokenizer: str | PreTrainedTokenizerBase | None = None, dtype: jnp.dtype = <class 'jax.numpy.bfloat16'>, max_model_len: int = 8192, min_input_pad: int = 16, max_num_seqs: int = 256, max_num_batched_tokens: int | None = None, hbm_utilization: float = 0.85, page_size: int = 128, use_aot_forward: bool = True, enable_prefix_caching: bool = True, auto_shard_model: bool = True, sharding_axis_dims: tuple[int, ...] = (1, 1, 1, -1, 1), compile_runner: bool = True, runner_verbose: bool = False, overlap_execution: bool = False, sampler_metrics: bool = False, esurge_name: str | None = None, reserve_tokens: int | None = None, auto_truncate_prompt: bool = True, auto_cap_new_tokens: bool = True, strict_context: bool = False, truncate_mode: typing.Literal['left', 'right', 'middle'] = 'left', prefer_preserve_prompt: bool = True, decode_truncated_prompt: bool = True, destroy_pages_on_pause: bool = True, detokenizer_max_states: int = 65536, tokenizer_endpoint: str | None = None, detokenizer_endpoint: str | None = None, sampling_params_callback: typing.Callable[[SamplingParams, dict[str, typing.Any]], SamplingParams | None] | None = None, extra_eos_token_ids: list[int] | None = None, silent_mode: bool = False, **kwargs)[source]#
Bases:
objectHigh-level engine interface for text generation with eSurge.
eSurge is a high-performance inference engine built on JAX that provides: - Efficient batched inference with paged attention - Continuous batching with background scheduling - Streaming generation with delta text tracking - Comprehensive monitoring and metrics - Thread-safe request handling - Dynamic context management with automatic prompt truncation
The engine runs a background scheduler thread that continuously processes requests from the queue, enabling high throughput and low latency.
- Key Features:
Context Management: Automatically manages context length with configurable truncation strategies and token reservation.
Streaming Support: Efficient incremental decoding with configurable intervals for optimal performance.
Monitoring: Built-in Prometheus metrics and console monitoring (visualize with Grafana).
Example
>>> # Initialize engine >>> engine = eSurge( ... model="model-name", ... max_model_len=8192, ... reserve_tokens=800 # Reserve tokens for generation ... ) >>> engine.initiate() >>> >>> # Generate with streaming >>> for output in engine.stream("Tell me a story"): ... print(output.delta_text, end="", flush=True)
- abort_request(request_id: str) None[source]#
Abort an in-progress request.
Marks the request as aborted and signals any waiting threads. The request will be removed from the scheduler queue if still waiting.
- Parameters
request_id – ID of the request to abort.
- chat(messages: list[dict[str, str]], tools: list[dict] | None = None, sampling_params: easydel.inference.sampling_params.SamplingParams | None = None, request_id: str | None = None, stream: bool = False, chat_template: str | None = None)[source]#
High-level chat interface compatible with vLLM and OpenAI APIs.
Provides a convenient chat-based interface for conversational AI applications. Automatically formats messages using the model’s chat template and handles both streaming and non-streaming responses.
- Parameters
messages – List of message dictionaries representing the conversation history. Each message must have ‘role’ and ‘content’ keys. Supported roles are typically ‘system’, ‘user’, and ‘assistant’, but may vary by model. Example: [{“role”: “user”, “content”: “Hello!”}]
tools – Optional list of tool/function definitions for function calling. Format should match the model’s expected tool schema. Tools allow the model to request function calls as part of its response.
sampling_params – Generation parameters controlling temperature, top_p, max_tokens, etc. Defaults to SamplingParams(max_tokens=128) if None.
request_id – Optional unique identifier for tracking this request. Auto-generated if None.
stream – If True, returns an iterator yielding incremental RequestOutput objects with delta_text for real-time streaming. If False, returns a single RequestOutput with the complete response.
chat_template – Optional custom Jinja2 template to override the tokenizer’s default chat template. Useful for models with non-standard formats.
- Returns
- Single RequestOutput object containing the complete
assistant response with all metrics and generated text.
If stream=True: Iterator[RequestOutput] yielding incremental updates with delta_text containing newly generated text chunks.
- Return type
If stream=False
- Raises
ValueError – If messages format is invalid or empty.
RuntimeError – If scheduler is not running or tokenizer lacks chat template.
Example
>>> # Non-streaming chat >>> messages = [ ... {"role": "system", "content": "You are a helpful assistant."}, ... {"role": "user", "content": "Explain quantum computing"} ... ] >>> response = engine.chat(messages, sampling_params=SamplingParams(max_tokens=200)) >>> print(response.get_text()) >>> >>> # Streaming chat with function calling >>> tools = [{ ... "type": "function", ... "function": { ... "name": "get_weather", ... "description": "Get weather for a location", ... "parameters": {...} ... } ... }] >>> for chunk in engine.chat(messages, tools=tools, stream=True): ... print(chunk.delta_text, end="", flush=True) >>> >>> # Custom chat template >>> custom_template = "{% for message in messages %}...{% endfor %}" >>> response = engine.chat(messages, chat_template=custom_template)
Note
This method provides compatibility with OpenAI’s chat completions API and vLLM’s chat interface, making it easy to migrate existing applications. The exact behavior of tool calling and special tokens depends on the specific model being used.
- property esurge_name: str#
- generate(prompts: str | list[str], sampling_params: easydel.inference.sampling_params.SamplingParams | None = None, request_id: str | list[str] | None = None, use_tqdm: bool = True) list[easydel.inference.esurge.esurge_engine.RequestOutput][source]#
Generate completions for one or more prompts (blocking).
Synchronous batch generation that waits for all completions to finish before returning. Suitable for batch processing scenarios where you need all results at once.
- Parameters
prompts – Single prompt string or list of prompts to generate from.
sampling_params – Generation parameters controlling temperature, top_p, max_tokens, etc. Defaults to SamplingParams(max_tokens=128) if None.
request_id – Optional request ID(s) for tracking. Auto-generated if None. Can be a single string (for single prompt) or list of strings.
use_tqdm – Show progress bar for batch generation. Useful for tracking progress with multiple prompts.
- Returns
Generated text in the text field
Token IDs in the token_ids field
Performance metrics (tokens/sec, latency, etc.)
Finish reason (‘stop’, ‘length’, ‘eos_token’)
- Return type
List of RequestOutput objects containing
- Raises
RuntimeError – If background scheduler is not running. Call initiate() first.
ValueError – If prompts and request_ids have mismatched lengths.
Example
>>> # Single prompt generation >>> outputs = engine.generate( ... "What is AI?", ... SamplingParams(max_tokens=100, temperature=0.7) ... ) >>> print(outputs[0].get_text()) >>> >>> # Batch generation with progress bar >>> prompts = ["Question 1?", "Question 2?", "Question 3?"] >>> outputs = engine.generate(prompts, use_tqdm=True) >>> for i, output in enumerate(outputs): ... print(f"Prompt {i}: {output.get_text()[:50]}...")
- get_metrics_summary() dict[str, Any][source]#
Get current performance metrics summary.
- Returns
requests_per_second: Current request throughput
average_latency: Average request latency
average_ttft: Average time to first token
average_throughput: Average tokens/second
total_completed: Total completed requests
total_failed: Total failed requests
total_tokens: Total tokens generated
active_requests: Currently active requests
queue_size: Pending requests in queue
running_requests: Currently running requests
- Return type
Dictionary containing
- initiate() None[source]#
Start the background scheduler thread.
Initiates a daemon thread that continuously runs the scheduler loop, processing requests from the queue and updating outputs. This must be called before using generate() or stream() methods.
The scheduler thread will: 1. Schedule requests from the waiting queue 2. Execute model forward passes 3. Update request outputs with generated tokens 4. Signal waiting threads when updates are available
- property monitoring_active: bool#
- property num_pending_requests: int#
Get the number of requests waiting in queue.
- Returns
Count of requests in the waiting queue.
- property num_running_requests: int#
Get the number of actively running requests.
- Returns
Count of requests currently being processed.
- set_sampling_params_callback(callback: Optional[Callable[[SamplingParams, dict[str, Any]], easydel.inference.sampling_params.SamplingParams | None]]) None[source]#
Register or clear the sampling-params callback.
- Parameters
callback – Callable receiving a cloned SamplingParams and metadata dict (
request_id,prompt,engine). Return a new SamplingParams, mutate the provided one, or return None to keep the original values. Pass None to disable the callback.
- start_monitoring(dashboard_port: int | None = None, prometheus_port: int = 11184, dashboard_host: str | None = None, enable_prometheus: bool = True, enable_dashboard: bool | None = None, enable_console: bool = False, log_file: str | None = None, log_interval: float = 10.0, history_size: int = 1000, enable_detailed_logging: bool = True, start_grafana: bool = True, grafana_port: int = 3000, grafana_host: str | None = None, grafana_image: str = 'grafana/grafana-oss:latest', grafana_use_docker: bool = False, grafana_admin_user: str = 'admin', grafana_admin_password: str = 'admin', grafana_allow_anonymous: bool = True, grafana_datasource_name: str = 'eSurge Prometheus', grafana_datasource_uid: str | None = None, grafana_datasource_url: str | None = None) dict[str, str][source]#
Start Prometheus-based monitoring for the engine.
Initializes the Prometheus metrics exporter, optional console monitor, and (by default) tries to auto-start a Grafana instance with a pre-provisioned Prometheus data source (local grafana-server first, optionally Docker if enabled).
- Parameters
dashboard_port – Deprecated; no longer used (kept for compatibility).
prometheus_port – Port for Prometheus metrics endpoint.
dashboard_host – Deprecated; no longer used (kept for compatibility).
enable_prometheus – Start Prometheus metrics server.
enable_dashboard – Deprecated; no longer used (kept for compatibility).
enable_console – Start console monitor with rich display.
log_file – Optional file path for metrics logging.
log_interval – Interval in seconds between metric logs.
history_size – Number of historical metrics to retain.
enable_detailed_logging – Enable detailed metric logging.
start_grafana – Auto-start Grafana (via Docker) pointed at the Prometheus endpoint.
grafana_port – Host port to expose Grafana.
grafana_host – Hostname to use when reporting Grafana URL (defaults to localhost).
grafana_image – Docker image for Grafana (used when grafana_use_docker=True).
grafana_use_docker – Allow falling back to Docker for Grafana when local server is unavailable.
grafana_admin_user – Admin username for Grafana.
grafana_admin_password – Admin password for Grafana.
grafana_allow_anonymous – Allow anonymous admin access to Grafana (for quick local use).
grafana_datasource_name – Display name for the auto-provisioned Prometheus data source.
grafana_datasource_uid – UID for the Prometheus data source (auto-generated if None).
grafana_datasource_url – Override URL for the Prometheus data source inside Grafana.
- Returns
‘prometheus’: Prometheus metrics endpoint
’grafana’: Grafana UI (when auto-start succeeds)
- Return type
Dictionary of service URLs
- start_profiling(output_dir: str, num_batches: int = 10, host_tracer_level: int | None = None, python_tracer_level: int | None = None) None[source]#
Start a JAX profiler trace for the next
num_batchesscheduler updates.
- stop_monitoring() None[source]#
Stop all monitoring services.
Gracefully shuts down Prometheus server and console monitor if they are running.
- stream(prompts: str | list[str], sampling_params: easydel.inference.sampling_params.SamplingParams | None = None, request_id: str | None = None) Iterator[RequestOutput][source]#
Stream generation output as tokens are produced.
Yields RequestOutput objects incrementally as new tokens are generated, enabling real-time streaming of generated text. Perfect for interactive applications and chat interfaces.
- Args:
- prompts: Single prompt string or list with one prompt. For multiple
prompts, use generate() instead.
- sampling_params: Generation parameters controlling temperature, top_p,
max_tokens, etc. Defaults to SamplingParams(max_tokens=128).
request_id: Optional request ID for tracking. Auto-generated if None.
- Yields:
- RequestOutput objects with incremental updates:
delta_text: Only the newly generated text since last yield
accumulated_text: Full text generated so far
finished: True when generation is complete
tokens_per_second: Current generation throughput
num_generated_tokens: Total tokens generated so far
- Raises:
ValueError: If empty prompt list provided. RuntimeError: If scheduler not running or request setup fails.
- Example:
>>> # Basic streaming >>> for output in engine.stream("Tell me a story"): ... if output.delta_text: ... print(output.delta_text, end="", flush=True) ... if output.finished: ... break >>> >>> # Monitor generation speed >>> for output in engine.stream("Long prompt here..."): ... if output.delta_text: ... print(output.delta_text, end="") ... if output.num_generated_tokens % 10 == 0: ... print(f"
[{output.tokens_per_second:.1f} tok/s]”, end=””)
- terminate() None[source]#
Stop the background scheduler thread.
Gracefully shuts down the scheduler loop and waits for the thread to terminate. Should be called when the engine is no longer needed to free resources.
- update_model_weights(model: EasyDeLBaseModule | None = None, *, graphdef=None, graphstate=None, graphother=None, restart_scheduler: bool = True) None[source]#
Hot-swap the underlying model weights/graphs.
The engine must be idle (no pending or running requests) before calling this method. It temporarily stops the scheduler loop, refreshes runner state, rebuilds the scheduler, and optionally restarts background serving.
- Parameters
model – Optional EasyDeLBaseModule carrying the new weights.
graphdef – Optional graphdef override.
graphstate – Optional graphstate override.
graphother – Optional graphother override.
restart_scheduler – Restart the scheduler thread if it was previously running (default: True).
- Raises
RuntimeError – If there are active or pending requests.
ValueError – If no model/graph data is provided.