eSurge Inference Engine#

eSurge is a high-performance inference engine designed for efficient batched text generation with advanced caching and scheduling capabilities. It provides fine-grained control over memory management and request scheduling, optimized for TPU, GPU, and CPU platforms.

Overview#

eSurge offers a production-ready serving solution with state-of-the-art performance characteristics:

  • Advanced KV Cache Management: Page-based allocation with prefix caching support

  • Multiple Attention Patterns: Full, sliding window, and chunked attention support

  • Flexible Request Scheduling: FCFS and priority-based scheduling algorithms

  • Continuous Batching: Dynamic batching for improved throughput

  • Multi-Backend Support: Optimized for TPU, GPU, and CPU with platform-specific kernels

  • Streaming Generation: Low-latency streaming with delta text updates

  • OpenAI-Compatible API: Drop-in replacement for OpenAI API endpoints

Key Features#

Performance Optimizations#

  • Paged Attention: Efficient memory management with page-based KV cache allocation

  • Prefix Caching: Reuse common prefixes across requests for improved efficiency

  • JIT Compilation: Fully JIT-compiled forward passes for ~20% speedup

  • Background Scheduling: Thread-safe request processing with background scheduler

  • Optimized Token Decoding: Interval-based batching for efficient token processing

Memory Management#

  • HBM Utilization Control: Fine-grained control over high-bandwidth memory usage

  • Dynamic Page Allocation: Automatic memory management with page pooling

  • Configurable Page Sizes: Adjust page size for optimal memory utilization

  • Cache Eviction Policies: Smart eviction strategies for long-running services

Request Handling#

  • Priority Scheduling: Support for request prioritization

  • Request Preemption: Ability to preempt lower-priority requests

  • Batched Processing: Efficient batching of multiple requests

  • Stream-First Design: Optimized for real-time streaming applications

Installation#

For TPU-specific optimizations:

pip install easydel[tpu]

Quick Start#

Basic Usage#

Here’s a simple example to get started with eSurge:

import os
os.environ["EASYDEL_AUTO"] = "1"  # Enable automatic sharding

from jax import lax
from jax import numpy as jnp
from transformers import AutoTokenizer
import easydel as ed

# Load model and tokenizer
model_id = "your-model-id"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
    model_id,
    dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    precision=lax.Precision.DEFAULT,
    auto_shard_model=True,
    sharding_axis_dims=(1, 1, 1, -1, 1),
    config_kwargs=ed.EasyDeLBaseConfigDict(
        attn_mechanism=ed.AttentionMechanisms.RAGGED_PAGE_ATTENTION_V3,
        attn_dtype=jnp.bfloat16,
    ),
)

# Initialize eSurge engine
engine = ed.eSurge(
    model=model,
    tokenizer=tokenizer,
    max_model_len=8192,
    max_num_seqs=16,
    hbm_utilization=0.9,
    page_size=64,
)

# Generate text
response = engine.generate(
    "What is the meaning of life?",
    sampling_params=ed.SamplingParams(
        max_tokens=100,
        temperature=0.7,
    )
)
print(response[0].outputs[0].text)

Streaming Generation#

eSurge supports streaming generation for real-time applications:

# Stream tokens as they're generated
for output in engine.stream(
    "Tell me a story about a robot",
    sampling_params=ed.SamplingParams(
        max_tokens=200,
        temperature=0.8,
    )
):
    print(output.delta_text, end="", flush=True)

Batch Processing#

Process multiple prompts efficiently:

prompts = [
    "Explain quantum computing",
    "What is machine learning?",
    "How does the internet work?",
]

# Process multiple prompts
outputs = engine.generate(
    prompts,
    sampling_params=ed.SamplingParams(
        max_tokens=100,
        temperature=0.7,
    )
)

for output in outputs:
    print(f"Request {output.request_id}: {output.outputs[0].text}")

Configuration#

Engine Configuration#

eSurge provides extensive configuration options:

