Source code for easydel.inference.tools.parsers.llama_tool_parser

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import re
from collections.abc import Sequence
from uuid import uuid4

import partial_json_parser
from eformer.loggings import get_logger
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase

from ...openai_api_modules import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from ..abstract_tool import ToolParser, ToolParserManager
from ..utils import find_common_prefix, is_complete_json, partial_json_loads

logger = get_logger(__name__)


[docs]@ToolParserManager.register_module("llama3_json") @ToolParserManager.register_module("llama4_json") class Llama3JsonToolParser(ToolParser): """ Tool call parser for Llama 3.x and 4 models with JSON format. Intended for use with the examples/tool_chat_template_llama.jinja template. Handles JSON-formatted tool calls with support for both single and multiple tool invocations separated by semicolons. Format supported: - Single: {"name": "func", "arguments": {...}} - Multiple: {"name": "func1", ...}; {"name": "func2", ...} - Uses <|python_tag|> token as optional marker - Supports both "arguments" and "parameters" field names Used when --enable-auto-tool-choice --tool-call-parser llama3_json or llama4_json are set. Attributes: bot_token (str): Special token marking tool calls tool_call_regex (re.Pattern): Pattern for extracting JSON tool calls prev_tool_call_arr (list): Previous tool calls for streaming comparison """ def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 self.current_tool_name_sent: bool = False self.streamed_args_for_tool: list[str] = [] self.bot_token = "<|python_tag|>" self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[0] self.tool_call_regex = re.compile( r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*", re.DOTALL )
[docs] def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Extract tool calls from complete Llama model response. Extracts JSON content and ignores surrounding plain text. Supports both single JSON and multiple JSONs separated by semicolons. Handles both "arguments" and "parameters" field names for compatibility. Args: model_output: Complete model output text request: Original request (unused) Returns: ExtractedToolCallInformation with parsed JSON tool calls """ if not (self.bot_token in model_output or "{" in model_output): return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) match = self.tool_call_regex.search(model_output) if not match: return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output) try: json_str = match.group(0) json_objects = [obj.strip() for obj in json_str.split(";")] tool_calls: list[ToolCall] = [] for json_obj in json_objects: if not json_obj: continue obj = json.loads(json_obj) tool_calls.append( ToolCall( type="function", function=FunctionCall( name=obj["name"], arguments=json.dumps( obj["arguments"] if "arguments" in obj else obj["parameters"], ensure_ascii=False ), ), ) ) return ExtractedToolCallInformation(tools_called=True, tool_calls=tool_calls, content=None) except Exception: pass return ExtractedToolCallInformation(tools_called=False, tool_calls=[], content=model_output)
[docs] def extract_tool_calls_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: if not (current_text.startswith(self.bot_token) or current_text.startswith("{")): return DeltaMessage(content=delta_text) flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: tool_call_arr = [] is_complete = [] try: start_idx = len(self.bot_token) if current_text.startswith(self.bot_token) else 0 while start_idx < len(current_text): (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags) is_complete.append(is_complete_json(current_text[start_idx : start_idx + end_idx])) start_idx += end_idx + len("; ") if "parameters" in obj: assert "arguments" not in obj, "model generated both parameters and arguments" obj["arguments"] = obj["parameters"] tool_call_arr.append(obj) except partial_json_parser.core.exceptions.MalformedJSON: pass return None current_tool_call: dict = tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} if len(tool_call_arr) == 0: return None elif len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1: if self.current_tool_id >= 0: cur_arguments = current_tool_call.get("arguments") if cur_arguments: cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) sent = len(self.streamed_args_for_tool[self.current_tool_id]) argument_diff = cur_args_json[sent:] delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=argument_diff).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] += argument_diff else: delta = None else: delta = None self.current_tool_id = len(tool_call_arr) - 1 self.current_tool_name_sent = False self.streamed_args_for_tool.append("") return delta elif not self.current_tool_name_sent: function_name = current_tool_call.get("name") if function_name: delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, type="function", id=f"chatcmpl-tool-{uuid4()}", function=DeltaFunctionCall(name=function_name).model_dump(exclude_none=True), ) ] ) self.current_tool_name_sent = True else: delta = None else: cur_arguments = current_tool_call.get("arguments") delta = None if cur_arguments: sent = len(self.streamed_args_for_tool[self.current_tool_id]) cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get("arguments") argument_diff = None if is_complete[self.current_tool_id]: argument_diff = cur_args_json[sent:] elif prev_arguments: prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) if cur_args_json != prev_args_json: prefix = find_common_prefix(prev_args_json, cur_args_json) argument_diff = prefix[sent:] if argument_diff is not None: delta = DeltaMessage( tool_calls=[ DeltaToolCall( index=self.current_tool_id, function=DeltaFunctionCall(arguments=argument_diff).model_dump(exclude_none=True), ) ] ) self.streamed_args_for_tool[self.current_tool_id] += argument_diff self.prev_tool_call_arr = tool_call_arr return delta except Exception: pass return None