Source code for easydel.inference.tools.parsers.hermes_tool_parser

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import json
import re
from collections.abc import Sequence
from uuid import uuid4

import partial_json_parser
from partial_json_parser.core.options import Allow
from transformers import AutoTokenizer as AnyTokenizer

from ...openai_api_modules import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from ..abstract_tool import ToolParser, ToolParserManager


[docs]@ToolParserManager.register_module("hermes") class HermesToolParser(ToolParser): """ Tool call parser for Hermes models. Handles tool calls wrapped in <tool_call> XML-style tags with JSON content. Designed for NousResearch Hermes models and similar architectures that use XML-style delimiters for function calling. Format: <tool_call>{"name": "function_name", "arguments": {...}}</tool_call> Features: - XML-style token boundary detection (<tool_call> and </tool_call>) - Token-level buffering for accurate boundary detection - Supports multiple tool calls in a single response - Handles partial JSON parsing for streaming - Scratch pad support for intermediate reasoning Attributes: current_tool_name_sent: Tracks if function name was sent in stream prev_tool_call_arr: Previous tool calls for streaming comparison current_tool_id: Index of current tool being processed streamed_args_for_tool: Arguments sent so far for each tool tool_call_start_token: Opening delimiter for tool calls tool_call_end_token: Closing delimiter for tool calls buffered_delta_text: Buffer for multi-token delimiter detection """ def __init__(self, tokenizer: AnyTokenizer): """ Initialize the Hermes tool parser. Args: tokenizer: The model tokenizer for encoding/decoding tokens Raises: ValueError: If tokenizer is not provided """ super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] self.tool_call_start_token: str = "<tool_call>" self.tool_call_end_token: str = "</tool_call>" self.tool_call_regex = re.compile(r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL) self.scratch_pad_regex = re.compile(r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL) if not self.model_tokenizer: raise ValueError("The model tokenizer must be passed to the ToolParser constructor during construction.") self.tool_call_start_token_ids = self.model_tokenizer.encode( self.tool_call_start_token, add_special_tokens=False ) self.tool_call_end_token_ids = self.model_tokenizer.encode(self.tool_call_end_token, add_special_tokens=False) self.tool_call_start_token_array = [ self.model_tokenizer.decode([token_id]) for token_id in self.tool_call_start_token_ids ] self.tool_call_end_token_array = [ self.model_tokenizer.decode([token_id]) for token_id in self.tool_call_end_token_ids ] self.buffered_delta_text = ""
[docs] def tool_call_delta_buffer(self, delta_text: str) -> str: """ Buffer delta text to handle multi-token delimiters. This method accumulates partial tokens that might form tool call delimiters, ensuring accurate boundary detection when delimiters span multiple tokens. Args: delta_text: The new text delta from streaming Returns: Processed text with complete delimiters or empty string if buffering """ if delta_text in self.tool_call_start_token_array or delta_text in self.tool_call_end_token_array: if delta_text == self.tool_call_start_token_array[-1] or delta_text == self.tool_call_end_token_array[-1]: buffered_text = self.buffered_delta_text self.buffered_delta_text = "" return buffered_text + delta_text else: self.buffered_delta_text = self.buffered_delta_text + delta_text return "" else: if self.buffered_delta_text: buffered_text = self.buffered_delta_text self.buffered_delta_text = "" return buffered_text + delta_text else: return delta_text
[docs] def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: """ Extract tool calls from complete model response. Parses XML-style tool call tags and extracts JSON function calls. Supports multiple tool calls and returns remaining content. Args: model_output: Complete model output containing tool calls request: Original chat completion request (unused) Returns: ExtractedToolCallInformation with: - tools_called: Whether tool calls were found - tool_calls: List of ToolCall objects - content: Text content before tool calls (if any) Example: Input: "Let me help. <tool_call>{"name": "search", "arguments": {"q": "weather"}}</tool_call>" Output: tools_called=True, tool_calls=[ToolCall(...)], content="Let me help. " """ if self.tool_call_start_token not in model_output: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) else: try: function_call_tuples = self.tool_call_regex.findall(model_output) raw_function_calls = [json.loads(match[0] if match[0] else match[1]) for match in function_call_tuples] tool_calls = [ ToolCall( type="function", function=FunctionCall( name=function_call["name"], arguments=json.dumps(function_call["arguments"], ensure_ascii=False), ), ) for function_call in raw_function_calls ] content = model_output[: model_output.find(self.tool_call_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None ) except Exception: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
[docs] def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: """ Extract tool calls from streaming model output. Handles incremental parsing of tool calls during streaming generation. Maintains state across calls to track partial tool calls and arguments. Uses buffering to handle multi-token delimiters correctly. Args: previous_text: Text generated before this delta current_text: Text including this delta delta_text: New text in this streaming chunk previous_token_ids: Token IDs before this delta current_token_ids: Token IDs including this delta delta_token_ids: New token IDs in this chunk request: Original chat completion request Returns: DeltaMessage with incremental tool call updates or content, or None if more tokens needed for parsing State Management: - Tracks tool call boundaries with start/end token counts - Maintains current tool ID for multi-tool responses - Buffers partial arguments until complete - Handles transition between content and tool calls """ delta_text = self.tool_call_delta_buffer(delta_text) if ( len(previous_text) >= len(self.buffered_delta_text) and previous_text[-len(self.buffered_delta_text) :] == self.buffered_delta_text ): previous_text = previous_text[: -len(self.buffered_delta_text)] current_text = previous_text + delta_text if self.tool_call_start_token not in current_text: return DeltaMessage(content=delta_text) try: prev_tool_start_count = previous_text.count(self.tool_call_start_token) prev_tool_end_count = previous_text.count(self.tool_call_end_token) cur_tool_start_count = current_text.count(self.tool_call_start_token) cur_tool_end_count = current_text.count(self.tool_call_end_token) tool_call_portion = None text_portion = None if ( cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count and self.tool_call_end_token not in delta_text ): return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: full_text = current_text + delta_text tool_call_portion = ( full_text.split(self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0].rstrip() ) delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: if len(delta_token_ids) > 1: tool_call_portion = current_text.split(self.tool_call_start_token)[-1] else: tool_call_portion = None delta = None text_portion = None self.current_tool_id += 1 self.current_tool_name_sent = False self.streamed_args_for_tool.append("") elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count >= prev_tool_end_count: if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: diff = diff.encode("utf-8").decode("unicode_escape") if diff is str else diff if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') diff = delta_text[:end_loc] + '"}' self.streamed_args_for_tool[self.current_tool_id] += diff return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True), ) ] ) else: text = delta_text.replace(self.tool_call_start_token, "") text = text.replace(self.tool_call_end_token, "") delta = DeltaMessage(tool_calls=[], content=text) return delta try: current_tool_call = ( partial_json_parser.loads(tool_call_portion or "{}", flags) if tool_call_portion else None ) except partial_json_parser.core.exceptions.MalformedJSON: return None except json.decoder.JSONDecodeError: return None if not self.current_tool_name_sent: if current_tool_call is None: return None function_name: str | None = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, type="function", id=f"chatcmpl-tool-{uuid4()}", function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True), ) ] ) else: return None if tool_call_portion is None: delta = DeltaMessage(content=delta_text) if text_portion is not None else None return delta if len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get("arguments") if not cur_arguments and not prev_arguments: delta = None elif not cur_arguments and prev_arguments: delta = None elif cur_arguments and not prev_arguments: cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) if delta_text not in cur_arguments_json[:-2]: return None args_delta_start_loc = cur_arguments_json[:-2].rindex(delta_text) + len(delta_text) arguments_delta = cur_arguments_json[:args_delta_start_loc] delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=arguments_delta).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] += arguments_delta elif cur_arguments and prev_arguments: if isinstance(delta_text, str) and len(delta_text.rstrip()) >= 1 and delta_text.rstrip()[-1] == "}": delta_text = delta_text.rstrip()[:-1] delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=delta_text).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] += delta_text if self.current_tool_id == len(self.prev_tool_call_arr) - 1: self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) return delta except Exception: return None