engine = ed.eSurge(
    model=model,
    tokenizer=tokenizer,

    # Model configuration
    max_model_len=16384,          # Maximum sequence length
    max_num_seqs=32,               # Maximum concurrent sequences

    # Memory configuration
    hbm_utilization=0.9,           # HBM utilization ratio (0.0-1.0)
    page_size=64,                  # KV cache page size

    # Performance options
    runner_verbose=False,          # Verbose runner logging
    min_input_pad=32,              # Minimum input padding

    # Naming
    esurge_name="my-engine",       # Engine instance name
)

Advanced Configuration#

For fine-grained control, use configuration objects:

from easydel.inference.esurge import (
    Config,
    SchedulerConfig,
    CacheConfig,
)

# Advanced configuration is typically handled through engine parameters
# The engine internally creates the appropriate configuration

engine = ed.eSurge(
    model=model,
    tokenizer=tokenizer,
    max_model_len=8192,
    max_num_seqs=16,
    max_num_batched_tokens=2048,  # Optional
    enable_prefix_caching=True,
    page_size=16,
    hbm_utilization=0.9,
)

API Server#

eSurge includes an OpenAI-compatible API server for easy deployment:

Starting the Server#

# Initialize engine (as shown above)
engine = ed.eSurge(model=model, tokenizer=tokenizer, ...)

# Start monitoring (optional)
engine.start_monitoring()

# Launch API server (API keys optional)
server = ed.eSurgeApiServer(
    engine,
    require_api_key=True,                 # Optional: enforce keys at runtime
    api_keys={"demo-key": {"label": "UI"}},  # Pre-provisioned keys
)

# Additional keys can be issued programmatically
print("Generated key:", server.generate_api_key(label="notebook"))

server.fire(host="0.0.0.0", port=8000)

The server provides OpenAI-compatible endpoints:

  • POST /v1/chat/completions - Chat completions

  • POST /v1/completions - Text completions

  • GET /v1/models - List available models

  • GET /health - Health check endpoint

  • GET /metrics - Prometheus metrics (if monitoring enabled)

Using the API#

Once the server is running, you can use any OpenAI-compatible client:

import openai

client = openai.OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="demo-key",  # Matches the key configured on the server
)

response = client.chat.completions.create(
    model="default",
    messages=[
        {"role": "user", "content": "Hello, how are you?"}
    ],
    temperature=0.7,
    max_tokens=100,
    stream=True,  # Streaming is supported
)

for chunk in response:
    print(chunk.choices[0].delta.content, end="")

API keys may be supplied via Authorization: Bearer <key>, X-API-Key, or an api_key query parameter. If require_api_key is False (the default), clients can still connect without a key, but registered keys will collect per-client prompt/completion token usage.

Monitoring and Metrics#

eSurge provides comprehensive monitoring capabilities:

When API key enforcement is enabled, the /metrics endpoint includes an api_key_usage section that reports prompt and completion token totals for each registered key, making it easy to attribute usage per client.

Console Monitoring#

# Start Prometheus exporter and console monitor
engine.start_monitoring(enable_console=True)

# The console will display:
# - Request throughput
# - Token generation rate
# - Memory utilization
# - Cache hit rates
# - Active request count

Grafana / Prometheus#

Export metrics for Prometheus and visualize in Grafana:

# Start Prometheus metrics server
engine.start_monitoring(prometheus_port=9090)

# Access metrics at http://localhost:9090/metrics
# Add this URL as a Prometheus data source in Grafana

Performance Tuning#

TPU Optimization#

For best performance on TPU:

engine = ed.eSurge(
    model=model,
    tokenizer=tokenizer,
    max_model_len=8192,
    max_num_seqs=64,        # TPUs handle larger batches well
    page_size=128,          # Larger pages for TPU
    use_aot_forward=True,   # Essential for TPU performance
    hbm_utilization=0.95,   # TPUs have abundant HBM
)

GPU Optimization#

For NVIDIA GPUs:

engine = ed.eSurge(
    model=model,
    tokenizer=tokenizer,
    max_model_len=4096,
    max_num_seqs=32,
    page_size=64,
    hbm_utilization=0.85,       # Leave headroom for kernels
)

Memory-Constrained Settings#

For systems with limited memory:

