# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base interface for EasyDeL inference API servers.
This module provides abstract base classes and utilities for building
standardized inference API servers with OpenAI API compatibility.
Classes:
ServerStatus: Enum representing server operational states
ServerMetrics: Dataclass for tracking server performance metrics
EndpointConfig: Configuration for API endpoints
ErrorResponse: Standard error response format
BaseInferenceApiServer: Abstract base class for inference servers
InferenceEngineAdapter: Abstract adapter for different inference engines
"""
from __future__ import annotations
import asyncio
import time
import typing as tp
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from enum import Enum
from http import HTTPStatus
import uvicorn
from eformer.loggings import get_logger
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from .openai_api_modules import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
FunctionCallFormat,
)
from .sampling_params import SamplingParams
if tp.TYPE_CHECKING:
from ..utils import ReturnSample
TIMEOUT_KEEP_ALIVE = 5.0
logger = get_logger("InferenceApiServer")
[docs]class ServerStatus(str, Enum):
"""Server status enumeration.
Represents the operational state of an inference server.
Attributes:
STARTING: Server is initializing
READY: Server is ready to accept requests
BUSY: Server is processing requests at capacity
ERROR: Server encountered an error
SHUTTING_DOWN: Server is gracefully shutting down
"""
STARTING = "starting"
READY = "ready"
BUSY = "busy"
ERROR = "error"
SHUTTING_DOWN = "shutting_down"
[docs]@dataclass
class ServerMetrics:
"""Server performance metrics.
Tracks key performance indicators for the inference server.
Attributes:
total_requests: Total number of requests received
successful_requests: Number of successfully completed requests
failed_requests: Number of failed requests
total_tokens_generated: Total tokens generated across all requests
average_tokens_per_second: Average generation speed
uptime_seconds: Time since server started
start_time: Unix timestamp when server started
"""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
total_tokens_generated: int = 0
average_tokens_per_second: float = 0.0
uptime_seconds: float = 0.0
start_time: float = field(default_factory=time.time)
[docs]class EndpointConfig(BaseModel):
"""Configuration for a FastAPI endpoint.
Defines the structure for registering API endpoints.
Attributes:
path: URL path for the endpoint
handler: Callable that handles requests
methods: HTTP methods supported (GET, POST, etc.)
summary: Brief description of the endpoint
tags: Tags for API documentation grouping
response_model: Pydantic model for response validation
"""
path: str
handler: tp.Callable
methods: list[str]
summary: str | None = None
tags: list[str] | None = None
response_model: tp.Any = None
[docs]class ErrorResponse(BaseModel):
"""Standard error response model.
Provides a consistent error response format across all endpoints.
Attributes:
error: Dictionary containing error message and type
request_id: Optional unique identifier for the request
timestamp: Unix timestamp when error occurred
"""
error: dict[str, str]
request_id: str | None = None
timestamp: float = Field(default_factory=time.time)
[docs]def create_error_response(status_code: HTTPStatus, message: str, request_id: str | None = None) -> JSONResponse:
"""Creates a standardized JSON error response.
Args:
status_code: HTTP status code for the error
message: Human-readable error message
request_id: Optional request identifier for tracking
Returns:
JSONResponse with error details and appropriate status code
"""
error_response = ErrorResponse(error={"message": message, "type": status_code.name}, request_id=request_id)
return JSONResponse(content=error_response.model_dump(), status_code=status_code.value)
[docs]class BaseInferenceApiServer(ABC):
"""
Abstract base class for inference API servers.
This interface defines the standard structure and methods that all inference
API servers should implement to ensure consistency across different inference modules.
"""
def __init__(
self,
max_workers: int | None = None,
enable_cors: bool = True,
cors_origins: list[str] | None = None,
max_request_size: int = 10 * 1024 * 1024,
request_timeout: float = 300.0,
enable_function_calling: bool = True,
default_function_format: FunctionCallFormat = FunctionCallFormat.OPENAI,
server_name: str = "EasyDeL Inference API Server",
server_description: str = "High-performance inference server with OpenAI API compatibility",
server_version: str = "2.0.0",
enable_auth_ui: bool = True,
) -> None:
"""
Initialize the base inference API server.
Args:
max_workers: Maximum number of worker threads
enable_cors: Enable CORS middleware
cors_origins: Allowed CORS origins
max_request_size: Maximum request size in bytes
request_timeout: Request timeout in seconds
enable_function_calling: Enable function calling support
default_function_format: Default format for function calls
server_name: Name of the server for FastAPI app
server_description: Description of the server
server_version: Version of the server
enable_auth_ui: Enable "Authorize" button in /docs for API key input
"""
self.thread_pool = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="inference-worker")
self.max_request_size = max_request_size
self.request_timeout = request_timeout
self.status = ServerStatus.STARTING
self.metrics = ServerMetrics()
self._active_requests: set[str] = set()
self._request_lock = asyncio.Lock()
self.enable_function_calling = enable_function_calling
self.default_function_format = default_function_format
# Configure OpenAPI security scheme for API key authentication
swagger_ui_init_oauth = None
if enable_auth_ui:
swagger_ui_init_oauth = {
"clientId": "swagger-ui",
"appName": "Swagger UI",
"usePkceWithAuthorizationCodeGrant": True,
}
self.app = FastAPI(
title=server_name,
description=server_description,
version=server_version,
lifespan=self._lifespan,
swagger_ui_init_oauth=swagger_ui_init_oauth,
)
# Add security schemes to OpenAPI schema if auth UI is enabled
if enable_auth_ui:
self.app.openapi_schema = None # Reset to regenerate
self._configure_openapi_security()
if enable_cors:
self._setup_cors(cors_origins)
self._setup_middleware()
self._register_endpoints()
if enable_function_calling:
self._add_function_calling_endpoints()
logger.info(f"{server_name} initialized")
@asynccontextmanager
async def _lifespan(self, app: FastAPI):
"""Manage server lifecycle."""
logger.info(f"Starting {self.app.title}...")
await self.on_startup()
self.status = ServerStatus.READY
yield
logger.info(f"Shutting down {self.app.title}...")
self.status = ServerStatus.SHUTTING_DOWN
await self._graceful_shutdown()
await self.on_shutdown()
[docs] async def on_startup(self) -> None: # noqa: B027
"""Hook for server startup.
Override in subclasses to perform custom initialization tasks
such as loading models, establishing connections, or warming up caches.
This method is called once when the server starts.
"""
pass
[docs] async def on_shutdown(self) -> None: # noqa: B027
"""Hook for server shutdown.
Override in subclasses to perform cleanup tasks such as
saving state, closing connections, or releasing resources.
This method is called once when the server shuts down.
"""
pass
def _configure_openapi_security(self) -> None:
"""Configure OpenAPI security schemes for API key authentication.
Adds security scheme definitions to the OpenAPI schema, which enables
the "Authorize" button in the /docs UI. Users can enter their API key
via Bearer token (Authorization header) or X-API-Key header.
"""
def custom_openapi():
if self.app.openapi_schema:
return self.app.openapi_schema
from fastapi.openapi.utils import get_openapi
openapi_schema = get_openapi(
title=self.app.title,
version=self.app.version,
description=self.app.description,
routes=self.app.routes,
)
# Define security schemes
openapi_schema["components"]["securitySchemes"] = {
"BearerAuth": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "API Key",
"description": "Enter your API key as a Bearer token (e.g., `sk-...`)",
},
"ApiKeyAuth": {
"type": "apiKey",
"in": "header",
"name": "X-API-Key",
"description": "Enter your API key in the X-API-Key header",
},
}
# Apply security globally to all endpoints (optional, can be overridden per endpoint)
# This makes the "Authorize" button appear in the UI
openapi_schema["security"] = [
{"BearerAuth": []},
{"ApiKeyAuth": []},
]
self.app.openapi_schema = openapi_schema
return self.app.openapi_schema
self.app.openapi = custom_openapi
def _setup_cors(self, origins: list[str] | None) -> None:
"""Setup CORS middleware.
Configures Cross-Origin Resource Sharing to allow web browsers
to make requests to the API from different domains.
Args:
origins: List of allowed origin URLs. Defaults to ["*"] (all origins)
"""
self.app.add_middleware(
CORSMiddleware,
allow_origins=origins or ["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def _setup_middleware(self) -> None:
"""Setup request middleware.
Configures middleware for request tracking, metrics collection,
and request ID generation. This method adds two middleware layers:
1. Request ID assignment for tracking
2. Metrics collection for monitoring
"""
@self.app.middleware("http")
async def add_request_id(request: Request, call_next):
"""Add unique request ID to each request."""
request_id = f"req_{int(time.time() * 1000000)}"
request.state.request_id = request_id
async with self._request_lock:
self._active_requests.add(request_id)
try:
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
finally:
async with self._request_lock:
self._active_requests.discard(request_id)
@self.app.middleware("http")
async def track_metrics(request: Request, call_next):
"""Track request metrics."""
self.metrics.total_requests += 1
try:
response = await call_next(request)
if response.status_code < 400:
self.metrics.successful_requests += 1
else:
self.metrics.failed_requests += 1
return response
except Exception:
self.metrics.failed_requests += 1
raise
finally:
self.metrics.uptime_seconds = time.time() - self.metrics.start_time
@property
def _endpoints(self) -> list[EndpointConfig]:
"""Define all API endpoints.
The base server exposes a predictable suite of OpenAI-compatible
endpoints, so this list acts as the single source of truth for route
registration. Subclasses rarely need to override individual routes;
instead they can extend or prune the list by overriding the property and
composing additional :class:`EndpointConfig` entries. Keeping the
definitions centralized also makes it easier to document the surface
area of a deployment, since API docs, middleware, and monitoring only
have to inspect one place to understand which handlers exist.
Each entry intentionally captures the handler callable, HTTP verbs,
documentation metadata, and optional response model. This mirrors the
arguments passed to ``FastAPI.add_api_route`` and prevents drift between
declarative configuration and runtime state. By standardizing this
schema we can build tooling (for example automated smoke tests or
changelog generators) that iterate through the endpoints without having
to introspect the FastAPI app directly.
"""
return [
EndpointConfig(
path="/v1/chat/completions",
handler=self.chat_completions,
methods=["POST"],
tags=["Chat"],
summary=(
"Submit a conversation expressed as OpenAI-style chat messages and "
"receive assistant turns that honor streaming, function calling, and"
"token usage accounting. The endpoint mirrors the semantics of the "
"OpenAI Chat Completions API so existing SDKs and client libraries can"
"drop in without translation.\n\n"
"Use this route whenever you need multi-turn context, tool invocation,"
"or delta streaming—Simple text prompts should go through the plain"
"completions endpoint below."
),
),
EndpointConfig(
path="/v1/completions",
handler=self.completions,
methods=["POST"],
tags=["Completions"],
summary=(
"Generate text from a raw prompt without chat scaffolding. This matches"
"OpenAI's legacy completion API and is ideal for single-turn tasks such"
"as template expansion, summarization, or logit probing.\n\n"
"Clients receive either a full response object or a text/event-stream"
"when `stream=true`, making it a minimal surface for classic prompt"
"engineering workloads."
),
),
EndpointConfig(
path="/health",
handler=self.health_check,
methods=["GET"],
tags=["Health"],
summary=(
"Lightweight health probe that reports server status, uptime, active"
"request counts, and model metadata. Load balancers and orchestrators"
"can call this endpoint to decide whether a replica should receive"
"traffic.\n\n"
"The payload is intentionally human-readable so operators can curl the"
"route during incidents and immediately understand whether the server"
"is READY, BUSY, or encountering errors."
),
),
EndpointConfig(
path="/v1/models",
handler=self.list_models,
methods=["GET"],
tags=["Models"],
summary=(
"Enumerate every model the server has loaded along with ownership,"
"capabilities, and tokenizer limits. The response mirrors the OpenAI"
"`/v1/models` schema so existing tooling (CLI, dashboards, SDKs) can"
"introspect deployments without custom code.\n\n"
"Call this endpoint when building control planes or auditing which"
"models are exposed to end users."
),
),
EndpointConfig(
path="/v1/models/{model_id}",
handler=self.get_model,
methods=["GET"],
tags=["Models"],
summary=(
"Return detailed metadata for a specific model ID, including tokenizer"
"capabilities, architecture hints, and server ownership information."
"Use this to confirm feature support (e.g., chat templates or tool"
"calling) before dispatching a request.\n\n"
"The payload is stable enough to cache in control planes or config"
"UIs that need to display per-model characteristics."
),
),
EndpointConfig(
path="/metrics",
handler=self.get_metrics,
methods=["GET"],
tags=["Monitoring"],
summary=(
"Expose aggregated counters covering request throughput, success and"
"failure rates, generated tokens, and authentication statistics."
"Intended for SRE dashboards, autoscalers, or simple cron-based"
"reporting scripts.\n\n"
"Because it shares the same authorization story as model endpoints,"
"operators can lock down who may read metrics without punching holes"
"in infrastructure firewalls."
),
),
]
def _add_function_calling_endpoints(self) -> None:
"""Add function calling specific endpoints."""
additional_endpoints = [
EndpointConfig(
path="/v1/tools",
handler=self.list_tools,
methods=["GET"],
tags=["Tools"],
summary="List available tools/functions",
),
EndpointConfig(
path="/v1/tools/execute",
handler=self.execute_tool,
methods=["POST"],
tags=["Tools"],
summary="Execute a tool/function call",
),
]
for endpoint in additional_endpoints:
self.app.add_api_route(
path=endpoint.path,
endpoint=endpoint.handler,
methods=endpoint.methods,
summary=endpoint.summary,
tags=endpoint.tags,
)
def _register_endpoints(self) -> None:
"""Register all API endpoints."""
for endpoint in self._endpoints:
self.app.add_api_route(
path=endpoint.path,
endpoint=endpoint.handler,
methods=endpoint.methods,
summary=endpoint.summary,
tags=endpoint.tags,
response_model=endpoint.response_model,
)
async def _graceful_shutdown(self) -> None:
"""Perform graceful shutdown."""
max_wait = 30
start = time.time()
while self._active_requests and (time.time() - start) < max_wait:
logger.info(f"Waiting for {len(self._active_requests)} active requests...")
await asyncio.sleep(1)
if self._active_requests:
logger.warning(f"Force closing {len(self._active_requests)} active requests")
self.thread_pool.shutdown(wait=True)
logger.info("Thread pool shut down")
# Abstract methods that must be implemented by subclasses
[docs] @abstractmethod
async def chat_completions(
self,
request: ChatCompletionRequest,
raw_request: Request,
) -> ChatCompletionResponse | StreamingResponse | JSONResponse:
"""
Handle chat completion requests.
Args:
request: The chat completion request
raw_request: Raw FastAPI request containing headers
Returns:
Chat completion response (streaming or non-streaming)
"""
raise NotImplementedError
[docs] @abstractmethod
async def completions(
self,
request: CompletionRequest,
raw_request: Request,
) -> CompletionResponse | StreamingResponse | JSONResponse:
"""
Handle completion requests.
Args:
request: The completion request
raw_request: Raw FastAPI request containing headers
Returns:
Completion response (streaming or non-streaming)
"""
raise NotImplementedError
[docs] @abstractmethod
async def health_check(self, raw_request: Request) -> JSONResponse:
"""
Perform comprehensive health check.
Args:
raw_request: Raw FastAPI request containing headers
Returns:
Health status information
"""
raise NotImplementedError
[docs] @abstractmethod
async def get_metrics(self, raw_request: Request) -> JSONResponse:
"""
Get server performance metrics.
Args:
raw_request: Raw FastAPI request containing headers
Returns:
Server metrics information
"""
raise NotImplementedError
[docs] @abstractmethod
async def list_models(self, raw_request: Request) -> JSONResponse:
"""
List available models.
Args:
raw_request: Raw FastAPI request containing headers
Returns:
List of available models with metadata
"""
raise NotImplementedError
[docs] @abstractmethod
async def get_model(self, model_id: str, raw_request: Request) -> JSONResponse:
"""
Get detailed information about a specific model.
Args:
model_id: The model identifier
raw_request: Raw FastAPI request containing headers
Returns:
Model details
"""
raise NotImplementedError
# Helper methods that can be used by subclasses
@abstractmethod
def _create_sampling_params(self, request: ChatCompletionRequest | CompletionRequest) -> SamplingParams:
"""
Create sampling parameters from request.
Args:
request: The completion request
Returns:
Sampling parameters for the inference engine
"""
raise NotImplementedError
def _determine_finish_reason(self, tokens_generated: int, max_tokens: float, text: str) -> str:
"""
Determine the finish reason for a generation.
Args:
tokens_generated: Number of tokens generated
max_tokens: Maximum tokens allowed
text: Generated text
Returns:
Finish reason string
"""
if tokens_generated >= max_tokens:
return "length"
return "stop"
async def _count_tokens_async(self, content: str, model_name: str | None = None) -> int:
"""
Count tokens asynchronously.
Args:
content: Text content to tokenize
model_name: Optional model name for model-specific tokenization
Returns:
Number of tokens
"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.thread_pool, self._count_tokens, content, model_name)
@abstractmethod
def _count_tokens(self, content: str, model_name: str | None = None) -> int:
"""
Count tokens for the given content.
Args:
content: Text content to tokenize
model_name: Optional model name for model-specific tokenization
Returns:
Number of tokens
"""
raise NotImplementedError
[docs] def run(
self,
host: str = "0.0.0.0",
port: int = 11556,
workers: int = 1,
log_level: str = "info",
ssl_keyfile: str | None = None,
ssl_certfile: str | None = None,
reload: bool = False,
) -> None:
"""
Start the server with enhanced configuration.
Args:
host: Host address to bind to
port: Port to listen on
workers: Number of worker processes
log_level: Logging level
ssl_keyfile: Path to SSL key file
ssl_certfile: Path to SSL certificate file
reload: Enable auto-reload for development
"""
uvicorn_config = {
"app": self.app,
"host": host,
"port": port,
"workers": workers if not reload else 1,
"log_level": log_level,
"timeout_keep_alive": TIMEOUT_KEEP_ALIVE,
"reload": reload,
"server_header": False,
"date_header": True,
}
if ssl_keyfile and ssl_certfile:
uvicorn_config.update({"ssl_keyfile": ssl_keyfile, "ssl_certfile": ssl_certfile})
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger.info("Using uvloop for enhanced performance")
except ImportError:
logger.info("uvloop not available, using default event loop")
uvicorn.run(**uvicorn_config)
fire = run
[docs]class InferenceEngineAdapter(ABC):
"""
Abstract adapter interface for different inference engines.
This allows different inference engines (eSurge, vLLM, TGI, etc.) to be used
with the same API server interface.
"""
[docs] @abstractmethod
async def generate(
self,
prompts: str | list[str],
sampling_params: SamplingParams,
stream: bool = False,
) -> list[ReturnSample] | tp.AsyncGenerator[list[ReturnSample], None]:
"""
Generate text from prompts.
Args:
prompts: Input prompts
sampling_params: Sampling parameters
stream: Whether to stream the response
Returns:
Generated samples (list or async generator)
"""
raise NotImplementedError
[docs] @abstractmethod
def count_tokens(self, content: str) -> int:
"""
Count tokens in the given content.
Args:
content: Text content
Returns:
Number of tokens
"""
raise NotImplementedError
[docs] @abstractmethod
def get_model_info(self) -> dict[str, tp.Any]:
"""
Get information about the loaded model.
Returns:
Model information dictionary
"""
raise NotImplementedError
@property
@abstractmethod
def model_name(self) -> str:
"""Get the name of the model."""
raise NotImplementedError
@property
@abstractmethod
def processor(self) -> tp.Any:
"""Get the processor/tokenizer for the model."""
raise NotImplementedError