Source code for easydel.inference.tools.tool_calling_mixin

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mixin class for tool calling functionality in inference servers."""

from __future__ import annotations

import typing as tp
from http import HTTPStatus

from eformer.loggings import get_logger
from fastapi.responses import JSONResponse

from ..openai_api_modules import ChatCompletionRequest, ChatMessage, DeltaMessage
from . import ToolParser, ToolParserManager

logger = get_logger("ToolCallingMixin")


[docs]class ToolCallingMixin: """Mixin class providing tool calling functionality for inference API servers. This mixin provides: - Tool parser initialization and management - Tool call extraction for batch responses - Tool call extraction for streaming responses - Tool listing and metadata endpoints Classes using this mixin should have: - self.tool_parsers: dict[str, ToolParser] - self.tool_parser_name: str - self.enable_function_calling: bool """ tool_parsers: dict[str, ToolParser]
[docs] def initialize_tool_parsers( self, model_processors: dict[str, tp.Any], tool_parser_name: str, enable_function_calling: bool, ) -> dict[str, ToolParser]: """Initialize tool parsers for models. Args: model_processors: Dictionary mapping model names to their processors/tokenizers tool_parser_name: Name of the tool parser to use (e.g., "hermes", "qwen") enable_function_calling: Whether to enable function calling Returns: Dictionary mapping model names to their tool parsers """ tool_parsers = {} if not enable_function_calling: return tool_parsers for model_name, processor in model_processors.items(): try: parser_class = ToolParserManager.get_tool_parser(tool_parser_name) tool_parsers[model_name] = parser_class(processor) logger.info(f"Initialized {tool_parser_name} tool parser for model {model_name}") except KeyError: logger.warning(f"Tool parser '{tool_parser_name}' not found, function calling disabled for {model_name}") except Exception as e: logger.warning(f"Failed to initialize tool parser for {model_name}: {e}") return tool_parsers
[docs] def extract_tool_calls_batch( self, response_text: str, request: ChatCompletionRequest, model_name: str, ) -> tuple[ChatMessage, str]: """Extract tool calls from a batch response. Args: response_text: The generated text response request: The original chat completion request model_name: The model name to get the appropriate parser Returns: Tuple of (ChatMessage with potential tool calls, finish_reason) """ if not hasattr(self, "tool_parsers") or model_name not in self.tool_parsers: return ChatMessage(role="assistant", content=response_text), "stop" tool_parser = self.tool_parsers[model_name] extracted = tool_parser.extract_tool_calls(response_text, request) if extracted.tools_called and extracted.tool_calls: message = ChatMessage( role="assistant", content=extracted.content, tool_calls=extracted.tool_calls, ) finish_reason = "function_call" else: message = ChatMessage( role="assistant", content=extracted.content if extracted.content else response_text, ) finish_reason = "stop" return message, finish_reason
[docs] def extract_tool_calls_streaming( self, model_name: str, previous_text: str, current_text: str, delta_text: str, previous_token_ids: list[int] | None = None, current_token_ids: list[int] | None = None, delta_token_ids: list[int] | None = None, request: ChatCompletionRequest | None = None, ) -> DeltaMessage | None: """Extract tool calls from streaming response. Args: model_name: The model name to get the appropriate parser previous_text: Previously accumulated text current_text: Current accumulated text delta_text: New text in this chunk previous_token_ids: Previous token IDs (optional) current_token_ids: Current token IDs (optional) delta_token_ids: Delta token IDs (optional) request: The original request (optional) Returns: DeltaMessage with tool call information or None """ if not hasattr(self, "tool_parsers") or model_name not in self.tool_parsers: return None tool_parser = self.tool_parsers[model_name] previous_token_ids = previous_token_ids or [] current_token_ids = current_token_ids or [] delta_token_ids = delta_token_ids or [] return tool_parser.extract_tool_calls_streaming( previous_text=previous_text, current_text=current_text, delta_text=delta_text, previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, delta_token_ids=delta_token_ids, request=request, )
[docs] def get_tool_parser_for_model(self, model_name: str) -> ToolParser | None: """Get the tool parser for a specific model. Args: model_name: Name of the model Returns: ToolParser instance or None if not available """ if not hasattr(self, "tool_parsers"): return None return self.tool_parsers.get(model_name)
[docs] def create_tools_response(self, model_names: list[str]) -> dict[str, tp.Any]: """Create a standardized tools response for listing endpoints. Args: model_names: List of available model names Returns: Dictionary with tools information for each model """ tools_by_model = {} for model_name in model_names: model_tools = [ { "type": "function", "function": { "name": "example_function", "description": "An example function for demonstration", "parameters": { "type": "object", "properties": { "param1": { "type": "string", "description": "First parameter", }, "param2": { "type": "number", "description": "Second parameter", }, }, "required": ["param1"], }, }, } ] has_parser = hasattr(self, "tool_parsers") and model_name in self.tool_parsers tools_by_model[model_name] = { "tools": model_tools, "tool_parser": getattr(self, "tool_parser_name", None) if has_parser else None, "formats_supported": list(ToolParserManager.tool_parsers.keys()), "parallel_calls": True, } return {"models": tools_by_model, "default_format": "openai"}
[docs] def create_tool_execution_placeholder(self) -> JSONResponse: """Create a placeholder response for tool execution endpoints. Returns: JSONResponse with NOT_IMPLEMENTED status """ error_response = { "error": { "message": "Tool execution endpoint is a placeholder. Implement based on your needs.", "type": HTTPStatus.NOT_IMPLEMENTED.name, } } return JSONResponse(content=error_response, status_code=HTTPStatus.NOT_IMPLEMENTED.value)