engine = ed.eSurge(
    model=model,
    tokenizer=tokenizer,
    max_model_len=2048,     # Reduce max length
    max_num_seqs=8,         # Fewer concurrent sequences
    page_size=32,           # Smaller pages
    hbm_utilization=0.7,    # Conservative memory usage
)

Advanced Features#

Prefix Caching#

Enable prefix caching for improved efficiency with common prefixes:

from easydel.inference.esurge import CacheConfig

cache_config = CacheConfig(
    enable_prefix_caching=True,
    prefix_cache_capacity=0.3,  # 30% of cache for prefixes
)

engine = ed.eSurge(
    model=model,
    tokenizer=tokenizer,
    cache_config=cache_config,
)

Custom Scheduling#

Implement custom scheduling logic:

from easydel.inference.esurge import PriorityRequestQueue

# Priority scheduling is configured in the scheduler
# The scheduler automatically handles request prioritization
# based on the scheduling policy (FCFS or Priority)

# Note: Direct priority setting is handled internally
# by the scheduler based on arrival time and request characteristics

Function Calling#

eSurge supports function calling for tool use:

# Function calling is available through the API server
# When using the API server with OpenAI-compatible clients,
# function calling follows the OpenAI format:

# client.chat.completions.create(
#     model="default",
#     messages=[{"role": "user", "content": "What's the weather?"}],
#     functions=[{...}],  # OpenAI function schema
#     function_call="auto"
# )

Troubleshooting#

Common Issues#

Out of Memory Errors

Reduce memory usage by:

  • Decreasing max_num_seqs

  • Reducing max_model_len

  • Lowering hbm_utilization

  • Using smaller page_size

Slow Generation

Improve performance by:

  • Enabling use_aot_forward for TPU

  • Increasing page_size for better memory access patterns

  • Adjusting max_num_batched_tokens

  • Using appropriate precision (bfloat16 recommended)

Request Timeouts

Handle long-running requests:

  • Increase timeout values in sampling parameters

  • Monitor request queue with Prometheus + Grafana

  • Implement request preemption for fairness

Best Practices#

  1. Memory Management

    • Monitor HBM utilization and adjust hbm_utilization parameter

    • Use prefix caching for repetitive prompts

    • Enable request preemption for long-running services

  2. Performance Optimization

    • Use platform-specific settings (TPU vs GPU)

    • Enable JIT compilation with use_aot_forward

    • Batch similar-length sequences together

  3. Production Deployment

    • Enable monitoring and metrics collection

    • Implement health checks and auto-restart

    • Use load balancing for multiple instances

    • Set up proper logging and alerting

  4. Scaling

    • Horizontal scaling with multiple engine instances

    • Use model parallelism for large models

    • Implement request routing based on load

API Reference#

For detailed API documentation, see:

  • api_docs/inference_esurge - Core eSurge APIs

  • api_docs/inference_esurge_engine - Engine implementation

  • api_docs/inference_esurge_scheduler - Scheduling components

  • api_docs/inference_esurge_server - API server implementation

Example Applications#

Chat Application#

Complete example of a chat application:

import os
os.environ["EASYDEL_AUTO"] = "1"

from jax import lax, numpy as jnp
from transformers import AutoTokenizer
import easydel as ed

def create_chat_engine(model_id="microsoft/phi-2"):
    """Create an eSurge engine for chat."""
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token_id = tokenizer.eos_token_id

    model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
        model_id,
        dtype=jnp.bfloat16,
        param_dtype=jnp.bfloat16,
        auto_shard_model=True,
        config_kwargs=ed.EasyDeLBaseConfigDict(
            attn_mechanism=ed.AttentionMechanisms.RAGGED_PAGE_ATTENTION_V2,
        ),
    )

    engine = ed.eSurge(
        model=model,
        tokenizer=tokenizer,
        max_model_len=4096,
        max_num_seqs=16,
        hbm_utilization=0.9,
        page_size=64,
        esurge_name="chat-engine",
    )

    return engine

