Source code for easydel.inference.tools.parsers.step3_tool_parser
# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import json
import re
from collections.abc import Sequence
from typing import Any
from uuid import uuid4
from eformer.loggings import get_logger
from transformers import AutoTokenizer as AnyTokenizer
from ...openai_api_modules import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from ..abstract_tool import ToolParser, ToolParserManager
logger = get_logger(__name__)
[docs]@ToolParserManager.register_module(["step3"])
class Step3ToolParser(ToolParser):
"""
Tool parser for Step3 models with XML-like format.
Uses a robust, stateful, cursor-based streaming parser that
consolidates tool arguments into single messages. Handles
complex nested tool call structures with special delimiters.
Features:
- Cursor-based position tracking
- Hierarchical token structure with special characters
- steptml:invoke format parsing
- Automatic type casting based on schema
- State machine for streaming parsing
Format:
<|tool_calls_begin|>
<|tool_call_begin|>
function<|tool_sep|>
<steptml:invoke name="func">
<steptml:parameter name="param">value</steptml:parameter>
</steptml:invoke>
<|tool_call_end|>
<|tool_calls_end|>
Note: Uses fullwidth vertical line characters (|) as delimiters.
""" # noqa
TOOL_CALLS_BEGIN = "<|tool_calls_begin|>" # noqa: RUF001
TOOL_CALLS_END = "<|tool_calls_end|>" # noqa: RUF001
TOOL_CALL_BEGIN = "<|tool_call_begin|>" # noqa: RUF001
TOOL_CALL_END = "<|tool_call_end|>" # noqa: RUF001
TOOL_SEP = "<|tool_sep|>" # noqa: RUF001
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END] # noqa: RUF012
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
self.tool_block_started = False
self.tool_block_finished = False
[docs] def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
return request
@staticmethod
def _parse_steptml_invoke(action_text: str) -> tuple[str | None, dict[str, str] | None]:
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text)
if not func_name_match:
return None, None
func_name = func_name_match.group(1)
params: dict[str, str] = {}
param_matches = re.findall(r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>', action_text)
for name, value in param_matches:
params[name] = value.strip()
return func_name, params
def _cast_arguments(
self,
func_name: str,
params: dict[str, Any],
request: ChatCompletionRequest,
) -> dict[str, Any]:
for tool in request.tools or []:
if tool.function.name == func_name:
schema = tool.function.parameters or {}
properties = schema.get("properties", {})
for key, value in params.items():
if not isinstance(value, str):
continue
prop = properties.get(key, {})
typ = prop.get("type")
if typ == "string":
params[key] = value.strip()
elif typ == "integer":
with contextlib.suppress(ValueError):
params[key] = int(value)
elif typ == "number":
with contextlib.suppress(ValueError):
params[key] = float(value)
elif typ == "boolean":
lower_val = value.lower()
params[key] = lower_val == "true" if lower_val in ("true", "false") else value
elif typ == "null":
params[key] = None if value.lower() == "null" else value
break
return params
[docs] def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
while True:
if self.position >= len(current_text):
return None
unprocessed_text = current_text[self.position :]
if self.tool_block_finished:
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
if not self.tool_block_started:
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
self.position += len(self.TOOL_CALLS_BEGIN)
self.tool_block_started = True
continue
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
if start_pos == -1:
if self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip()) and unprocessed_text:
return None
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
else:
content = unprocessed_text[:start_pos]
self.position += len(content)
return DeltaMessage(content=content)
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
unprocessed_text = unprocessed_text.lstrip()
self.position += offset
if unprocessed_text.startswith(self.TOOL_CALLS_END):
self.position += len(self.TOOL_CALLS_END)
self.tool_block_finished = True
self.current_tool_id = -1
continue
tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[self.current_tool_id].get("finished")
if self.current_tool_id == -1 or tool_finished:
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
self.position += len(self.TOOL_CALL_BEGIN)
if self.current_tool_id == -1:
self.current_tool_id = 0
else:
self.current_tool_id += 1
self.current_tool_name_sent = False
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[self.current_tool_id]["finished"] = False
continue
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
return None
if self.current_tool_id != -1 and not self.prev_tool_call_arr[self.current_tool_id].get("finished", False):
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
if end_tool_pos == -1:
tool_body = unprocessed_text
else:
tool_body = unprocessed_text[:end_tool_pos]
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body):
return None
function_name, arguments = self._parse_steptml_invoke(tool_body)
if not function_name:
return None
tool_call_arr = {"name": function_name, "parameters": arguments or {}}
if not self.current_tool_name_sent:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{uuid4()}",
function=DeltaFunctionCall(name=function_name),
)
]
)
self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
if end_tool_pos != -1:
self.position += end_tool_pos + len(self.TOOL_CALL_END)
self.prev_tool_call_arr[self.current_tool_id]["finished"] = True
final_args = self._cast_arguments(
function_name,
tool_call_arr.get("parameters", {}), # type: ignore
request,
)
if final_args:
final_args_json = json.dumps(final_args, ensure_ascii=False)
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id, function=DeltaFunctionCall(arguments=final_args_json)
)
]
)
return None
return None
[docs] def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.TOOL_CALLS_BEGIN not in model_output:
return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
if self.TOOL_CALLS_END not in rest:
return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
content = (pre_text + post_text).strip()
tool_calls: list[ToolCall] = []
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
for part in call_parts:
if not part or self.TOOL_CALL_END not in part:
continue
call_content = part.split(self.TOOL_CALL_END, 1)[0]
if self.TOOL_SEP not in call_content:
continue
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
if type_part.strip() != "function":
continue
function_name, params_dict = self._parse_steptml_invoke(invoke_part)
if function_name and params_dict is not None:
params_dict = self._cast_arguments(function_name, params_dict, request)
params_str = json.dumps(params_dict, ensure_ascii=False)
tool_calls.append(ToolCall(function=FunctionCall(name=function_name, arguments=params_str)))
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content if content else None
)
return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)