Source code for easydel.inference.tools.parsers.deepseekv3_tool_parser

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import re
from collections.abc import Sequence
from uuid import uuid4

from eformer.loggings import get_logger
from transformers import AutoTokenizer as AnyTokenizer

from ...openai_api_modules import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from ..abstract_tool import ToolParser, ToolParserManager

logger = get_logger(__name__)


[docs]@ToolParserManager.register_module("deepseek_v3") class DeepSeekV3ToolParser(ToolParser): """ Tool parser for DeepSeek V3 models. This parser handles the specific tool call format used by DeepSeek V3 models, which uses custom delimiters for tool calls and arguments. The format includes: - Tool calls wrapped in <|tool▁calls▁begin|> and <|tool▁calls▁end|> - Individual calls wrapped in <|tool▁call▁begin|> and <|tool▁call▁end|> - Tool type and name separated by <|tool▁sep|> - Arguments in JSON format wrapped in ```json blocks Attributes: current_tool_name_sent (bool): Tracks if tool name has been sent in streaming prev_tool_call_arr (list): Previous tool calls for comparison in streaming current_tool_id (int): Index of current tool being processed streamed_args_for_tool (list): Arguments streamed so far for each tool tool_calls_start_token (str): Token marking start of tool calls section tool_calls_end_token (str): Token marking end of tool calls section tool_call_start_token (str): Token marking start of individual tool call tool_call_end_token (str): Token marking end of individual tool call """ # noqa def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" # noqa self.tool_calls_end_token: str = "<|tool▁calls▁end|>" # noqa self.tool_call_start_token: str = "<|tool▁call▁begin|>" # noqa self.tool_call_end_token: str = "<|tool▁call▁end|>" # noqa self.tool_call_regex = re.compile( r"<|tool▁call▁begin|>(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<|tool▁call▁end|>" # noqa ) self.stream_tool_call_portion_regex = re.compile( r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*[^\n`])" # noqa ) self.stream_tool_call_name_regex = re.compile(r"(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n") # noqa if not self.model_tokenizer: raise ValueError("The model tokenizer must be passed to the ToolParser constructor during construction.") self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if self.tool_calls_start_token_id is None or self.tool_calls_end_token_id is None: raise RuntimeError("DeepSeek-V3 Tool parser could not locate tool call start/end tokens in the tokenizer!")
[docs] def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: """ Extract tool calls from complete model output. Parses the DeepSeek V3 format to extract function names and arguments from the model's response. Handles the specific token structure with tool type, separator, and JSON argument blocks. Args: model_output: Complete text output from the model request: Original chat completion request with tool definitions Returns: ExtractedToolCallInformation containing: - tools_called: Whether any tools were invoked - tool_calls: List of ToolCall objects with function details - content: Text content before tool calls (if any) """ if self.tool_calls_start_token not in model_output: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) else: try: function_call_tuples = self.tool_call_regex.findall(model_output) tool_calls = [] for match in function_call_tuples: tool_type, function_name, function_args = match tool_calls.append( ToolCall( type=tool_type, function=FunctionCall(name=function_name, arguments=function_args), ) ) content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None, ) except Exception: logger.exception("Error in extracting tool call from response.") return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
[docs] def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: """ Extract tool calls from streaming model output. Handles incremental parsing of DeepSeek V3 tool call format during streaming generation. Maintains state across chunks to properly identify tool boundaries and stream arguments progressively. Args: previous_text: Text generated up to previous chunk current_text: All text generated so far delta_text: New text in this chunk previous_token_ids: Token IDs up to previous chunk current_token_ids: All token IDs so far delta_token_ids: New token IDs in this chunk request: Original request with tool definitions Returns: DeltaMessage with incremental tool call information, or None if more data needed for parsing """ logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) if self.tool_calls_start_token_id not in current_token_ids: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(self.tool_calls_end_token, "") try: prev_tool_start_count = previous_token_ids.count(self.tool_call_start_token_id) prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count(self.tool_call_start_token_id) cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None if ( cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count and self.tool_call_end_token not in delta_text ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text tool_call_portion = ( full_text.split(self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0].rstrip() ) delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: if len(delta_token_ids) > 1: tool_call_portion = current_text.split(self.tool_call_start_token)[-1] else: tool_call_portion = None delta = None text_portion = None self.current_tool_id += 1 self.current_tool_name_sent = False self.streamed_args_for_tool.append("") logger.debug("Starting on a new tool %s", self.current_tool_id) elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count >= prev_tool_end_count: if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: logger.debug("attempting to close tool call, but no tool call") return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: diff = diff.encode("utf-8").decode("unicode_escape") if diff is str else diff if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not been streamed yet: %s", diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True), ) ] ) else: text = delta_text.replace(self.tool_call_start_token, "") text = text.replace(self.tool_call_end_token, "") delta = DeltaMessage(tool_calls=[], content=text) return delta current_tool_call = dict() if tool_call_portion: current_tool_call_matches = self.stream_tool_call_portion_regex.match(tool_call_portion) if current_tool_call_matches: _tool_type, tool_name, tool_args = current_tool_call_matches.groups() current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = self.stream_tool_call_name_regex.match(tool_call_portion) if current_tool_call_name_matches: _tool_type, tool_name = current_tool_call_name_matches.groups() current_tool_call["name"] = tool_name current_tool_call["arguments"] = "" else: logger.debug("Not enough token") return None if not self.current_tool_name_sent: if current_tool_call is None: return None function_name: str | None = current_tool_call.get("name") if function_name: self.current_tool_name_sent = True return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, type="function", id=f"chatcmpl-tool-{uuid4()}", function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True), ) ] ) else: return None if tool_call_portion is None: delta = DeltaMessage(content=delta_text) if text_portion is not None else None return delta logger.debug("Trying to parse current tool call with ID %s", self.current_tool_id) if len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) logger.debug("against new ones: %s", cur_arguments) if not cur_arguments and not prev_arguments: logger.debug("Skipping text %s - no arguments", delta_text) delta = None elif not cur_arguments and prev_arguments: logger.error("should be impossible to have arguments reset mid-call. skipping streaming anything.") delta = None elif cur_arguments and not prev_arguments: delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=cur_arguments).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] = cur_arguments elif cur_arguments and prev_arguments: if ( isinstance(delta_text, str) and cur_arguments != prev_arguments and len(cur_arguments) > len(prev_arguments) and cur_arguments.startswith(prev_arguments) ): delta_arguments = cur_arguments[len(prev_arguments) :] logger.debug("got diff %s", delta_text) delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=delta_arguments).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None if self.current_tool_id == len(self.prev_tool_call_arr) - 1: self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) return delta except Exception: logger.exception("Error trying to handle streaming tool call.") return None