Source code for easydel.inference.vinference.api_server.api_server_test
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example asynchronous client script for testing the vInference API server."""
import asyncio
import json
import typing as tp
import aiohttp
[docs]class ChatCompletionClient:
"""An asynchronous client for interacting with the chat completion endpoint."""
def __init__(self, base_url: str):
"""
Initializes the asynchronous client.
Args:
base_url (str): The base URL of the vInference API server (e.g., "http://127.0.0.1:7680").
"""
self.base_url = base_url
[docs] async def create_chat_completion(
self,
messages: tp.List[tp.Dict[str, str]],
model: str,
stream: bool = True,
**kwargs,
) -> tp.AsyncGenerator[tp.Dict[str, tp.Any], None]:
"""
Sends a chat completion request to the server and streams the response.
Args:
messages (tp.List[tp.Dict[str, str]]): A list of message dictionaries, e.g.,
`[{"role": "user", "content": "Hello!"}]`.
model (str): The name of the model to use.
stream (bool): Whether to request a streaming response. Defaults to True.
**kwargs: Additional parameters to pass to the API (e.g., temperature, max_tokens).
Yields:
tp.Dict[str, tp.Any]: Each chunk of the response as a dictionary.
Raises:
Exception: If the server returns a non-200 status code.
"""
url = f"{self.base_url}/v1/chat/completions"
payload = {"messages": messages, "model": model, "stream": stream, **kwargs}
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as response:
if response.status != 200:
raise Exception(f"Error: {response.status} - {await response.text()}")
async for line in response.content:
line = line.decode("utf-8").strip()
if line.startswith("data: "):
data = json.loads(line[6:])
yield data
[docs]async def main():
"""Main function to run the example chat completion interaction."""
client = ChatCompletionClient("http://127.0.0.1:7680") # Adjust URL if needed
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "write a neural network in c++ and rust and compare them",
},
]
model_name = "llama-3-8b" # Replace with your actual running model name
print(f"Sending request to model: {model_name}")
try:
async for chunk in client.create_chat_completion(
messages,
model=model_name,
max_tokens=512, # Example: setting max_tokens
):
if (
chunk["choices"]
and chunk["choices"][0].get("delta")
and chunk["choices"][0]["delta"].get("content")
):
print(chunk["choices"][0]["delta"]["content"], end="", flush=True)
# Check for finish reason in the last chunk
if chunk["choices"] and chunk["choices"][0].get("finish_reason"):
print("\n--- Finish Reason ---")
print(chunk["choices"][0]["finish_reason"])
if chunk.get("usage"):
print("--- Usage --- ")
print(chunk["usage"])
break # Stop after finish reason is received
except Exception as e:
print(f"\nAn error occurred: {e}")
if __name__ == "__main__":
asyncio.run(main())