# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OpenAI API compatibility models and utilities.
This module provides Pydantic models and utilities for OpenAI API compatibility,
enabling EasyDeL inference engines to work with OpenAI-compatible clients and tools.
Key Components:
- Request/Response models for chat completions and text completions
- Function calling support with multiple format parsers
- Token usage tracking and metrics
- Streaming response models
Classes:
ChatMessage: Single message in a conversation
DeltaMessage: Incremental message for streaming
UsageInfo: Token usage and performance metrics
ChatCompletionRequest: Request for chat completions
ChatCompletionResponse: Response from chat completions
CompletionRequest: Request for text completions
CompletionResponse: Response from text completions
FunctionCallFormat: Supported function call formats
FunctionCallFormatter: Formatter for function call prompts
FunctionCallParser: Parser for extracting function calls
Example:
>>> from easydel.inference.openai_api_modules import (
... ChatCompletionRequest,
... ChatMessage
... )
>>> request = ChatCompletionRequest(
... model="gpt-3.5-turbo",
... messages=[
... ChatMessage(role="user", content="Hello!")
... ],
... temperature=0.7
... )
"""
import time
import typing as tp
import uuid
from enum import Enum
from pydantic import BaseModel, ConfigDict, Field, model_validator
[docs]class OpenAIBaseModel(BaseModel):
model_config = ConfigDict(extra="allow")
field_names: tp.ClassVar[set[str] | None] = None
@model_validator(mode="wrap")
@classmethod
def __log_extra_fields__(cls, data, handler):
result = handler(data)
if not isinstance(data, dict):
return result
field_names = cls.field_names
if field_names is None:
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if alias := getattr(field, "alias", None):
field_names.add(alias)
cls.field_names = field_names
return result
[docs]class ChatMessage(OpenAIBaseModel):
"""Represents a single message in a chat conversation.
Attributes:
role: Message role (system, user, assistant, function)
content: Message content (text or structured)
name: Optional name for the message sender
function_call: Optional function call made by assistant
"""
role: str
content: str | list[tp.Mapping[str, str]]
name: str | None = None
function_call: dict[str, tp.Any] | None = None
[docs]class DeltaMessage(OpenAIBaseModel):
"""Represents a change (delta) in a chat message.
Used in streaming responses to send incremental updates.
Attributes:
role: Optional role if starting new message
content: Incremental content to append
function_call: Optional function call updates
"""
role: str | None = None
content: str | list[tp.Mapping[str, str]] | None = None
function_call: dict[str, tp.Any] | None = None
class Function(OpenAIBaseModel):
"""Function definition for OpenAI-compatible function calling."""
name: str
description: str | None = None
parameters: dict[str, tp.Any] = Field(default_factory=dict)
[docs]class DeltaFunctionCall(OpenAIBaseModel):
name: str | None = None
arguments: str | None = None
[docs]class UsageInfo(OpenAIBaseModel):
"""Token usage and performance metrics.
Tracks computational resources used for a request.
Attributes:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens generated
total_tokens: Sum of prompt and completion tokens
tokens_per_second: Generation speed
processing_time: Total processing time in seconds
"""
prompt_tokens: int = 0
completion_tokens: int | None = 0
total_tokens: int = 0
tokens_per_second: float = 0
processing_time: float = 0.0
[docs]class FunctionDefinition(OpenAIBaseModel):
"""Defines a function that can be called by the model.
Attributes:
name: Function name
description: Function description for the model
parameters: JSON Schema for function parameters
required: List of required parameter names
"""
name: str
description: str | None = None
parameters: dict[str, tp.Any] = Field(default_factory=dict)
required: list[str] | None = None
[docs]class ChatCompletionRequest(OpenAIBaseModel):
"""
Represents a request to the chat completion endpoint.
Mirrors the OpenAI ChatCompletion request structure.
"""
model: str
messages: list[ChatMessage]
max_tokens: int | None = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 0
min_p: float = 0.0
suppress_tokens: list[int] = Field(default_factory=list)
functions: list[FunctionDefinition] | None = None
function_call: str | dict[str, tp.Any] | None = None
tools: list[ToolDefinition] | None = None
tool_choice: str | dict[str, tp.Any] | None = None
n: int | None = 1
stream: bool | None = False
stop: str | list[str] | None = None
logit_bias: dict[str, float] | None = None # Ignored by EasyDeL
user: str | None = None # Ignored by EasyDeL
chat_template_kwargs: dict[str, int | float | str | bool] | None = None
[docs]class ChatCompletionResponseChoice(OpenAIBaseModel):
"""Represents a single choice within a non-streaming chat completion response."""
index: int
message: ChatMessage
finish_reason: tp.Literal["stop", "length", "function_call", "tool_calls", "abort"] | None = None
[docs]class ChatCompletionResponse(OpenAIBaseModel):
"""Represents a non-streaming response from the chat completion endpoint."""
id: str = Field(default_factory=lambda: f"chat-{uuid.uuid4().hex}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[ChatCompletionResponseChoice]
usage: UsageInfo
[docs]class ChatCompletionStreamResponseChoice(OpenAIBaseModel):
"""Represents a single choice within a streaming chat completion response chunk."""
index: int
delta: DeltaMessage
finish_reason: tp.Literal["stop", "length", "function_call"] | None = None
[docs]class ChatCompletionStreamResponse(OpenAIBaseModel):
"""Represents a single chunk in a streaming response from the chat completion endpoint."""
id: str = Field(default_factory=lambda: f"chat-{uuid.uuid4().hex}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[ChatCompletionStreamResponseChoice]
usage: UsageInfo # Usage info might be included in chunks, often zero until the end
[docs]class CountTokenRequest(OpenAIBaseModel):
"""Represents a request to the token counting endpoint."""
model: str
conversation: str | list[ChatMessage] # Can count tokens for a string or a list of messages
[docs]class CompletionRequest(OpenAIBaseModel):
"""
Represents a request to the completions endpoint.
Mirrors the OpenAI Completion request structure.
"""
model: str
prompt: str | list[str]
max_tokens: int | None = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
temperature: float = 0.7
top_p: float = 0.95
top_k: int = 0
min_p: float = 0.0
suppress_tokens: list[int] = Field(default_factory=list)
n: int | None = 1
stream: bool | None = False
stop: str | list[str] | None = None
logit_bias: dict[str, float] | None = None
user: str | None = None
[docs]class CompletionLogprobs(OpenAIBaseModel):
"""Log probabilities for token generation."""
tokens: list[str]
token_logprobs: list[float]
top_logprobs: list[dict[str, float]] | None = None
text_offset: list[int] | None = None
[docs]class CompletionResponseChoice(OpenAIBaseModel):
"""Represents a single choice within a completion response."""
text: str
index: int
logprobs: CompletionLogprobs | None = None
finish_reason: tp.Literal["stop", "length", "function_call"] | None = None
[docs]class CompletionResponse(OpenAIBaseModel):
"""Represents a response from the completions endpoint."""
id: str = Field(default_factory=lambda: f"cmpl-{uuid.uuid4().hex}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionResponseChoice]
usage: UsageInfo
[docs]class CompletionStreamResponseChoice(OpenAIBaseModel):
"""Represents a single choice within a streaming completion response chunk."""
index: int
text: str # The delta text content
logprobs: CompletionLogprobs | None = None
finish_reason: tp.Literal["stop", "length", "function_call"] | None = None
[docs]class CompletionStreamResponse(OpenAIBaseModel):
"""Represents a streaming response from the completions endpoint."""
id: str = Field(default_factory=lambda: f"cmpl-{uuid.uuid4().hex}")
object: str = "text_completion.chunk" # Correct object type for streaming
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[CompletionStreamResponseChoice] # Use the new streaming choice model
usage: UsageInfo | None = None
# Usage is often None until the final chunk in OAI
[docs]class FunctionCall(OpenAIBaseModel):
"""Represents a function call in the OpenAI format."""
name: str
arguments: str # JSON string of arguments
[docs]class Function(OpenAIBaseModel):
"""Function definition for OpenAI-compatible function calling."""
name: str
description: str | None = None
parameters: dict[str, tp.Any] = Field(default_factory=dict)