Source code for easydel.inference.tools.abstract_tool
# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import importlib
import importlib.metadata
import importlib.util
import os
import sys
from collections.abc import Callable, Sequence
from functools import cached_property
from typing import Literal, assert_never
from transformers import AutoTokenizer as AnyTokenizer
from ..openai_api_modules import ChatCompletionRequest, DeltaMessage, ExtractedToolCallInformation
[docs]class ToolParser:
"""Abstract base class for tool call parsing from LLM outputs.
This class provides the foundation for extracting function/tool calls from
language model outputs. Subclasses implement model-specific parsing logic
for different tool call formats (JSON, XML, pythonic, etc.).
The parser maintains state for streaming responses and provides methods
for both batch and streaming extraction of tool calls.
Attributes:
prev_tool_call_arr (list[dict]): History of previously parsed tool calls
current_tool_id (int): ID counter for tool calls in current session
current_tool_name_sent (bool): Flag indicating if current tool name was sent
streamed_args_for_tool (list[str]): Buffer for streaming tool arguments
model_tokenizer (AnyTokenizer): Tokenizer instance for the model
Note:
This is an abstract class - use model-specific subclasses like
HermesToolParser, MistralToolParser, etc.
"""
def __init__(self, tokenizer: AnyTokenizer):
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[str] = []
self.model_tokenizer = tokenizer
@cached_property
def vocab(self) -> dict[str, int]:
"""Get the tokenizer vocabulary.
Returns:
dict[str, int]: Mapping of tokens to their IDs
Note:
Only PreTrainedTokenizerFast is guaranteed to have .vocab
"""
return self.model_tokenizer.get_vocab()
[docs] def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
"""Adjust request parameters for model-specific requirements.
Override this method to modify request parameters like system prompts,
tool definitions, or formatting before processing. Default implementation
returns the request unchanged.
Args:
request: Original chat completion request
Returns:
ChatCompletionRequest: Potentially modified request
Example:
Some models may need to reformat tool definitions or add
specific system instructions for proper tool calling.
"""
return request
[docs] def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (batch mode).
Parses the entire model response to identify and extract tool/function
calls. This method is used for non-streaming responses where the complete
output is available.
Args:
model_output: Complete text generated by the model
request: Original request containing tool definitions
Returns:
ExtractedToolCallInformation: Parsed tool calls and remaining content
Raises:
NotImplementedError: Must be implemented by subclasses
Note:
This method is stateless - it doesn't use instance state.
Each parser implements model-specific extraction logic.
"""
raise NotImplementedError("AbstractToolParser.extract_tool_calls has not been implemented!")
[docs] def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output.
Processes incremental model output to identify partial tool calls
and emit appropriate streaming updates. Maintains state across
calls to handle incomplete JSON/XML structures.
Args:
previous_text: Text accumulated up to previous call
current_text: Text accumulated including current chunk
delta_text: New text in current chunk
previous_token_ids: Token IDs up to previous call
current_token_ids: Token IDs including current chunk
delta_token_ids: New token IDs in current chunk
request: Original request with tool definitions
Returns:
DeltaMessage: Incremental tool call update, or None if no update
Raises:
NotImplementedError: Must be implemented by subclasses
Note:
This method is stateful - it uses instance variables to track
parsing progress across streaming chunks.
"""
raise NotImplementedError("AbstractToolParser.extract_tool_calls_streaming has not been implemented!")
[docs]class ToolParserManager:
"""Registry and manager for tool parser implementations.
This class provides a centralized registry for tool parsers, allowing
dynamic registration and retrieval of parser implementations. It supports
both decorator-based and direct registration patterns.
The manager enables:
- Registration of custom parser implementations
- Retrieval of parsers by name
- Dynamic loading of parser plugins from external files
- Validation that parsers inherit from ToolParser base class
Attributes:
tool_parsers (dict[str, type]): Registry mapping parser names to classes
Example:
```python
# Register using decorator
@ToolParserManager.register_module("my_parser")
class MyCustomParser(ToolParser):
...
# Retrieve parser
parser_class = ToolParserManager.get_tool_parser("my_parser")
parser = parser_class(tokenizer)
```
"""
tool_parsers: dict[str, type] = {} # noqa: RUF012
[docs] @classmethod
def get_tool_parser(cls, name: str) -> type:
"""Retrieve a registered tool parser by name.
Args:
name: Name of the parser to retrieve
Returns:
type: The parser class registered with the given name
Raises:
KeyError: If no parser is registered with the given name
Example:
```python
HermesParser = ToolParserManager.get_tool_parser("hermes")
parser_instance = HermesParser(tokenizer)
```
"""
if name in cls.tool_parsers:
return cls.tool_parsers[name]
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
@classmethod
def _register_module(cls, module: type, module_name: str | list[str] | None = None, force: bool = True) -> None:
"""Internal method to register a parser module.
Args:
module: Parser class to register (must inherit from ToolParser)
module_name: Name(s) to register the parser under (defaults to class name)
force: If True, overwrites existing registration; if False, raises error
Raises:
TypeError: If module doesn't inherit from ToolParser
KeyError: If name already registered and force=False
"""
if not issubclass(module, ToolParser):
raise TypeError(f"module must be subclass of ToolParser, but got {type(module)}")
if module_name is None:
module_name = module.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in cls.tool_parsers:
existed_module = cls.tool_parsers[name]
raise KeyError(f"{name} is already registered at {existed_module.__module__}")
cls.tool_parsers[name] = module
[docs] @classmethod
def register_module(
cls, name: str | list[str] | None = None, force: bool = True, module: type | None = None
) -> type | Callable:
"""Register a tool parser module.
Can be used as a decorator or called directly. Supports registering
a parser under multiple names.
Args:
name: Name(s) to register the parser under (defaults to class name)
force: If True, overwrites existing; if False, raises error on conflict
module: Parser class to register (if None, returns decorator)
Returns:
type | Callable: The registered module or a decorator function
Raises:
TypeError: If arguments have incorrect types
KeyError: If name conflict and force=False
Examples:
```python
# As decorator
@ToolParserManager.register_module("custom")
class CustomParser(ToolParser):
...
# Direct registration
ToolParserManager.register_module(
name="alternate",
module=CustomParser
)
# Multiple names
@ToolParserManager.register_module(["v1", "version1"])
class V1Parser(ToolParser):
...
```
"""
if not isinstance(force, bool):
raise TypeError(f"force must be a boolean, but got {type(force)}")
if not (name is None or isinstance(name, str) or is_list_of(name, str)):
raise TypeError(f"name must be None, an instance of str, or a sequence of str, but got {type(name)}")
if module is not None:
cls._register_module(module=module, module_name=name, force=force)
return module
def _register(module):
cls._register_module(module=module, module_name=name, force=force)
return module
return _register
[docs] @classmethod
def import_tool_parser(cls, plugin_path: str) -> None:
"""Import and register a tool parser from an external file.
Dynamically loads a Python module containing tool parser definitions.
The module should contain parser classes decorated with @register_module
or manually register them upon import.
Args:
plugin_path: File path to the Python module containing parser(s)
Note:
The parser class in the file should use the @register_module
decorator or call register_module during module initialization.
Example:
```python
# In external_parser.py:
@ToolParserManager.register_module("external")
class ExternalParser(ToolParser):
...
# Load it:
ToolParserManager.import_tool_parser("/path/to/external_parser.py")
```
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
try:
import_from_path(module_name, plugin_path)
except Exception:
print("Failed to load module '%s' from %s.", module_name, plugin_path)
return
[docs]def import_from_path(module_name: str, file_path: str | os.PathLike):
"""Import a Python module from a file path.
Dynamically imports a Python file and registers it in sys.modules.
Used for loading external tool parser plugins.
Args:
module_name: Name to register the module under in sys.modules
file_path: Path to the Python file to import
Returns:
module: The imported module object
Raises:
ModuleNotFoundError: If the file cannot be loaded as a module
Note:
Based on the official Python importlib recipe:
https://docs.python.org/3/library/importlib.html
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ModuleNotFoundError(f"No module named '{module_name}'")
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
[docs]def is_list_of(value: object, typ, *, check: Literal["first", "all"] = "first") -> bool:
"""Check if a value is a list of specific type.
Args:
value: Value to check
typ: Expected type of list elements
check: "first" to check only first element, "all" to check all elements
Returns:
bool: True if value is a list of the specified type
Examples:
```python
is_list_of(["a", "b"], str) # True
is_list_of(["a", 1], str, check="all") # False
is_list_of([], str) # True (empty list)
is_list_of("not a list", str) # False
```
"""
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)