LHC88 commited on
Commit
23dd25f
·
1 Parent(s): 2fd1197

openai compatible parallel tool calls

Browse files
Files changed (2) hide show
  1. chat_template_with_tools.jinja +28 -13
  2. vllm_tool_parser.py +126 -0
chat_template_with_tools.jinja CHANGED
@@ -6,21 +6,38 @@
6
  {%- set system_message = default_system_message %}
7
  {%- set loop_messages = messages %}
8
  {%- endif %}
 
 
9
  {% if tools is not none and tools|length > 0 %}
10
- {%- set tool_str = tools|tojson -%}
11
  {%- set tool_instructions =
12
- '\nYou can use multiple tools by responding with a JSON array like this: '
13
- ~ '[{"name": "tool1", "parameters": {"p1": "val1" }}, {"name": "tool2", "parameters": {"p2": "val2", "p3": 23}}] '
14
- ~ ' Your available tools are: '
15
- ~ '[AVAILABLE_TOOLS]'
16
- ~ tool_str
17
- ~ '[/AVAILABLE_TOOLS]\n'
18
  -%}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  {%- set system_message = system_message ~ tool_instructions -%}
20
  {% endif %}
21
- {{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }}
22
- {%- for message in loop_messages %}
23
 
 
24
  {%- if message['role'] == 'user' %}
25
  {{- '[INST]' + message['content'] + '[/INST]' }}
26
 
