Source code for easydel.inference.tools.parsers.kimi_k2_tool_parser

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import re
from collections.abc import Sequence

from eformer.loggings import get_logger
from transformers import AutoTokenizer as AnyTokenizer

from ...openai_api_modules import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from ..abstract_tool import ToolParser, ToolParserManager

logger = get_logger(__name__)


[docs]@ToolParserManager.register_module(["kimi_k2"]) class KimiK2ToolParser(ToolParser): """ Tool parser for Kimi K2 models. Handles tool calls with hierarchical token structure: - Tool calls section: <|tool_calls_section_begin|> ... <|tool_calls_section_end|> - Individual calls: <|tool_call_begin|> ... <|tool_call_end|> - Arguments after <|tool_call_argument_begin|> Features: - Hierarchical token-based parsing - Tool ID extraction from format: namespace.function:id - Streaming with state tracking for nested structures - Regex patterns for structured extraction Format: <|tool_calls_section_begin|> <|tool_call_begin|>namespace.function:123<|tool_call_argument_begin|>{...}<|tool_call_end|> <|tool_calls_section_end|> """ def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.streamed_args_for_tool: list[str] = [] self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" self.tool_calls_end_token: str = "<|tool_calls_section_end|>" self.tool_call_start_token: str = "<|tool_call_begin|>" self.tool_call_end_token: str = "<|tool_call_end|>" self.tool_call_regex = re.compile( r"<\|tool_call_begin\|>\s*(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*?)\s*<\|tool_call_end\|>" ) self.stream_tool_call_portion_regex = re.compile( r"(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)" ) self.stream_tool_call_name_regex = re.compile(r"(?P<tool_call_id>.+:\d+)\s*") if not self.model_tokenizer: raise ValueError("The model tokenizer must be passed to the ToolParser constructor during construction.") self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token) self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token) self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token) self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) if self.tool_calls_start_token_id is None or self.tool_calls_end_token_id is None: raise RuntimeError("Kimi-K2 Tool parser could not locate tool call start/end tokens in the tokenizer!")
[docs] def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: if self.tool_calls_start_token not in model_output: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) else: try: function_call_tuples = self.tool_call_regex.findall(model_output) logger.debug("function_call_tuples: %s", function_call_tuples) tool_calls = [] for match in function_call_tuples: function_id, function_args = match function_name = function_id.split(".")[1].split(":")[0] tool_calls.append( ToolCall( id=function_id, type="function", function=FunctionCall(name=function_name, arguments=function_args), ) ) content = model_output[: model_output.find(self.tool_calls_start_token)] return ExtractedToolCallInformation( tools_called=True, tool_calls=tool_calls, content=content if content else None, ) except Exception: logger.exception("Error in extracting tool call from response.") return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
[docs] def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) if self.tool_calls_start_token_id not in current_token_ids: logger.debug("No tool call tokens found!") return DeltaMessage(content=delta_text) delta_text = delta_text.replace(self.tool_calls_start_token, "").replace(self.tool_calls_end_token, "") try: prev_tool_start_count = previous_token_ids.count(self.tool_call_start_token_id) prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id) cur_tool_start_count = current_token_ids.count(self.tool_call_start_token_id) cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id) tool_call_portion = None text_portion = None if ( cur_tool_start_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count and self.tool_call_end_token not in delta_text ): logger.debug("Generating text content! skipping tool parsing.") return DeltaMessage(content=delta_text) if self.tool_call_end_token in delta_text: logger.debug("tool_call_end_token in delta_text") full_text = current_text + delta_text tool_call_portion = ( full_text.split(self.tool_call_start_token)[-1].split(self.tool_call_end_token)[0].rstrip() ) delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() if cur_tool_start_count > cur_tool_end_count and cur_tool_start_count > prev_tool_start_count: if len(delta_token_ids) > 1: tool_call_portion = current_text.split(self.tool_call_start_token)[-1] else: tool_call_portion = None delta = None text_portion = None self.current_tool_id += 1 self.current_tool_name_sent = False self.streamed_args_for_tool.append("") logger.debug("Starting on a new tool %s", self.current_tool_id) elif cur_tool_start_count > cur_tool_end_count and cur_tool_start_count == prev_tool_start_count: tool_call_portion = current_text.split(self.tool_call_start_token)[-1] text_portion = None elif cur_tool_start_count == cur_tool_end_count and cur_tool_end_count >= prev_tool_end_count: if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: logger.debug("attempting to close tool call, but no tool call") return None diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") if diff: diff = diff.encode("utf-8").decode("unicode_escape") if diff is str else diff if '"}' not in delta_text: return None end_loc = delta_text.rindex('"}') diff = delta_text[:end_loc] + '"}' logger.debug( "Finishing tool and found diff that had not been streamed yet: %s", diff, ) self.streamed_args_for_tool[self.current_tool_id] += diff return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=diff).model_dump(exclude_none=True), ) ] ) else: text = delta_text.replace(self.tool_call_start_token, "") text = text.replace(self.tool_call_end_token, "") delta = DeltaMessage(tool_calls=[], content=text) return delta current_tool_call = dict() if tool_call_portion: current_tool_call_matches = self.stream_tool_call_portion_regex.match(tool_call_portion) if current_tool_call_matches: tool_id, tool_args = current_tool_call_matches.groups() tool_name = tool_id.split(".")[1].split(":")[0] current_tool_call["id"] = tool_id current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: current_tool_call_name_matches = self.stream_tool_call_name_regex.match(tool_call_portion) if current_tool_call_name_matches: (tool_id_str,) = current_tool_call_name_matches.groups() tool_name = tool_id_str.split(".")[1].split(":")[0] current_tool_call["id"] = tool_id_str current_tool_call["name"] = tool_name current_tool_call["arguments"] = "" else: logger.debug("Not enough token") return None if not self.current_tool_name_sent: if current_tool_call is None: return None function_name: str | None = current_tool_call.get("name") tool_id = current_tool_call.get("id") if function_name: self.current_tool_name_sent = True return DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, type="function", id=tool_id, function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True), ) ] ) else: return None if tool_call_portion is None: delta = DeltaMessage(content=delta_text) if text_portion is not None else None return delta logger.debug("Trying to parse current tool call with ID %s", self.current_tool_id) if len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments") cur_arguments = current_tool_call.get("arguments") logger.debug("diffing old arguments: %s", prev_arguments) logger.debug("against new ones: %s", cur_arguments) if not cur_arguments and not prev_arguments: logger.debug("Skipping text %s - no arguments", delta_text) delta = None elif not cur_arguments and prev_arguments: logger.error("should be impossible to have arguments reset mid-call. skipping streaming anything.") delta = None elif cur_arguments and not prev_arguments: delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=cur_arguments).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] = cur_arguments elif cur_arguments and prev_arguments: if ( isinstance(delta_text, str) and cur_arguments != prev_arguments and len(cur_arguments) > len(prev_arguments) and cur_arguments.startswith(prev_arguments) ): delta_arguments = cur_arguments[len(prev_arguments) :] logger.debug("got diff %s", delta_text) delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=delta_arguments).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] = cur_arguments else: delta = None if self.current_tool_id == len(self.prev_tool_call_arr) - 1: self.prev_tool_call_arr[self.current_tool_id] = current_tool_call else: self.prev_tool_call_arr.append(current_tool_call) return delta except Exception: logger.exception("Error trying to handle streaming tool call.") return None