Source code for easydel.inference.tools.utils

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for tool call parsing in EasyDeL inference.

This module provides essential helper functions for parsing and processing
tool calls from Large Language Model outputs. It includes utilities for:

- Partial JSON parsing and validation
- String manipulation for streaming token differences
- Common prefix/suffix extraction for incremental parsing
- Whitespace handling in structured outputs

These utilities are fundamental to the streaming tool call extraction
process, enabling parsers to handle incomplete JSON/XML structures
as they are being generated token-by-token.

Key Functions:
    find_common_prefix: Find shared prefix between strings
    find_common_suffix: Find shared suffix between strings
    extract_intermediate_diff: Extract new content for streaming
    partial_json_loads: Parse potentially incomplete JSON
    is_complete_json: Validate JSON completeness
"""

import json
from json import JSONDecodeError, JSONDecoder
from typing import Any

import partial_json_parser
from partial_json_parser.core.options import Allow


[docs]def find_common_prefix(s1: str, s2: str) -> str: """ Find common prefix shared between two strings. Used for extracting information from JSON generated by partial_json_parser to ensure proper token streaming without premature closure of quotes, brackets, or braces. Order of arguments does not matter. Args: s1: First string to compare s2: Second string to compare Returns: The common prefix substring, empty string if none Example: >>> find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') '{"fruit": "ap' """ prefix = "" min_length = min(len(s1), len(s2)) for i in range(0, min_length): if s1[i] == s2[i]: prefix += s1[i] else: break return prefix
[docs]def find_common_suffix(s1: str, s2: str) -> str: """ Find common suffix shared between two strings. Stops when suffix ends or hits an alphanumeric character. Order of arguments does not matter. Args: s1: First string to compare s2: Second string to compare Returns: The common suffix substring, empty string if none Example: >>> find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') '"}' """ suffix = "" min_length = min(len(s1), len(s2)) for i in range(1, min_length + 1): if s1[-i] == s2[-i] and not s1[-i].isalnum(): suffix = s1[-i] + suffix else: break return suffix
[docs]def extract_intermediate_diff(curr: str, old: str) -> str: """ Extract the difference between two strings with common prefix/suffix. Used for streaming partial JSON to extract only new tokens that should be sent to the client. Critical for proper streaming of tool arguments without duplicating or missing content. Args: curr: New/current version of partially-parsed JSON (must be first) old: Previous version from earlier generation (must be second) Returns: The intermediate difference that should be streamed Example: >>> extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') 'ple' Note: Argument order is important - new version must be first. """ suffix = find_common_suffix(curr, old) old = old[::-1].replace(suffix[::-1], "", 1)[::-1] prefix = find_common_prefix(curr, old) diff = curr if len(suffix): diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1] if len(prefix): diff = diff.replace(prefix, "", 1) return diff
[docs]def find_all_indices(string: str, substring: str) -> list[int]: """ Find all starting indices of a substring in a string. Useful for locating multiple tool call boundaries or markers in model output. Args: string: String to search in substring: Substring to find Returns: List of starting indices where substring appears """ indices = [] index = -1 while True: index = string.find(substring, index + 1) if index == -1: break indices.append(index) return indices
[docs]def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]: """ Load partial JSON with fallback to standard decoder. Attempts to parse potentially incomplete JSON using partial_json_parser, falling back to standard JSON decoder for "Extra data" errors. Args: input_str: JSON string to parse (may be incomplete) flags: Parsing flags from partial_json_parser.Allow Returns: Tuple of (parsed object, characters consumed) Raises: JSONDecodeError: If JSON is malformed (not just incomplete) """ try: return (partial_json_parser.loads(input_str, flags), len(input_str)) except JSONDecodeError as e: if "Extra data" in e.msg: dec = JSONDecoder() return dec.raw_decode(input_str) raise
[docs]def is_complete_json(input_str: str) -> bool: """ Check if a string is complete, valid JSON. Args: input_str: String to validate Returns: True if string is complete valid JSON, False otherwise """ try: json.loads(input_str) return True except JSONDecodeError: return False
[docs]def consume_space(i: int, s: str) -> int: """ Skip whitespace characters starting from index. Args: i: Starting index s: String to process Returns: Index of first non-whitespace character, or len(s) if all whitespace """ while i < len(s) and s[i].isspace(): i += 1 return i