|
|
|
|
|
import ast |
|
import json |
|
import uuid |
|
from collections.abc import Sequence |
|
from typing import Any, List, Optional, Union |
|
|
|
import regex as re |
|
|
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, |
|
ChatCompletionToolsParam, |
|
DeltaFunctionCall, DeltaMessage, |
|
DeltaToolCall, |
|
ExtractedToolCallInformation, |
|
FunctionCall, ToolCall) |
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( |
|
ToolParser, ToolParserManager) |
|
from vllm.logger import init_logger |
|
from vllm.transformers_utils.tokenizer import AnyTokenizer |
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
@ToolParserManager.register_module("qwen3_coder") |
|
class Qwen3CoderToolParser(ToolParser): |
|
|
|
def __init__(self, tokenizer: AnyTokenizer): |
|
super().__init__(tokenizer) |
|
|
|
self.current_tool_name_sent: bool = False |
|
self.prev_tool_call_arr: list[dict] = [] |
|
self.current_tool_id: int = -1 |
|
self.streamed_args_for_tool: list[str] = [] |
|
|
|
|
|
self.tool_call_start_token: str = "<tool_call>" |
|
self.tool_call_end_token: str = "</tool_call>" |
|
self.tool_call_prefix: str = "<function=" |
|
self.function_end_token: str = "</function>" |
|
self.parameter_prefix: str = "<parameter=" |
|
self.parameter_end_token: str = "</parameter>" |
|
self.is_tool_call_started: bool = False |
|
self.failed_count: int = 0 |
|
|
|
|
|
self._reset_streaming_state() |
|
|
|
|
|
self.tool_call_complete_regex = re.compile( |
|
r"<tool_call>(.*?)</tool_call>", re.DOTALL) |
|
self.tool_call_regex = re.compile( |
|
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL) |
|
self.tool_call_function_regex = re.compile( |
|
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) |
|
self.tool_call_parameter_regex = re.compile( |
|
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)", |
|
re.DOTALL) |
|
|
|
if not self.model_tokenizer: |
|
raise ValueError( |
|
"The model tokenizer must be passed to the ToolParser " |
|
"constructor during construction.") |
|
|
|
self.tool_call_start_token_id = self.vocab.get( |
|
self.tool_call_start_token) |
|
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) |
|
|
|
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None: |
|
raise RuntimeError( |
|
"Qwen3 XML Tool parser could not locate tool call start/end " |
|
"tokens in the tokenizer!") |
|
|
|
logger.info( |
|
f"vLLM Successfully import tool parser {self.__class__.__name__} !" |
|
) |
|
|
|
def _generate_tool_call_id(self) -> str: |
|
"""Generate a unique tool call ID.""" |
|
return f"call_{uuid.uuid4().hex[:24]}" |
|
|
|
def _reset_streaming_state(self): |
|
"""Reset all streaming state.""" |
|
self.current_tool_index = 0 |
|
self.is_tool_call_started = False |
|
self.header_sent = False |
|
self.current_tool_id = None |
|
self.current_function_name = None |
|
self.current_param_name = None |
|
self.current_param_value = "" |
|
self.param_count = 0 |
|
self.in_param = False |
|
self.in_function = False |
|
self.accumulated_text = "" |
|
self.json_started = False |
|
self.json_closed = False |
|
|
|
self.accumulated_params = {} |
|
self.streaming_request = None |
|
|
|
def _get_arguments_config( |
|
self, func_name: str, |
|
tools: Optional[list[ChatCompletionToolsParam]]) -> dict: |
|
"""Extract argument configuration for a function.""" |
|
if tools is None: |
|
return {} |
|
for config in tools: |
|
if not hasattr(config, "type") or not (hasattr( |
|
config, "function") and hasattr(config.function, "name")): |
|
continue |
|
if config.type == "function" and config.function.name == func_name: |
|
if not hasattr(config.function, "parameters"): |
|
return {} |
|
params = config.function.parameters |
|
if isinstance(params, dict) and "properties" in params: |
|
return params["properties"] |
|
elif isinstance(params, dict): |
|
return params |
|
else: |
|
return {} |
|
logger.warning(f"Tool '{func_name}' is not defined in the tools list.") |
|
return {} |
|
|
|
def _convert_param_value(self, param_value: str, param_name: str, |
|
param_config: dict, func_name: str) -> Any: |
|
"""Convert parameter value based on its type in the schema.""" |
|
|
|
if param_value.lower() == "null": |
|
return None |
|
|
|
if param_name not in param_config: |
|
if param_config != {}: |
|
logger.warning( |
|
f"Parsed parameter '{param_name}' is not defined in the tool " |
|
f"parameters for tool '{func_name}', directly returning the string value." |
|
) |
|
return param_value |
|
|
|
if isinstance(param_config[param_name], |
|
dict) and "type" in param_config[param_name]: |
|
param_type = str(param_config[param_name]["type"]).strip().lower() |
|
else: |
|
param_type = "string" |
|
if param_type in ["string", "str", "text", "varchar", "char", "enum"]: |
|
return param_value |
|
elif param_type.startswith("int") or param_type.startswith( |
|
"uint") or param_type.startswith( |
|
"long") or param_type.startswith( |
|
"short") or param_type.startswith("unsigned"): |
|
try: |
|
param_value = int(param_value) |
|
except: |
|
logger.warning( |
|
f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " |
|
f"'{func_name}', degenerating to string.") |
|
return param_value |
|
elif param_type.startswith("num") or param_type.startswith("float"): |
|
try: |
|
float_param_value = float(param_value) |
|
param_value = float_param_value if float_param_value - int( |
|
float_param_value) != 0 else int(float_param_value) |
|
except: |
|
logger.warning( |
|
f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " |
|
f"'{func_name}', degenerating to string.") |
|
return param_value |
|
elif param_type in ["boolean", "bool", "binary"]: |
|
param_value = param_value.lower() |
|
if param_value not in ["true", "false"]: |
|
logger.warning( |
|
f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." |
|
) |
|
return param_value == "true" |
|
else: |
|
if param_type in ["object", "array", "arr" |
|
] or param_type.startswith( |
|
"dict") or param_type.startswith("list"): |
|
try: |
|
param_value = json.loads(param_value) |
|
return param_value |
|
except: |
|
logger.warning( |
|
f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool " |
|
f"'{func_name}', will try other methods to parse it.") |
|
try: |
|
param_value = ast.literal_eval(param_value) |
|
except: |
|
logger.warning( |
|
f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string." |
|
) |
|
return param_value |
|
|
|
def _parse_xml_function_call( |
|
self, function_call_str: str, |
|
tools: Optional[list[ChatCompletionToolsParam]] |
|
) -> Optional[ToolCall]: |
|
|
|
|
|
end_index = function_call_str.index(">") |
|
function_name = function_call_str[:end_index] |
|
param_config = self._get_arguments_config(function_name, tools) |
|
parameters = function_call_str[end_index + 1:] |
|
param_dict = {} |
|
for match_text in self.tool_call_parameter_regex.findall(parameters): |
|
idx = match_text.index(">") |
|
param_name = match_text[:idx] |
|
param_value = str(match_text[idx + 1:]) |
|
|
|
if param_value.startswith("\n"): |
|
param_value = param_value[1:] |
|
if param_value.endswith("\n"): |
|
param_value = param_value[:-1] |
|
|
|
param_dict[param_name] = self._convert_param_value( |
|
param_value, param_name, param_config, function_name) |
|
return ToolCall( |
|
type="function", |
|
function=FunctionCall(name=function_name, |
|
arguments=json.dumps(param_dict, |
|
ensure_ascii=False)), |
|
) |
|
|
|
def _get_function_calls(self, model_output: str) -> List[str]: |
|
|
|
matched_ranges = self.tool_call_regex.findall(model_output) |
|
raw_tool_calls = [ |
|
match[0] if match[0] else match[1] for match in matched_ranges |
|
] |
|
|
|
|
|
if len(raw_tool_calls) == 0: |
|
raw_tool_calls = [model_output] |
|
|
|
raw_function_calls = [] |
|
for tool_call in raw_tool_calls: |
|
raw_function_calls.extend( |
|
self.tool_call_function_regex.findall(tool_call)) |
|
|
|
function_calls = [ |
|
match[0] if match[0] else match[1] for match in raw_function_calls |
|
] |
|
return function_calls |
|
|
|
def extract_tool_calls( |
|
self, |
|
model_output: str, |
|
request: ChatCompletionRequest, |
|
) -> ExtractedToolCallInformation: |
|
|
|
if self.tool_call_prefix not in model_output: |
|
return ExtractedToolCallInformation(tools_called=False, |
|
tool_calls=[], |
|
content=model_output) |
|
|
|
try: |
|
function_calls = self._get_function_calls(model_output) |
|
if len(function_calls) == 0: |
|
return ExtractedToolCallInformation(tools_called=False, |
|
tool_calls=[], |
|
content=model_output) |
|
|
|
tool_calls = [ |
|
self._parse_xml_function_call(function_call_str, request.tools) |
|
for function_call_str in function_calls |
|
] |
|
|
|
|
|
self.prev_tool_call_arr.clear() |
|
for tool_call in tool_calls: |
|
if tool_call: |
|
self.prev_tool_call_arr.append({ |
|
"name": |
|
tool_call.function.name, |
|
"arguments": |
|
tool_call.function.arguments, |
|
}) |
|
|
|
|
|
content_index = model_output.find(self.tool_call_start_token) |
|
content_index = content_index if content_index >= 0 else model_output.find( |
|
self.tool_call_prefix) |
|
content = model_output[:content_index] |
|
|
|
return ExtractedToolCallInformation( |
|
tools_called=(len(tool_calls) > 0), |
|
tool_calls=tool_calls, |
|
content=content if content else None, |
|
) |
|
|
|
except Exception: |
|
logger.exception("Error in extracting tool call from response.") |
|
return ExtractedToolCallInformation(tools_called=False, |
|
tool_calls=[], |
|
content=model_output) |
|
|
|
def extract_tool_calls_streaming( |
|
self, |
|
previous_text: str, |
|
current_text: str, |
|
delta_text: str, |
|
previous_token_ids: Sequence[int], |
|
current_token_ids: Sequence[int], |
|
delta_token_ids: Sequence[int], |
|
request: ChatCompletionRequest, |
|
) -> Union[DeltaMessage, None]: |
|
|
|
if not previous_text: |
|
self._reset_streaming_state() |
|
self.streaming_request = request |
|
|
|
|
|
if not delta_text: |
|
|
|
|
|
|
|
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids: |
|
|
|
complete_calls = len( |
|
self.tool_call_complete_regex.findall(current_text)) |
|
|
|
|
|
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: |
|
|
|
open_calls = current_text.count( |
|
self.tool_call_start_token) - current_text.count( |
|
self.tool_call_end_token) |
|
if open_calls == 0: |
|
|
|
return DeltaMessage(content="") |
|
elif not self.is_tool_call_started and current_text: |
|
|
|
return DeltaMessage(content="") |
|
return None |
|
|
|
|
|
self.accumulated_text = current_text |
|
|
|
|
|
if self.json_closed and not self.in_function: |
|
|
|
tool_ends = current_text.count(self.tool_call_end_token) |
|
if tool_ends > self.current_tool_index: |
|
|
|
self.current_tool_index += 1 |
|
self.header_sent = False |
|
self.param_count = 0 |
|
self.json_started = False |
|
self.json_closed = False |
|
self.accumulated_params = {} |
|
|
|
|
|
tool_starts = current_text.count(self.tool_call_start_token) |
|
if self.current_tool_index >= tool_starts: |
|
|
|
self.is_tool_call_started = False |
|
|
|
return None |
|
|
|
|
|
if not self.is_tool_call_started: |
|
|
|
if self.tool_call_start_token_id in delta_token_ids or self.tool_call_start_token in delta_text: |
|
self.is_tool_call_started = True |
|
|
|
if self.tool_call_start_token in delta_text: |
|
content_before = delta_text[:delta_text.index( |
|
self.tool_call_start_token)] |
|
if content_before: |
|
return DeltaMessage(content=content_before) |
|
return None |
|
else: |
|
|
|
if current_text.rstrip().endswith(self.tool_call_end_token): |
|
|
|
if delta_text.strip() == "": |
|
return None |
|
|
|
return DeltaMessage(content=delta_text) |
|
|
|
|
|
|
|
tool_starts_count = current_text.count(self.tool_call_start_token) |
|
if self.current_tool_index >= tool_starts_count: |
|
|
|
return None |
|
|
|
|
|
|
|
tool_starts = [] |
|
idx = 0 |
|
while True: |
|
idx = current_text.find(self.tool_call_start_token, idx) |
|
if idx == -1: |
|
break |
|
tool_starts.append(idx) |
|
idx += len(self.tool_call_start_token) |
|
|
|
if self.current_tool_index >= len(tool_starts): |
|
|
|
return None |
|
|
|
tool_start_idx = tool_starts[self.current_tool_index] |
|
|
|
tool_end_idx = current_text.find(self.tool_call_end_token, |
|
tool_start_idx) |
|
if tool_end_idx == -1: |
|
tool_text = current_text[tool_start_idx:] |
|
else: |
|
tool_text = current_text[tool_start_idx:tool_end_idx + |
|
len(self.tool_call_end_token)] |
|
|
|
|
|
if not self.header_sent: |
|
if self.tool_call_prefix in tool_text: |
|
func_start = tool_text.find(self.tool_call_prefix) + len( |
|
self.tool_call_prefix) |
|
func_end = tool_text.find(">", func_start) |
|
|
|
if func_end != -1: |
|
|
|
self.current_function_name = tool_text[func_start:func_end] |
|
self.current_tool_id = self._generate_tool_call_id() |
|
self.header_sent = True |
|
self.in_function = True |
|
|
|
|
|
|
|
already_added = any( |
|
tool.get("name") == self.current_function_name |
|
for tool in self.prev_tool_call_arr) |
|
if not already_added: |
|
self.prev_tool_call_arr.append({ |
|
"name": self.current_function_name, |
|
"arguments": |
|
"{}", |
|
}) |
|
|
|
|
|
return DeltaMessage(tool_calls=[ |
|
DeltaToolCall( |
|
index=self.current_tool_index, |
|
id=self.current_tool_id, |
|
function=DeltaFunctionCall( |
|
name=self.current_function_name, arguments=""), |
|
type="function", |
|
) |
|
]) |
|
return None |
|
|
|
|
|
if self.in_function: |
|
|
|
if not self.json_started and self.parameter_prefix not in delta_text: |
|
self.json_started = True |
|
return DeltaMessage(tool_calls=[ |
|
DeltaToolCall( |
|
index=self.current_tool_index, |
|
function=DeltaFunctionCall(arguments="{"), |
|
) |
|
]) |
|
|
|
|
|
if not self.json_started: |
|
self.json_started = True |
|
|
|
|
|
if not self.json_closed and self.function_end_token in tool_text: |
|
|
|
self.json_closed = True |
|
|
|
|
|
|
|
func_start = tool_text.find(self.tool_call_prefix) + len( |
|
self.tool_call_prefix) |
|
func_content_end = tool_text.find(self.function_end_token, |
|
func_start) |
|
if func_content_end != -1: |
|
func_content = tool_text[func_start:func_content_end] |
|
|
|
try: |
|
parsed_tool = self._parse_xml_function_call( |
|
func_content, self.streaming_request.tools |
|
if self.streaming_request else None) |
|
if parsed_tool: |
|
|
|
for i, tool in enumerate(self.prev_tool_call_arr): |
|
if tool.get( |
|
"name") == parsed_tool.function.name: |
|
self.prev_tool_call_arr[i][ |
|
"arguments"] = parsed_tool.function.arguments |
|
break |
|
except Exception: |
|
pass |
|
|
|
result = DeltaMessage(tool_calls=[ |
|
DeltaToolCall( |
|
index=self.current_tool_index, |
|
function=DeltaFunctionCall(arguments="}"), |
|
) |
|
]) |
|
|
|
|
|
self.in_function = False |
|
self.json_closed = True |
|
self.accumulated_params = {} |
|
|
|
return result |
|
|
|
|
|
|
|
param_starts = [] |
|
idx = 0 |
|
while True: |
|
idx = tool_text.find(self.parameter_prefix, idx) |
|
if idx == -1: |
|
break |
|
param_starts.append(idx) |
|
idx += len(self.parameter_prefix) |
|
|
|
|
|
if not self.in_param and self.param_count < len(param_starts): |
|
|
|
if len(param_starts) > self.param_count: |
|
|
|
param_idx = param_starts[self.param_count] |
|
param_start = param_idx + len(self.parameter_prefix) |
|
remaining = tool_text[param_start:] |
|
|
|
if ">" in remaining: |
|
|
|
name_end = remaining.find(">") |
|
self.current_param_name = remaining[:name_end] |
|
|
|
|
|
value_start = param_start + name_end + 1 |
|
value_text = tool_text[value_start:] |
|
if value_text.startswith("\n"): |
|
value_text = value_text[1:] |
|
|
|
|
|
param_end_idx = value_text.find( |
|
self.parameter_end_token) |
|
if param_end_idx == -1: |
|
|
|
next_param_idx = value_text.find( |
|
self.parameter_prefix) |
|
func_end_idx = value_text.find( |
|
self.function_end_token) |
|
|
|
if next_param_idx != -1 and (func_end_idx == -1 |
|
or next_param_idx |
|
< func_end_idx): |
|
param_end_idx = next_param_idx |
|
elif func_end_idx != -1: |
|
param_end_idx = func_end_idx |
|
else: |
|
|
|
if self.tool_call_end_token in tool_text: |
|
|
|
|
|
param_end_idx = len(value_text) |
|
else: |
|
|
|
return None |
|
|
|
if param_end_idx != -1: |
|
|
|
param_value = value_text[:param_end_idx] |
|
if param_value.endswith("\n"): |
|
param_value = param_value[:-1] |
|
|
|
|
|
self.accumulated_params[ |
|
self.current_param_name] = param_value |
|
|
|
|
|
param_config = self._get_arguments_config( |
|
self.current_function_name, |
|
self.streaming_request.tools |
|
if self.streaming_request else None) |
|
|
|
|
|
converted_value = self._convert_param_value( |
|
param_value, self.current_param_name, |
|
param_config, self.current_function_name) |
|
|
|
|
|
|
|
serialized_value = json.dumps(converted_value, |
|
ensure_ascii=False) |
|
|
|
if self.param_count == 0: |
|
json_fragment = f'"{self.current_param_name}": {serialized_value}' |
|
else: |
|
json_fragment = f', "{self.current_param_name}": {serialized_value}' |
|
|
|
self.param_count += 1 |
|
|
|
return DeltaMessage(tool_calls=[ |
|
DeltaToolCall( |
|
index=self.current_tool_index, |
|
function=DeltaFunctionCall( |
|
arguments=json_fragment), |
|
) |
|
]) |
|
|
|
|
|
|
|
if self.in_param: |
|
if self.parameter_end_token in delta_text: |
|
|
|
end_idx = delta_text.find(self.parameter_end_token) |
|
value_chunk = delta_text[:end_idx] |
|
|
|
|
|
if not self.current_param_value and ">" in value_chunk: |
|
gt_idx = value_chunk.find(">") |
|
value_chunk = value_chunk[gt_idx + 1:] |
|
|
|
if not self.current_param_value and value_chunk.startswith( |
|
"\n"): |
|
value_chunk = value_chunk[1:] |
|
|
|
|
|
full_value = self.current_param_value + value_chunk |
|
self.accumulated_params[ |
|
self.current_param_name] = full_value |
|
|
|
|
|
param_config = self._get_arguments_config( |
|
self.current_function_name, |
|
self.streaming_request.tools |
|
if self.streaming_request else None) |
|
|
|
|
|
converted_value = self._convert_param_value( |
|
full_value, self.current_param_name, param_config, |
|
self.current_function_name) |
|
|
|
|
|
serialized_value = json.dumps(converted_value, |
|
ensure_ascii=False) |
|
|
|
|
|
|
|
self.in_param = False |
|
self.current_param_value = "" |
|
|
|
|
|
return DeltaMessage(tool_calls=[ |
|
DeltaToolCall( |
|
index=self.current_tool_index, |
|
function=DeltaFunctionCall( |
|
arguments='"'), |
|
) |
|
]) |
|
else: |
|
|
|
value_chunk = delta_text |
|
|
|
|
|
if not self.current_param_value and ">" in value_chunk: |
|
gt_idx = value_chunk.find(">") |
|
value_chunk = value_chunk[gt_idx + 1:] |
|
|
|
if not self.current_param_value and value_chunk.startswith( |
|
"\n"): |
|
value_chunk = value_chunk[1:] |
|
|
|
if value_chunk: |
|
|
|
prev_escaped = json.dumps( |
|
self.current_param_value, ensure_ascii=False |
|
)[1:-1] if self.current_param_value else "" |
|
self.current_param_value += value_chunk |
|
full_escaped = json.dumps(self.current_param_value, |
|
ensure_ascii=False)[1:-1] |
|
delta_escaped = full_escaped[len(prev_escaped):] |
|
|
|
if delta_escaped: |
|
return DeltaMessage(tool_calls=[ |
|
DeltaToolCall( |
|
index=self.current_tool_index, |
|
function=DeltaFunctionCall( |
|
arguments=delta_escaped), |
|
) |
|
]) |
|
|
|
return None |
|
|