def chat_loop(engine):
    """Interactive chat loop."""
    conversation = []

    while True:
        user_input = input("\nYou: ")
        if user_input.lower() in ['quit', 'exit']:
            break

        conversation.append({"role": "user", "content": user_input})

        # Format conversation for model
        prompt = format_conversation(conversation)

        # Generate response
        print("Assistant: ", end="", flush=True)
        response_text = ""

        for output in engine.stream(
            prompt,
            sampling_params=ed.SamplingParams(
                max_tokens=200,
                temperature=0.7,
                stop=["</s>", "\n\nYou:"],
            )
        ):
            print(output.delta_text, end="", flush=True)
            response_text += output.delta_text

        print()  # New line after response
        conversation.append({"role": "assistant", "content": response_text})

def format_conversation(messages):
    """Format messages for model input."""
    formatted = ""
    for msg in messages:
        role = msg["role"].capitalize()
        formatted += f"{role}: {msg['content']}\n"
    formatted += "Assistant: "
    return formatted

if __name__ == "__main__":
    engine = create_chat_engine()
    engine.start_monitoring()
    print("Chat engine ready! Type 'quit' to exit.")
    chat_loop(engine)

Batch Processing Service#

Example of a batch processing service:

import asyncio
from typing import List, Dict
import easydel as ed

class BatchProcessor:
    """Batch text processing service."""

    def __init__(self, engine: ed.eSurge):
        self.engine = engine
        self.pending_requests = {}

    async def process_batch(
        self,
        texts: List[str],
        sampling_params: ed.SamplingParams = None,
    ) -> List[str]:
        """Process a batch of texts."""
        if sampling_params is None:
            sampling_params = ed.SamplingParams(
                max_new_tokens=100,
                temperature=0.7,
            )

        # Submit all requests
        request_ids = []
        for text in texts:
            request_id = self.engine.add_request(
                prompt=text,
                sampling_params=sampling_params,
            )
            request_ids.append(request_id)
            self.pending_requests[request_id] = None

        # Wait for all completions
        results = {}
        while len(results) < len(request_ids):
            outputs = self.engine.step()

            for output in outputs:
                if output.finished:
                    results[output.request_id] = output.outputs[0].text
                    self.pending_requests.pop(output.request_id, None)

            await asyncio.sleep(0.01)  # Small delay

        # Return in original order
        return [results[rid] for rid in request_ids]

    async def process_stream(
        self,
        texts: List[str],
        callback=None,
    ):
        """Process texts with streaming callbacks."""
        for text in texts:
            request_id = self.engine.add_request(
                prompt=text,
                sampling_params=ed.SamplingParams(
                    max_new_tokens=100,
                    temperature=0.7,
                ),
            )

            # Stream this request to completion
            while True:
                outputs = self.engine.step()

                for output in outputs:
                    if output.request_id == request_id:
                        if callback:
                            await callback(output)

                        if output.finished:
                            break
                else:
                    await asyncio.sleep(0.01)
                    continue
                break

# Usage example
async def main():
    engine = create_chat_engine()  # From previous example
    processor = BatchProcessor(engine)

    # Batch processing
    prompts = [
        "Explain quantum computing in simple terms",
        "What are the benefits of exercise?",
        "How does photosynthesis work?",
    ]

    results = await processor.process_batch(prompts)
    for prompt, result in zip(prompts, results):
        print(f"Q: {prompt}\nA: {result}\n")

if __name__ == "__main__":
    asyncio.run(main())

Performance Comparison#

eSurge vs Other Engines#

Performance characteristics compared to other inference engines:

Future Roadmap#

Planned features for future releases:

  • Speculative Decoding: Faster generation with draft models

  • Multi-LoRA Support: Serve multiple LoRA adapters efficiently

  • Tensor Parallelism: Scale to larger models

  • Quantization Methods: Additional quantization schemes (AWQ, GPTQ)

  • Model Hotswapping: Switch models without restarting

  • Request Caching: Cache and reuse common request patterns

  • Pipeline Parallelism: Better scaling across multiple devices

Contributing#

We welcome contributions to eSurge! Areas where help is appreciated:

  • Performance optimizations for specific hardware

  • Additional scheduling algorithms

  • New attention mechanisms

  • Documentation and examples

  • Bug reports and fixes

See the main EasyDeL contributing guide for more information.

Support#

For issues and questions:

License#

eSurge is part of EasyDeL and is licensed under the Apache License 2.0.