Source code for easydel.inference.tools.abstract_tool

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import importlib
import importlib.metadata
import importlib.util
import os
import sys
from collections.abc import Callable, Sequence
from functools import cached_property
from typing import Literal, assert_never

from transformers import AutoTokenizer as AnyTokenizer

from ..openai_api_modules import ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation


[docs]class ToolParser: """Abstract base class for tool call parsing from LLM outputs. This class provides the foundation for extracting function/tool calls from language model outputs. Subclasses implement model-specific parsing logic for different tool call formats (JSON, XML, pythonic, etc.). The parser maintains state for streaming responses and provides methods for both batch and streaming extraction of tool calls. Attributes: prev_tool_call_arr (list[dict]): History of previously parsed tool calls current_tool_id (int): ID counter for tool calls in current session current_tool_name_sent (bool): Flag indicating if current tool name was sent streamed_args_for_tool (list[str]): Buffer for streaming tool arguments model_tokenizer (AnyTokenizer): Tokenizer instance for the model Note: This is an abstract class - use model-specific subclasses like HermesToolParser, MistralToolParser, etc. """ def __init__(self, tokenizer: AnyTokenizer): self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False self.streamed_args_for_tool: list[str] = [] self.model_tokenizer = tokenizer @cached_property def vocab(self) -> dict[str, int]: """Get the tokenizer vocabulary. Returns: dict[str, int]: Mapping of tokens to their IDs Note: Only PreTrainedTokenizerFast is guaranteed to have .vocab """ return self.model_tokenizer.get_vocab()
[docs] def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: """Adjust request parameters for model-specific requirements. Override this method to modify request parameters like system prompts, tool definitions, or formatting before processing. Default implementation returns the request unchanged. Args: request: Original chat completion request Returns: ChatCompletionRequest: Potentially modified request Example: Some models may need to reformat tool definitions or add specific system instructions for proper tool calling. """ return request
[docs] def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: """Extract tool calls from complete model output (batch mode). Parses the entire model response to identify and extract tool/function calls. This method is used for non-streaming responses where the complete output is available. Args: model_output: Complete text generated by the model request: Original request containing tool definitions Returns: ExtractedToolCallInformation: Parsed tool calls and remaining content Raises: NotImplementedError: Must be implemented by subclasses Note: This method is stateless - it doesn't use instance state. Each parser implements model-specific extraction logic. """ raise NotImplementedError("AbstractToolParser.extract_tool_calls has not been implemented!")
[docs] def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: """Extract tool calls from streaming model output. Processes incremental model output to identify partial tool calls and emit appropriate streaming updates. Maintains state across calls to handle incomplete JSON/XML structures. Args: previous_text: Text accumulated up to previous call current_text: Text accumulated including current chunk delta_text: New text in current chunk previous_token_ids: Token IDs up to previous call current_token_ids: Token IDs including current chunk delta_token_ids: New token IDs in current chunk request: Original request with tool definitions Returns: DeltaMessage: Incremental tool call update, or None if no update Raises: NotImplementedError: Must be implemented by subclasses Note: This method is stateful - it uses instance variables to track parsing progress across streaming chunks. """ raise NotImplementedError("AbstractToolParser.extract_tool_calls_streaming has not been implemented!")
[docs]class ToolParserManager: """Registry and manager for tool parser implementations. This class provides a centralized registry for tool parsers, allowing dynamic registration and retrieval of parser implementations. It supports both decorator-based and direct registration patterns. The manager enables: - Registration of custom parser implementations - Retrieval of parsers by name - Dynamic loading of parser plugins from external files - Validation that parsers inherit from ToolParser base class Attributes: tool_parsers (dict[str, type]): Registry mapping parser names to classes Example: ```python # Register using decorator @ToolParserManager.register_module("my_parser") class MyCustomParser(ToolParser): ... # Retrieve parser parser_class = ToolParserManager.get_tool_parser("my_parser") parser = parser_class(tokenizer) ``` """ tool_parsers: dict[str, type] = {} # noqa: RUF012
[docs] @classmethod def get_tool_parser(cls, name: str) -> type: """Retrieve a registered tool parser by name. Args: name: Name of the parser to retrieve Returns: type: The parser class registered with the given name Raises: KeyError: If no parser is registered with the given name Example: ```python HermesParser = ToolParserManager.get_tool_parser("hermes") parser_instance = HermesParser(tokenizer) ``` """ if name in cls.tool_parsers: return cls.tool_parsers[name] raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
@classmethod def _register_module(cls, module: type, module_name: str | list[str] | None = None, force: bool = True) -> None: """Internal method to register a parser module. Args: module: Parser class to register (must inherit from ToolParser) module_name: Name(s) to register the parser under (defaults to class name) force: If True, overwrites existing registration; if False, raises error Raises: TypeError: If module doesn't inherit from ToolParser KeyError: If name already registered and force=False """ if not issubclass(module, ToolParser): raise TypeError(f"module must be subclass of ToolParser, but got {type(module)}") if module_name is None: module_name = module.__name__ if isinstance(module_name, str): module_name = [module_name] for name in module_name: if not force and name in cls.tool_parsers: existed_module = cls.tool_parsers[name] raise KeyError(f"{name} is already registered at {existed_module.__module__}") cls.tool_parsers[name] = module
[docs] @classmethod def register_module( cls, name: str | list[str] | None = None, force: bool = True, module: type | None = None ) -> type | Callable: """Register a tool parser module. Can be used as a decorator or called directly. Supports registering a parser under multiple names. Args: name: Name(s) to register the parser under (defaults to class name) force: If True, overwrites existing; if False, raises error on conflict module: Parser class to register (if None, returns decorator) Returns: type | Callable: The registered module or a decorator function Raises: TypeError: If arguments have incorrect types KeyError: If name conflict and force=False Examples: ```python # As decorator @ToolParserManager.register_module("custom") class CustomParser(ToolParser): ... # Direct registration ToolParserManager.register_module( name="alternate", module=CustomParser ) # Multiple names @ToolParserManager.register_module(["v1", "version1"]) class V1Parser(ToolParser): ... ``` """ if not isinstance(force, bool): raise TypeError(f"force must be a boolean, but got {type(force)}") if not (name is None or isinstance(name, str) or is_list_of(name, str)): raise TypeError(f"name must be None, an instance of str, or a sequence of str, but got {type(name)}") if module is not None: cls._register_module(module=module, module_name=name, force=force) return module def _register(module): cls._register_module(module=module, module_name=name, force=force) return module return _register
[docs] @classmethod def import_tool_parser(cls, plugin_path: str) -> None: """Import and register a tool parser from an external file. Dynamically loads a Python module containing tool parser definitions. The module should contain parser classes decorated with @register_module or manually register them upon import. Args: plugin_path: File path to the Python module containing parser(s) Note: The parser class in the file should use the @register_module decorator or call register_module during module initialization. Example: ```python # In external_parser.py: @ToolParserManager.register_module("external") class ExternalParser(ToolParser): ... # Load it: ToolParserManager.import_tool_parser("/path/to/external_parser.py") ``` """ module_name = os.path.splitext(os.path.basename(plugin_path))[0] try: import_from_path(module_name, plugin_path) except Exception: print("Failed to load module '%s' from %s.", module_name, plugin_path) return
[docs]def import_from_path(module_name: str, file_path: str | os.PathLike): """Import a Python module from a file path. Dynamically imports a Python file and registers it in sys.modules. Used for loading external tool parser plugins. Args: module_name: Name to register the module under in sys.modules file_path: Path to the Python file to import Returns: module: The imported module object Raises: ModuleNotFoundError: If the file cannot be loaded as a module Note: Based on the official Python importlib recipe: https://docs.python.org/3/library/importlib.html """ spec = importlib.util.spec_from_file_location(module_name, file_path) if spec is None: raise ModuleNotFoundError(f"No module named '{module_name}'") assert spec.loader is not None module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) return module
[docs]def is_list_of(value: object, typ, *, check: Literal["first", "all"] = "first") -> bool: """Check if a value is a list of specific type. Args: value: Value to check typ: Expected type of list elements check: "first" to check only first element, "all" to check all elements Returns: bool: True if value is a list of the specified type Examples: ```python is_list_of(["a", "b"], str) # True is_list_of(["a", 1], str, check="all") # False is_list_of([], str) # True (empty list) is_list_of("not a list", str) # False ``` """ if not isinstance(value, list): return False if check == "first": return len(value) == 0 or isinstance(value[0], typ) elif check == "all": return all(isinstance(v, typ) for v in value) assert_never(check)