# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements a FastAPI server for serving vEngine models, mimicking OpenAI API."""
from __future__ import annotations
import asyncio
import typing as tp
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
import uvicorn
from eformer.pytree import auto_pytree
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from transformers import ProcessorMixin
from easydel.inference.utilities import SamplingParams
from easydel.utils.helpers import get_logger
from easydel.utils.lazy_import import is_package_available
from ..openai_api_modules import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionStreamResponse,
CompletionStreamResponseChoice,
CountTokenRequest,
DeltaMessage,
UsageInfo,
)
from .vsurge import vSurge, vSurgeRequest
TIMEOUT_KEEP_ALIVE = 5.0
APP = FastAPI(title="EasyDeL vSurge API Server")
logger = get_logger("vSurgeApiServer")
[docs]@auto_pytree
class EndpointConfig:
"""Configuration for a FastAPI endpoint."""
path: str
handler: tp.Callable
methods: list[str]
summary: tp.Optional[str] = None
tags: tp.Optional[list[str]] = None
[docs]def create_error_response(status_code: HTTPStatus, message: str) -> JSONResponse:
"""Creates a standardized JSON error response."""
return JSONResponse({"error": {"message": message}}, status_code=status_code.value)
[docs]class vSurgeApiServer:
"""
FastAPI server for serving vEngine instances.
This server provides endpoints mimicking the OpenAI API structure for chat completions,
liveness/readiness checks, token counting, and listing available models.
It handles both streaming and non-streaming requests asynchronously using a thread pool.
"""
def __init__(
self,
vsurge_map: tp.Union[tp.Dict[str, vSurge], vSurge] = None,
max_workers: int = 10,
oai_like_processor: bool = True,
) -> None:
if isinstance(vsurge_map, vSurge):
vsurge_map = {vsurge_map.vsurge_name: vsurge_map}
self.vsurge_map: tp.Dict[str, vSurge] = {}
for name, vsurge in vsurge_map.items():
err_msg = (
f"Value for key '{name}' in vsurge_map must be an instance of `vSurge`, "
f"got {type(vsurge).__name__} instead."
)
assert isinstance(vsurge, vSurge), err_msg
self.vsurge_map[name] = vsurge
logger.info(f"Added vsurge: {name}")
self.thread_pool = ThreadPoolExecutor(max_workers=max_workers)
self.logger = logger
self.oai_like_processor = oai_like_processor
self._register_endpoints()
@property
def _endpoints(self) -> tp.List[EndpointConfig]:
"""Defines all API endpoints for the server."""
return [
EndpointConfig(
path="/v1/chat/completions",
handler=self.chat_completions,
methods=["POST"],
tags=["Chat"],
summary="Creates a model response for the given chat conversation.",
),
EndpointConfig(
path="/v1/completions",
handler=self.completions,
methods=["POST"],
tags=["Completions"],
summary="Creates a completion for the provided prompt.",
),
EndpointConfig(
path="/liveness",
handler=self.liveness,
methods=["GET"],
tags=["Health"],
summary="Checks if the API server is running.",
),
EndpointConfig(
path="/readiness",
handler=self.readiness,
methods=["GET"],
tags=["Health"],
summary="Checks if the API server is ready to receive requests.",
),
EndpointConfig(
path="/v1/count_tokens", # Changed path for consistency
handler=self.count_tokens,
methods=["POST"],
tags=["Utility"],
summary="Counts the number of tokens in a given text or conversation.",
),
EndpointConfig(
path="/v1/models", # Changed path to match OpenAI standard
handler=self.available_inference,
methods=["GET"],
tags=["Utility"],
summary="Lists the models available through this API.",
),
]
def _create_sampling_params_from_request(
self,
request: ChatCompletionRequest,
) -> SamplingParams:
"""Creates SamplingParams from a ChatCompletionRequest."""
return SamplingParams(
max_tokens=int(request.max_tokens),
temperature=float(request.temperature),
presence_penalty=float(request.presence_penalty),
frequency_penalty=float(request.frequency_penalty),
repetition_penalty=float(request.repetition_penalty),
top_k=int(request.top_k),
top_p=float(request.top_p),
min_p=float(request.min_p),
suppress_tokens=request.suppress_tokens,
)
[docs] async def completions(self, request: CompletionRequest):
"""
Handles completion requests (POST /v1/completions).
Processes the prompt for completion and returns generated text.
Args:
request (CompletionRequest): The incoming request data.
Returns:
Union[JSONResponse, StreamingResponse]: The generated response.
"""
try:
vsurge = self._get_vsurge(request.model)
# Process the prompt
prompt = request.prompt
if isinstance(prompt, list):
prompt = prompt[0]
if not request.stream:
return await self._handle_completion_response_async(
request,
vsurge,
prompt,
)
else:
return await self._handle_completion_streaming_async(
request,
vsurge,
prompt,
)
except Exception as e:
self.logger.exception(f"Error during completion for model {request.model}: {e}")
return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
[docs] async def chat_completions(self, request: ChatCompletionRequest):
"""
Handles chat completion requests (POST /v1/chat/completions).
Validates the request, retrieves the appropriate vEngine model,
tokenizes the input, and delegates to streaming or non-streaming handlers.
Args:
request (ChatCompletionRequest): The incoming request data.
Returns:
Union[JSONResponse, StreamingResponse]: The generated response, either
a complete JSON object or a streaming event-stream.
"""
try:
vsurge = self._get_vsurge(request.model)
ids = await self._prepare_vsurge_input_async(request, vsurge)
if not request.stream:
return await self._handle_completion_response_async(request, vsurge, ids)
else:
return await self._handle_completion_streaming_async(request, vsurge, ids)
except Exception as e:
self.logger.exception(
f"Error during chat completion for model {request.model}: {e}"
)
return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
def _get_vsurge(self, model_name: str) -> vSurge:
"""
Retrieves the vEngine instance for the given model name.
Args:
model_name (str): The requested model name.
Returns:
vEngine: The corresponding vEngine instance.
Raises:
RuntimeError: If the model name is not found in the `vsurge_map`.
"""
vsurge = self.vsurge_map.get(model_name)
if vsurge is None:
available_models = list(self.vsurge_map.keys())
error_msg = (
f"Invalid model name: '{model_name}'. Available models: {available_models}"
)
self.logger.error(error_msg)
raise RuntimeError(error_msg)
return vsurge
def _prepare_vsurge_input(
self,
request: ChatCompletionRequest,
vsurge: vSurge,
) -> str:
conversation = request.model_dump(exclude_unset=True)["messages"]
processor = vsurge.processor
if isinstance(processor, ProcessorMixin) and self.oai_like_processor:
from easydel.trainers.prompt_utils import convert_to_openai_format
conversation = convert_to_openai_format(conversation)
try:
return processor.apply_chat_template(
conversation=conversation,
add_generation_prompt=True,
tokenize=False,
)
except Exception as e:
self.logger.exception(
f"Error applying chat template for model {vsurge.vsurge_name}: {e}"
)
raise RuntimeError(f"Error tokenizing input: {e}") from e
async def _prepare_vsurge_input_async(self, request, vsurge) -> dict:
"""Runs tokenization in the thread pool."""
return await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
self._prepare_vsurge_input,
request,
vsurge,
)
def _create_usage_info(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
time_spent_computing: float,
tokens_per_second: float,
) -> UsageInfo:
"""Creates a UsageInfo object."""
return UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
tokens_per_second=tokens_per_second,
processing_time=time_spent_computing,
)
async def _handle_completion_response_async(
self,
request: ChatCompletionRequest | CompletionRequest,
vsurge: vSurge,
content: str,
):
"""Runs the non-streaming handler in the thread pool."""
"""
Generates a complete, non-streaming chat response.
Runs the vEngine generation loop to completion and formats the result.
Args:
request (ChatCompletionRequest): The original request.
vsurge (vEngine): The vEngine instance.
ids (dict): The tokenized input dictionary.
Returns:
ChatCompletionResponse: The complete chat response object.
"""
prompt_tokens = vsurge.count_tokens(content)
sampling_params = self._create_sampling_params_from_request(request)
response_state = None
final_response = ""
async for response_state in vsurge.complete(
request=vSurgeRequest.from_sampling_params(
prompt=content,
sampling_params=sampling_params,
)
):
final_response += response_state[0].text
response_state = response_state[0]
if response_state is None:
raise RuntimeError("Generation failed to produce any output state.")
generated_tokens = response_state.num_generated_tokens
time_spent_computing = 0
tokens_per_second = response_state.tokens_per_second
finish_reason = (
"length" if generated_tokens >= sampling_params.max_tokens else "stop"
)
function_call_result = None
usage = self._create_usage_info(
prompt_tokens=prompt_tokens,
completion_tokens=generated_tokens,
total_tokens=prompt_tokens + generated_tokens,
time_spent_computing=time_spent_computing,
tokens_per_second=tokens_per_second,
)
if isinstance(request, ChatCompletionRequest):
return ChatCompletionResponse(
model=request.model,
choices=[
ChatCompletionResponseChoice(
index=generated_tokens + 1,
message=ChatMessage(
role="assistant",
content=final_response,
function_call=function_call_result,
),
finish_reason="function_call" if function_call_result else finish_reason,
)
],
usage=usage,
)
elif isinstance(request, CompletionRequest):
return CompletionResponse(
model=request.model,
choices=[
CompletionResponseChoice(
index=generated_tokens + 1,
text=final_response,
finish_reason=finish_reason,
)
],
usage=usage,
)
else:
raise NotImplementedError("UnKnown request type!")
async def _handle_completion_streaming_async(
self,
request: ChatCompletionRequest | CompletionRequest,
vsurge: vSurge,
content: str,
) -> StreamingResponse:
"""Handle streaming response generation asynchronously."""
async def stream_results() -> tp.AsyncGenerator[bytes, tp.Any]:
prompt_tokens = vsurge.count_tokens(content)
sampling_params = self._create_sampling_params_from_request(request)
try:
async for response_state in vsurge.complete(
request=vSurgeRequest.from_sampling_params(
prompt=content,
sampling_params=sampling_params,
)
):
response_state = response_state[0]
chunk_usage = await self._create_usage_info_async(
prompt_tokens=prompt_tokens,
completion_tokens=response_state.num_generated_tokens,
total_tokens=prompt_tokens + response_state.num_generated_tokens,
time_spent_computing=0,
tokens_per_second=response_state.tokens_per_second,
)
if isinstance(request, ChatCompletionRequest):
stream_resp = ChatCompletionStreamResponse(
model=request.model,
choices=[
ChatCompletionStreamResponseChoice(
index=response_state.num_generated_tokens,
delta=DeltaMessage(
role="assistant"
if response_state.num_generated_tokens == 0
else None,
content=response_state.text,
function_call=None,
),
finish_reason=None,
)
],
usage=chunk_usage,
)
elif isinstance(request, CompletionRequest):
stream_resp = CompletionStreamResponse(
model=request.model,
choices=[
CompletionStreamResponseChoice(
text=response_state.text,
index=response_state.num_generated_tokens,
finish_reason=None,
)
],
usage=chunk_usage,
)
else:
raise NotImplementedError("UnKnown request type!")
last_response_state = response_state
yield (
"data: " + stream_resp.model_dump_json(exclude_unset=True) + "\n\n"
).encode("utf-8")
except Exception as e:
self.logger.exception(f"Error during streaming generation: {e}")
yield (
f"data: {create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, str(e)).body.decode()}" # type: ignore
+ "\n\n"
).encode("utf-8")
return
if last_response_state is not None:
finish_reason = (
"length"
if last_response_state.num_generated_tokens >= sampling_params.max_tokens
else "stop"
)
final_usage = await self._create_usage_info_async(
prompt_tokens=prompt_tokens,
completion_tokens=last_response_state.num_generated_tokens,
total_tokens=prompt_tokens + last_response_state.num_generated_tokens,
time_spent_computing=0,
tokens_per_second=last_response_state.tokens_per_second,
)
if isinstance(request, ChatCompletionRequest):
final_resp = ChatCompletionStreamResponse(
model=request.model,
choices=[
ChatCompletionStreamResponseChoice(
index=last_response_state.num_generated_tokens + 1,
delta=DeltaMessage(),
finish_reason=finish_reason,
)
],
usage=final_usage,
)
elif isinstance(request, CompletionRequest):
final_resp = CompletionStreamResponse(
model=request.model,
choices=[
CompletionStreamResponseChoice(
text="",
index=response_state.num_generated_tokens + 1,
finish_reason=finish_reason,
)
],
usage=final_usage,
)
else:
raise NotImplementedError("UnKnown request type!")
yield (
"data: " + final_resp.model_dump_json(exclude_unset=True) + "\n\n"
).encode("utf-8")
else:
self.logger.warning("Streaming finished without producing any response state.")
return StreamingResponse(stream_results(), media_type="text/event-stream")
async def _create_usage_info_async(
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
time_spent_computing: float,
tokens_per_second: float,
) -> UsageInfo:
"""Async helper to create UsageInfo, potentially offloading if needed."""
# Currently, UsageInfo creation is trivial, so no need to offload.
# If it became complex, we could use run_in_executor here.
return self._create_usage_info(
prompt_tokens,
completion_tokens,
total_tokens,
time_spent_computing,
tokens_per_second,
)
[docs] async def liveness(self):
"""Liveness check endpoint (GET /liveness)."""
return JSONResponse({"status": "alive"}, status_code=200)
[docs] async def readiness(self):
"""Readiness check endpoint (GET /readiness)."""
# Basic check: server is running. Could be extended to check model loading status.
return JSONResponse({"status": "ready"}, status_code=200)
[docs] async def available_inference(self):
"""Lists available models (GET /v1/models)."""
models_data = [
{
"id": model_id,
"object": "model",
"owned_by": "easydel", # Or customize as needed
"permission": [],
}
for model_id in self.vsurge_map.keys()
]
return JSONResponse({"object": "list", "data": models_data}, status_code=200)
[docs] async def count_tokens(self, request: CountTokenRequest):
"""Token counting endpoint (POST /v1/count_tokens)."""
try:
conv = request.conversation
model_name = request.model
vsurge = self._get_vsurge(model_name)
# Run token counting in thread pool as it might involve processing
num_tokens = await asyncio.get_event_loop().run_in_executor(
self.thread_pool,
vsurge.count_tokens,
self._prepare_vsurge_input(conv, vsurge),
)
return JSONResponse({"model": model_name, "count": num_tokens}, status_code=200)
except Exception as e:
self.logger.exception(f"Error counting tokens for model {request.model}: {e}")
return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
def _register_endpoints(self):
"""Registers all defined API endpoints with the FastAPI application."""
for endpoint in self._endpoints:
# The handler needs to be wrapped to be recognized correctly by FastAPI's decorators
# when defined within a class.
APP.add_api_route(
path=endpoint.path,
endpoint=endpoint.handler,
methods=endpoint.methods,
summary=endpoint.summary,
tags=endpoint.tags,
response_model=None, # Let FastAPI infer or handle manually
)
[docs] def fire(
self,
host="0.0.0.0",
port=11556,
metrics_port: tp.Optional[int] = None,
log_level="info", # Changed default log level
ssl_keyfile: tp.Optional[str] = None,
ssl_certfile: tp.Optional[str] = None,
):
"""
Starts the uvicorn server to run the FastAPI application.
Args:
host (str): The host address to bind to. Defaults to "0.0.0.0".
port (int): The port to listen on. Defaults to 11556.
metrics_port (tp.Optional[int]): The port for the Prometheus metrics server.
If None, defaults to `port + 1`. Set to -1 to disable.
log_level (str): The logging level for uvicorn. Defaults to "info".
ssl_keyfile (tp.Optional[str]): Path to the SSL key file for HTTPS.
ssl_certfile (tp.Optional[str]): Path to the SSL certificate file for HTTPS.
"""
if metrics_port is None:
metrics_port = port + 1
if metrics_port > 0 and is_package_available("prometheus_client"):
try:
from prometheus_client import start_http_server # type:ignore
start_http_server(metrics_port)
self.logger.info(f"Prometheus metrics server started on port {metrics_port}")
except Exception as e:
self.logger.error(f"Failed to start Prometheus metrics server: {e}")
elif metrics_port > 0:
self.logger.warning(
"Prometheus metrics requested but `prometheus_client` is not installed. "
"Metrics server will not start. Install with `pip install prometheus-client`."
)
uvicorn_config = {
"host": host,
"port": port,
"log_level": log_level,
"timeout_keep_alive": TIMEOUT_KEEP_ALIVE,
}
# Use uvloop if available for better performance
try:
import uvloop # type:ignore #noqa
uvicorn_config["loop"] = "uvloop"
self.logger.info("Using uvloop for the event loop.")
except ImportError:
self.logger.info("uvloop not found, using default asyncio event loop.")
if ssl_keyfile and ssl_certfile:
uvicorn_config["ssl_keyfile"] = ssl_keyfile
uvicorn_config["ssl_certfile"] = ssl_certfile
self.logger.info(f"Running with HTTPS enabled on {host}:{port}")
else:
self.logger.info(f"Running with HTTP on {host}:{port}")
uvicorn.run(APP, **uvicorn_config)