@@ -28,10 +45,8 @@
28
  {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
29
 
30
  {%- elif message['role'] == 'assistant' %}
31
- {%- if message.tool_calls is defined and message.tool_calls is not none -%}
32
- [TOOL_CALLS]
33
- {{ {"tool_calls": message.tool_calls}|tojson }}
34
- [/TOOL_CALLS]
35
  {%- elif message['content'] is defined and message['content'] is not none -%}
36
  {{- message['content'] + eos_token }}
37
  {%- endif %}
 
6
  {%- set system_message = default_system_message %}
7
  {%- set loop_messages = messages %}
8
  {%- endif %}
9
+ {%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
10
+
11
  {% if tools is not none and tools|length > 0 %}
 
12
  {%- set tool_instructions =
13
+ 'Use the available tools appropriately to fulfill your instructions and achieve your goals by returning a JSON array pre-pended with `[TOOL_CALLS]` like this: '
14
+ ~ '[TOOL_CALLS][{"name": "tool1", "arguments": {"p1": "val1"}}, {"name": "tool2", "arguments": {"p2": "val2", "p3": 23}}] '
15
+ ~ 'Your available tools are: '
 
 
 
16
  -%}
17
+
18
+ {# Build out the tool list however you want #}
19
+ {%- set tool_instructions = tool_instructions ~ '[AVAILABLE_TOOLS]' %}
20
+ {%- for tool in tools %}
21
+ {%- if not loop.first %}, {% endif %}
22
+ {%- set tool_definition = tool.function %}
23
+ {%- set tool_instructions = tool_instructions ~ '{"type":"function","function":{' %}
24
+ {%- for key, val in tool_definition.items() if key != "return" %}
25
+ {%- if not loop.first %}, {% endif %}
26
+ {%- if val is string %}
27
+ {%- set tool_instructions = tool_instructions ~ '"' ~ key ~ '":"' ~ val ~ '"' %}
28
+ {%- else %}
29
+ {%- set tool_instructions = tool_instructions ~ '"' ~ key ~ '":' ~ (val|tojson) %}
30
+ {%- endif %}
31
+ {%- endfor %}
32
+ {%- set tool_instructions = tool_instructions ~ '}}' %}
33
+ {%- endfor %}
34
+ {%- set tool_instructions = tool_instructions ~ '[/AVAILABLE_TOOLS]\n' %}
35
+
36
  {%- set system_message = system_message ~ tool_instructions -%}
37
  {% endif %}
38
+ {{ '[SYSTEM_PROMPT]' ~ system_message ~ '[/SYSTEM_PROMPT]' }}
 
39
 
40
+ {%- for message in loop_messages %}
41
  {%- if message['role'] == 'user' %}
42
  {{- '[INST]' + message['content'] + '[/INST]' }}
43
 
 
45
  {{- '[SYSTEM_PROMPT]' + message['content'] + '[/SYSTEM_PROMPT]' }}
46
 
47
  {%- elif message['role'] == 'assistant' %}
48
+ {%- if message.tool_calls is defined and message.tool_calls is not none and message.tool_calls|length > 0 -%}
49
+ [TOOL_CALLS][{{ message.tool_calls|tojson }}]
 
 
50
  {%- elif message['content'] is defined and message['content'] is not none -%}
51
  {{- message['content'] + eos_token }}
52
  {%- endif %}
vllm_tool_parser.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from vllm.entrypoints.openai.protocol import (
4
+ ExtractedToolCallInformation,
5
+ FunctionCall,
6
+ ToolCall,
7
+ )
8
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
9
+ ToolParser,
10
+ ToolParserManager,
11
+ )
12
+ from vllm.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ @ToolParserManager.register_module("mistral_v3_debug")
18
+ class MistralV3DebugToolParser(ToolParser):
19
+ """
20
+ Custom parser for Mistral v3 with detailed logging.
21
+ Ensures OpenAI-compatible tool calls while debugging missing arguments.
22
+ """
23
+
24
+ def extract_tool_calls(
25
+ self, model_output: str, request
26
+ ) -> ExtractedToolCallInformation:
27
+ """
28
+ Extracts tool calls from model output using Mistral's special tokens.
29
+ Accepts multiple calls in a comma-separated list, either with or
30
+ without leading/trailing square brackets.
31
+ """
32
+
33
+ logger.info(f"🔍 Extracting tool calls from model output... | {repr(request)}")
34
+ logger.info(f"Raw model output: {model_output}")
35
+
36
+ try:
37
+ # Find tool calls inside [TOOL_CALLS][ ... ]
38
+ tool_call_match = re.search(
39
+ r"\[TOOL_CALLS\]\[(.*?)\]", model_output, re.DOTALL
40
+ )
41
+ if not tool_call_match:
42
+ logger.warning(
43
+ "⚠️ No valid [TOOL_CALLS] block found. Treating as normal content."
44
+ )
45
+ return ExtractedToolCallInformation(
46
+ tools_called=False, tool_calls=[], content=model_output
47
+ )
48
+
49
+ # Extract JSON snippet from inside [TOOL_CALLS][...]
50
+ tool_call_json = tool_call_match.group(1).strip()
51
+ logger.debug(f"📥 Extracted JSON snippet: {tool_call_json}")
52
+
53
+ # Ensure valid JSON list format
54
+ if not tool_call_json.startswith("["):
55
+ logger.debug("🔧 Wrapping snippet with leading '['")
56
+ tool_call_json = f"[{tool_call_json}"
57
+ if not tool_call_json.endswith("]"):
58
+ logger.debug("🔧 Wrapping snippet with trailing ']'")
59
+ tool_call_json = f"{tool_call_json}]"
60
+
61
+ logger.debug(f"📝 Final JSON to parse: {tool_call_json}")
62
+ tool_call_data = json.loads(tool_call_json)
63
+
64
+ # Ensure we have a list of tool calls
65
+ if isinstance(tool_call_data, dict):
66
+ logger.debug(
67
+ "👀 Detected single tool call dictionary; converting to a list."
68
+ )
69
+ tool_call_data = [tool_call_data]
70
+ elif not isinstance(tool_call_data, list):
71
+ logger.error(
72
+ "🚨 Tool call data is neither a list nor a valid object list. Returning as content."
73
+ )
74
+ return ExtractedToolCallInformation(
75
+ tools_called=False, tool_calls=[], content=model_output
76
+ )
77
+
78
+ tool_calls = []
79
+ for i, tool_item in enumerate(tool_call_data):
80
+ logger.debug(f"🛠️ Processing item {i}: {tool_item}")
81
+
82
+ # Ensure each item is a dict with "name" and "arguments"
83
+ if not isinstance(tool_item, dict):
84
+ logger.error(f"❌ Item {i} is not a JSON object. Skipping.")
85
+ continue
86
+
87
+ name = tool_item.get("name", "unknown_tool")
88
+ args = tool_item.get("arguments", {})
89
+
90
+ # Ensure arguments is a dict
91
+ if not isinstance(args, dict):
92
+ logger.error(
93
+ f"❌ Arguments for tool '{name}' are not a dict. Using empty dict."
94
+ )
95
+ args = {}
96
+
97
+ # Convert arguments to a JSON string (for OpenAI-compatible function calls)
98
+ arguments_json = json.dumps(args, ensure_ascii=False)
99
+ logger.debug(f"✅ Parsed arguments for '{name}': {arguments_json}")
100
+
101
+ # Build a single ToolCall object
102
+ tool_calls.append(
103
+ ToolCall(
104
+ type="function",
105
+ id=f"call_{i}",
106
+ function=FunctionCall(name=name, arguments=arguments_json),
107
+ )
108
+ )
109
+
110
+ logger.info(f"✅ Successfully extracted {len(tool_calls)} tool call(s).")
111
+ # We have recognized tool calls, so set content=None
112
+ return ExtractedToolCallInformation(
113
+ tools_called=True, tool_calls=tool_calls, content=None
114
+ )
115
+
116
+ except json.JSONDecodeError as e:
117
+ logger.error(f"❌ Failed to parse tool calls JSON: {str(e)}")
118
+ return ExtractedToolCallInformation(
119
+ tools_called=False, tool_calls=[], content=model_output
120
+ )
121
+
122
+ except Exception as e:
123
+ logger.exception("🔥 Unexpected error while parsing tool calls.")
124
+ return ExtractedToolCallInformation(
125
+ tools_called=False, tool_calls=[], content=model_output
126
+ )