openai compatible parallel tool calls
Browse files- chat_template_with_tools.jinja +28 -13
- vllm_tool_parser.py +126 -0
chat_template_with_tools.jinja
CHANGED
@@ -6,21 +6,38 @@
|
|
6 |
{%- set system_message = default_system_message %}
|
7 |
{%- set loop_messages = messages %}
|
8 |
{%- endif %}
|
|
|
|
|
9 |
{% if tools is not none and tools|length > 0 %}
|
10 |
-
{%- set tool_str = tools|tojson -%}
|
11 |
{%- set tool_instructions =
|
12 |
-
'
|
13 |
-
~ '[{"name": "tool1", "
|
14 |
-
~ '
|
15 |
-
~ '[AVAILABLE_TOOLS]'
|
16 |
-
~ tool_str
|
17 |
-
~ '[/AVAILABLE_TOOLS]\n'
|
18 |
-%}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
{%- set system_message = system_message ~ tool_instructions -%}
|
20 |
{% endif %}
|
21 |
-
{{
|
22 |
-
{%- for message in loop_messages %}
|
23 |
|
|
|
24 |
{%- if message['role'] == 'user' %}
|
25 |
{{- '[INST]' + message['content'] + '[/INST]' }}
|
26 |
|
@@ -28,10 +45,8 @@
|
|
28 |
{{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
|
29 |
|
30 |
{%- elif message['role'] == 'assistant' %}
|
31 |
-
{%- if message.tool_calls is defined and message.tool_calls is not none -%}
|
32 |
-
[TOOL_CALLS]
|
33 |
-
{{ {"tool_calls": message.tool_calls}|tojson }}
|
34 |
-
[/TOOL_CALLS]
|
35 |
{%- elif message['content'] is defined and message['content'] is not none -%}
|
36 |
{{- message['content'] + eos_token }}
|
37 |
{%- endif %}
|
|
|
6 |
{%- set system_message = default_system_message %}
|
7 |
{%- set loop_messages = messages %}
|
8 |
{%- endif %}
|
9 |
+
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
10 |
+
|
11 |
{% if tools is not none and tools|length > 0 %}
|
|
|
12 |
{%- set tool_instructions =
|
13 |
+
'Use the available tools appropriately to fulfill your instructions and achieve your goals by returning a JSON array pre-pended with `[TOOL_CALLS]` like this: '
|
14 |
+
~ '[TOOL_CALLS][{"name": "tool1", "arguments": {"p1": "val1"}}, {"name": "tool2", "arguments": {"p2": "val2", "p3": 23}}] '
|
15 |
+
~ 'Your available tools are: '
|
|
|
|
|
|
|
16 |
-%}
|
17 |
+
|
18 |
+
{# Build out the tool list however you want #}
|
19 |
+
{%- set tool_instructions = tool_instructions ~ '[AVAILABLE_TOOLS]' %}
|
20 |
+
{%- for tool in tools %}
|
21 |
+
{%- if not loop.first %}, {% endif %}
|
22 |
+
{%- set tool_definition = tool.function %}
|
23 |
+
{%- set tool_instructions = tool_instructions ~ '{"type":"function","function":{' %}
|
24 |
+
{%- for key, val in tool_definition.items() if key != "return" %}
|
25 |
+
{%- if not loop.first %}, {% endif %}
|
26 |
+
{%- if val is string %}
|
27 |
+
{%- set tool_instructions = tool_instructions ~ '"' ~ key ~ '":"' ~ val ~ '"' %}
|
28 |
+
{%- else %}
|
29 |
+
{%- set tool_instructions = tool_instructions ~ '"' ~ key ~ '":' ~ (val|tojson) %}
|
30 |
+
{%- endif %}
|
31 |
+
{%- endfor %}
|
32 |
+
{%- set tool_instructions = tool_instructions ~ '}}' %}
|
33 |
+
{%- endfor %}
|
34 |
+
{%- set tool_instructions = tool_instructions ~ '[/AVAILABLE_TOOLS]\n' %}
|
35 |
+
|
36 |
{%- set system_message = system_message ~ tool_instructions -%}
|
37 |
{% endif %}
|
38 |
+
{{ '[SYSTEM_PROMPT]' ~ system_message ~ '[/SYSTEM_PROMPT]' }}
|
|
|
39 |
|
40 |
+
{%- for message in loop_messages %}
|
41 |
{%- if message['role'] == 'user' %}
|
42 |
{{- '[INST]' + message['content'] + '[/INST]' }}
|
43 |
|
|
|
45 |
{{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
|
46 |
|
47 |
{%- elif message['role'] == 'assistant' %}
|
48 |
+
{%- if message.tool_calls is defined and message.tool_calls is not none and message.tool_calls|length > 0 -%}
|
49 |
+
[TOOL_CALLS][{{ message.tool_calls|tojson }}]
|
|
|
|
|
50 |
{%- elif message['content'] is defined and message['content'] is not none -%}
|
51 |
{{- message['content'] + eos_token }}
|
52 |
{%- endif %}
|
vllm_tool_parser.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from vllm.entrypoints.openai.protocol import (
|
4 |
+
ExtractedToolCallInformation,
|
5 |
+
FunctionCall,
|
6 |
+
ToolCall,
|
7 |
+
)
|
8 |
+
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
9 |
+
ToolParser,
|
10 |
+
ToolParserManager,
|
11 |
+
)
|
12 |
+
from vllm.logger import init_logger
|
13 |
+
|
14 |
+
logger = init_logger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
@ToolParserManager.register_module("mistral_v3_debug")
|
18 |
+
class MistralV3DebugToolParser(ToolParser):
|
19 |
+
"""
|
20 |
+
Custom parser for Mistral v3 with detailed logging.
|
21 |
+
Ensures OpenAI-compatible tool calls while debugging missing arguments.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def extract_tool_calls(
|
25 |
+
self, model_output: str, request
|
26 |
+
) -> ExtractedToolCallInformation:
|
27 |
+
"""
|
28 |
+
Extracts tool calls from model output using Mistral's special tokens.
|
29 |
+
Accepts multiple calls in a comma-separated list, either with or
|
30 |
+
without leading/trailing square brackets.
|
31 |
+
"""
|
32 |
+
|
33 |
+
logger.info(f"🔍 Extracting tool calls from model output... | {repr(request)}")
|
34 |
+
logger.info(f"Raw model output: {model_output}")
|
35 |
+
|
36 |
+
try:
|
37 |
+
# Find tool calls inside [TOOL_CALLS][ ... ]
|
38 |
+
tool_call_match = re.search(
|
39 |
+
r"\[TOOL_CALLS\]\[(.*?)\]", model_output, re.DOTALL
|
40 |
+
)
|
41 |
+
if not tool_call_match:
|
42 |
+
logger.warning(
|
43 |
+
"⚠️ No valid [TOOL_CALLS] block found. Treating as normal content."
|
44 |
+
)
|
45 |
+
return ExtractedToolCallInformation(
|
46 |
+
tools_called=False, tool_calls=[], content=model_output
|
47 |
+
)
|
48 |
+
|
49 |
+
# Extract JSON snippet from inside [TOOL_CALLS][...]
|
50 |
+
tool_call_json = tool_call_match.group(1).strip()
|
51 |
+
logger.debug(f"📥 Extracted JSON snippet: {tool_call_json}")
|
52 |
+
|
53 |
+
# Ensure valid JSON list format
|
54 |
+
if not tool_call_json.startswith("["):
|
55 |
+
logger.debug("🔧 Wrapping snippet with leading '['")
|
56 |
+
tool_call_json = f"[{tool_call_json}"
|
57 |
+
if not tool_call_json.endswith("]"):
|
58 |
+
logger.debug("🔧 Wrapping snippet with trailing ']'")
|
59 |
+
tool_call_json = f"{tool_call_json}]"
|
60 |
+
|
61 |
+
logger.debug(f"📝 Final JSON to parse: {tool_call_json}")
|
62 |
+
tool_call_data = json.loads(tool_call_json)
|
63 |
+
|
64 |
+
# Ensure we have a list of tool calls
|
65 |
+
if isinstance(tool_call_data, dict):
|
66 |
+
logger.debug(
|
67 |
+
"👀 Detected single tool call dictionary; converting to a list."
|
68 |
+
)
|
69 |
+
tool_call_data = [tool_call_data]
|
70 |
+
elif not isinstance(tool_call_data, list):
|
71 |
+
logger.error(
|
72 |
+
"🚨 Tool call data is neither a list nor a valid object list. Returning as content."
|
73 |
+
)
|
74 |
+
return ExtractedToolCallInformation(
|
75 |
+
tools_called=False, tool_calls=[], content=model_output
|
76 |
+
)
|
77 |
+
|
78 |
+
tool_calls = []
|
79 |
+
for i, tool_item in enumerate(tool_call_data):
|
80 |
+
logger.debug(f"🛠️ Processing item {i}: {tool_item}")
|
81 |
+
|
82 |
+
# Ensure each item is a dict with "name" and "arguments"
|
83 |
+
if not isinstance(tool_item, dict):
|
84 |
+
logger.error(f"❌ Item {i} is not a JSON object. Skipping.")
|
85 |
+
continue
|
86 |
+
|
87 |
+
name = tool_item.get("name", "unknown_tool")
|
88 |
+
args = tool_item.get("arguments", {})
|
89 |
+
|
90 |
+
# Ensure arguments is a dict
|
91 |
+
if not isinstance(args, dict):
|
92 |
+
logger.error(
|
93 |
+
f"❌ Arguments for tool '{name}' are not a dict. Using empty dict."
|
94 |
+
)
|
95 |
+
args = {}
|
96 |
+
|
97 |
+
# Convert arguments to a JSON string (for OpenAI-compatible function calls)
|
98 |
+
arguments_json = json.dumps(args, ensure_ascii=False)
|
99 |
+
logger.debug(f"✅ Parsed arguments for '{name}': {arguments_json}")
|
100 |
+
|
101 |
+
# Build a single ToolCall object
|
102 |
+
tool_calls.append(
|
103 |
+
ToolCall(
|
104 |
+
type="function",
|
105 |
+
id=f"call_{i}",
|
106 |
+
function=FunctionCall(name=name, arguments=arguments_json),
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
logger.info(f"✅ Successfully extracted {len(tool_calls)} tool call(s).")
|
111 |
+
# We have recognized tool calls, so set content=None
|
112 |
+
return ExtractedToolCallInformation(
|
113 |
+
tools_called=True, tool_calls=tool_calls, content=None
|
114 |
+
)
|
115 |
+
|
116 |
+
except json.JSONDecodeError as e:
|
117 |
+
logger.error(f"❌ Failed to parse tool calls JSON: {str(e)}")
|
118 |
+
return ExtractedToolCallInformation(
|
119 |
+
tools_called=False, tool_calls=[], content=model_output
|
120 |
+
)
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
logger.exception("🔥 Unexpected error while parsing tool calls.")
|
124 |
+
return ExtractedToolCallInformation(
|
125 |
+
tools_called=False, tool_calls=[], content=model_output
|
126 |
+
)
|