Source code for easydel.inference.tools.parsers.step3_tool_parser

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import contextlib
import json
import re
from collections.abc import Sequence
from typing import Any
from uuid import uuid4

from eformer.loggings import get_logger
from transformers import AutoTokenizer as AnyTokenizer

from ...openai_api_modules import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from ..abstract_tool import ToolParser, ToolParserManager

logger = get_logger(__name__)


[docs]@ToolParserManager.register_module(["step3"]) class Step3ToolParser(ToolParser): """ Tool parser for Step3 models with XML-like format. Uses a robust, stateful, cursor-based streaming parser that consolidates tool arguments into single messages. Handles complex nested tool call structures with special delimiters. Features: - Cursor-based position tracking - Hierarchical token structure with special characters - steptml:invoke format parsing - Automatic type casting based on schema - State machine for streaming parsing Format: <|tool_calls_begin|> <|tool_call_begin|> function<|tool_sep|> <steptml:invoke name="func"> <steptml:parameter name="param">value</steptml:parameter> </steptml:invoke> <|tool_call_end|> <|tool_calls_end|> Note: Uses fullwidth vertical line characters (|) as delimiters. """ # noqa TOOL_CALLS_BEGIN = "<|tool_calls_begin|>" # noqa: RUF001 TOOL_CALLS_END = "<|tool_calls_end|>" # noqa: RUF001 TOOL_CALL_BEGIN = "<|tool_call_begin|>" # noqa: RUF001 TOOL_CALL_END = "<|tool_call_end|>" # noqa: RUF001 TOOL_SEP = "<|tool_sep|>" # noqa: RUF001 SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END] # noqa: RUF012 def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.position = 0 self.tool_block_started = False self.tool_block_finished = False
[docs] def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: if request.tools and request.tool_choice != "none": request.skip_special_tokens = False return request
@staticmethod def _parse_steptml_invoke(action_text: str) -> tuple[str | None, dict[str, str] | None]: func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text) if not func_name_match: return None, None func_name = func_name_match.group(1) params: dict[str, str] = {} param_matches = re.findall(r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', action_text) for name, value in param_matches: params[name] = value.strip() return func_name, params def _cast_arguments( self, func_name: str, params: dict[str, Any], request: ChatCompletionRequest, ) -> dict[str, Any]: for tool in request.tools or []: if tool.function.name == func_name: schema = tool.function.parameters or {} properties = schema.get("properties", {}) for key, value in params.items(): if not isinstance(value, str): continue prop = properties.get(key, {}) typ = prop.get("type") if typ == "string": params[key] = value.strip() elif typ == "integer": with contextlib.suppress(ValueError): params[key] = int(value) elif typ == "number": with contextlib.suppress(ValueError): params[key] = float(value) elif typ == "boolean": lower_val = value.lower() params[key] = lower_val == "true" if lower_val in ("true", "false") else value elif typ == "null": params[key] = None if value.lower() == "null" else value break return params
[docs] def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: while True: if self.position >= len(current_text): return None unprocessed_text = current_text[self.position :] if self.tool_block_finished: self.position = len(current_text) return DeltaMessage(content=unprocessed_text) if not self.tool_block_started: if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN): self.position += len(self.TOOL_CALLS_BEGIN) self.tool_block_started = True continue start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN) if start_pos == -1: if self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip()) and unprocessed_text: return None self.position = len(current_text) return DeltaMessage(content=unprocessed_text) else: content = unprocessed_text[:start_pos] self.position += len(content) return DeltaMessage(content=content) offset = len(unprocessed_text) - len(unprocessed_text.lstrip()) unprocessed_text = unprocessed_text.lstrip() self.position += offset if unprocessed_text.startswith(self.TOOL_CALLS_END): self.position += len(self.TOOL_CALLS_END) self.tool_block_finished = True self.current_tool_id = -1 continue tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[self.current_tool_id].get("finished") if self.current_tool_id == -1 or tool_finished: if unprocessed_text.startswith(self.TOOL_CALL_BEGIN): self.position += len(self.TOOL_CALL_BEGIN) if self.current_tool_id == -1: self.current_tool_id = 0 else: self.current_tool_id += 1 self.current_tool_name_sent = False while len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) self.prev_tool_call_arr[self.current_tool_id]["finished"] = False continue if self.TOOL_CALL_BEGIN.startswith(unprocessed_text): return None if self.current_tool_id != -1 and not self.prev_tool_call_arr[self.current_tool_id].get("finished", False): end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END) if end_tool_pos == -1: tool_body = unprocessed_text else: tool_body = unprocessed_text[:end_tool_pos] if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body): return None function_name, arguments = self._parse_steptml_invoke(tool_body) if not function_name: return None tool_call_arr = {"name": function_name, "parameters": arguments or {}} if not self.current_tool_name_sent: self.current_tool_name_sent = True self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr) return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, type="function", id=f"chatcmpl-tool-{uuid4()}", function=DeltaFunctionCall(name=function_name), ) ] ) self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr) if end_tool_pos != -1: self.position += end_tool_pos + len(self.TOOL_CALL_END) self.prev_tool_call_arr[self.current_tool_id]["finished"] = True final_args = self._cast_arguments( function_name, tool_call_arr.get("parameters", {}), # type: ignore request, ) if final_args: final_args_json = json.dumps(final_args, ensure_ascii=False) return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=final_args_json) ) ] ) return None return None
[docs] def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.TOOL_CALLS_BEGIN not in model_output: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1) if self.TOOL_CALLS_END not in rest: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1) content = (pre_text + post_text).strip() tool_calls: list[ToolCall] = [] call_parts = tool_block.split(self.TOOL_CALL_BEGIN) for part in call_parts: if not part or self.TOOL_CALL_END not in part: continue call_content = part.split(self.TOOL_CALL_END, 1)[0] if self.TOOL_SEP not in call_content: continue type_part, invoke_part = call_content.split(self.TOOL_SEP, 1) if type_part.strip() != "function": continue function_name, params_dict = self._parse_steptml_invoke(invoke_part) if function_name and params_dict is not None: params_dict = self._cast_arguments(function_name, params_dict, request) params_str = json.dumps(params_dict, ensure_ascii=False) tool_calls.append(ToolCall(function=FunctionCall(name=function_name, arguments=params_str))) if tool_calls: return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None ) return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)