Skip to content

Commit 2667aad

Browse files
cedonleyweilong.yu
authored andcommitted
[Bugfix] Multiple fixes to tool streaming with hermes and mistral (vllm-project#10979)
Signed-off-by: cedonley <[email protected]>
1 parent 394b912 commit 2667aad

File tree

3 files changed

+69
-21
lines changed

3 files changed

+69
-21
lines changed

vllm/entrypoints/openai/serving_chat.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,21 +496,33 @@ async def chat_completion_stream_generator(
496496

497497
if self._should_check_for_unstreamed_tool_arg_tokens(
498498
delta_message, output) and tool_parser:
499+
latest_delta_len = 0
500+
if ((isinstance(
501+
delta_message.tool_calls[0].function,
502+
DeltaFunctionCall)) and isinstance(
503+
delta_message.tool_calls[0].function.
504+
arguments, str)):
505+
latest_delta_len = len(
506+
delta_message.tool_calls[0].function.
507+
arguments)
508+
499509
# get the expected call based on partial JSON
500510
# parsing which "autocompletes" the JSON
501511
expected_call = json.dumps(
502512
tool_parser.prev_tool_call_arr[index].get(
503-
"arguments", {}))
513+
"arguments", {}),
514+
ensure_ascii=False)
504515

505516
# get what we've streamed so far for arguments
506517
# for the current tool
507518
actual_call = tool_parser.streamed_args_for_tool[
508519
index]
520+
if (latest_delta_len > 0):
521+
actual_call = actual_call[:-latest_delta_len]
509522

510523
# check to see if there's anything left to stream
511524
remaining_call = expected_call.replace(
512525
actual_call, "", 1)
513-
514526
# set that as a delta message
515527
delta_message = DeltaMessage(tool_calls=[
516528
DeltaToolCall(index=index,

vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def extract_tool_calls(
9191
function=FunctionCall(
9292
name=function_call["name"],
9393
# function call args are JSON but as a string
94-
arguments=json.dumps(function_call["arguments"])))
94+
arguments=json.dumps(function_call["arguments"],
95+
ensure_ascii=False)))
9596
for function_call in raw_function_calls
9697
]
9798

@@ -139,13 +140,26 @@ def extract_tool_calls_streaming(
139140
self.tool_call_start_token_id)
140141
cur_tool_end_count = current_token_ids.count(
141142
self.tool_call_end_token_id)
143+
tool_call_portion = None
144+
text_portion = None
142145

143146
# case: if we're generating text, OR rounding out a tool call
144147
if (cur_tool_start_count == cur_tool_end_count
145-
and prev_tool_end_count == cur_tool_end_count):
148+
and prev_tool_end_count == cur_tool_end_count
149+
and self.tool_call_end_token not in delta_text):
146150
logger.debug("Generating text content! skipping tool parsing.")
147-
if delta_text != self.tool_call_end_token:
148-
return DeltaMessage(content=delta_text)
151+
return DeltaMessage(content=delta_text)
152+
153+
if self.tool_call_end_token in delta_text:
154+
logger.debug("tool_call_end_token in delta_text")
155+
full_text = current_text + delta_text
156+
tool_call_portion = full_text.split(
157+
self.tool_call_start_token)[-1].split(
158+
self.tool_call_end_token)[0].rstrip()
159+
delta_text = delta_text.split(
160+
self.tool_call_end_token)[0].rstrip()
161+
text_portion = delta_text.split(
162+
self.tool_call_end_token)[-1].lstrip()
149163

150164
# case: if tool open & close tag counts don't match, we're doing
151165
# imaginary "else" block here
@@ -184,15 +198,21 @@ def extract_tool_calls_streaming(
184198

185199
# case -- the current tool call is being closed.
186200
elif (cur_tool_start_count == cur_tool_end_count
187-
and cur_tool_end_count > prev_tool_end_count):
201+
and cur_tool_end_count >= prev_tool_end_count):
202+
if (self.prev_tool_call_arr is None
203+
or len(self.prev_tool_call_arr) == 0):
204+
logger.debug(
205+
"attempting to close tool call, but no tool call")
206+
return None
188207
diff = self.prev_tool_call_arr[self.current_tool_id].get(
189208
"arguments")
190209
if diff:
191210
diff = diff.encode('utf-8').decode(
192211
'unicode_escape') if diff is str else diff
193-
diff = json.dumps(
194-
diff, ensure_ascii=False
195-
)[len(self.streamed_args_for_tool[self.current_tool_id]):]
212+
if ('"}' not in delta_text):
213+
return None
214+
end_loc = delta_text.rindex('"}')
215+
diff = delta_text[:end_loc] + '"}'
196216
logger.debug(
197217
"Finishing tool and found diff that had not "
198218
"been streamed yet: %s", diff)
@@ -221,10 +241,15 @@ def extract_tool_calls_streaming(
221241
except partial_json_parser.core.exceptions.MalformedJSON:
222242
logger.debug('not enough tokens to parse into JSON yet')
223243
return None
244+
except json.decoder.JSONDecodeError:
245+
logger.debug("unable to parse JSON")
246+
return None
224247

225248
# case - we haven't sent the tool name yet. If it's available, send
226249
# it. otherwise, wait until it's available.
227250
if not self.current_tool_name_sent:
251+
if (current_tool_call is None):
252+
return None
228253
function_name: Union[str, None] = current_tool_call.get("name")
229254
if function_name:
230255
self.current_tool_name_sent = True
@@ -284,13 +309,17 @@ def extract_tool_calls_streaming(
284309
# autocompleting the JSON
285310
elif cur_arguments and not prev_arguments:
286311

287-
cur_arguments_json = json.dumps(cur_arguments)
312+
cur_arguments_json = json.dumps(cur_arguments,
313+
ensure_ascii=False)
288314
logger.debug("finding %s in %s", delta_text,
289315
cur_arguments_json)
290316

291317
# get the location where previous args differ from current
292-
args_delta_start_loc = cur_arguments_json.index(delta_text) \
293-
+ len(delta_text)
318+
if (delta_text not in cur_arguments_json[:-2]):
319+
return None
320+
args_delta_start_loc = cur_arguments_json[:-2]. \
321+
rindex(delta_text) + \
322+
len(delta_text)
294323

295324
# use that to find the actual delta
296325
arguments_delta = cur_arguments_json[:args_delta_start_loc]

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
extract_intermediate_diff)
2020
from vllm.logger import init_logger
2121
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
22-
from vllm.utils import random_uuid
2322

2423
logger = init_logger(__name__)
2524

@@ -109,7 +108,8 @@ def extract_tool_calls(
109108
function=FunctionCall(
110109
name=raw_function_call["name"],
111110
# function call args are JSON but as a string
112-
arguments=json.dumps(raw_function_call["arguments"])))
111+
arguments=json.dumps(raw_function_call["arguments"],
112+
ensure_ascii=False)))
113113
for raw_function_call in function_call_arr
114114
]
115115

@@ -199,7 +199,7 @@ def extract_tool_calls_streaming(
199199
diff: Union[str, None] = current_tool_call.get("arguments")
200200

201201
if diff:
202-
diff = json.dumps(diff).replace(
202+
diff = json.dumps(diff, ensure_ascii=False).replace(
203203
self.streamed_args_for_tool[self.current_tool_id],
204204
"")
205205
delta = DeltaMessage(tool_calls=[
@@ -232,7 +232,7 @@ def extract_tool_calls_streaming(
232232
delta = DeltaMessage(tool_calls=[
233233
DeltaToolCall(index=self.current_tool_id,
234234
type="function",
235-
id=f"chatcmpl-tool-{random_uuid()}",
235+
id=MistralToolCall.generate_random_id(),
236236
function=DeltaFunctionCall(
237237
name=function_name).model_dump(
238238
exclude_none=True))
@@ -250,6 +250,8 @@ def extract_tool_calls_streaming(
250250
cur_arguments = current_tool_call.get("arguments")
251251

252252
new_text = delta_text.replace("\'", "\"")
253+
if ('"}' in new_text):
254+
new_text = new_text[:new_text.rindex('"}')]
253255

254256
if not cur_arguments and not prev_arguments:
255257

@@ -260,12 +262,15 @@ def extract_tool_calls_streaming(
260262
"mid-arguments")
261263
delta = None
262264
elif cur_arguments and not prev_arguments:
263-
cur_arguments_json = json.dumps(cur_arguments)
265+
cur_arguments_json = json.dumps(cur_arguments,
266+
ensure_ascii=False)[:-2]
264267
logger.debug("finding %s in %s", new_text,
265268
cur_arguments_json)
266269

270+
if (new_text not in cur_arguments_json):
271+
return None
267272
arguments_delta = cur_arguments_json[:cur_arguments_json.
268-
index(new_text) +
273+
rindex(new_text) +
269274
len(new_text)]
270275
logger.debug("First tokens in arguments received: %s",
271276
arguments_delta)
@@ -279,8 +284,10 @@ def extract_tool_calls_streaming(
279284
self.current_tool_id] += arguments_delta
280285

281286
elif cur_arguments and prev_arguments:
282-
cur_args_json = json.dumps(cur_arguments)
283-
prev_args_json = json.dumps(prev_arguments)
287+
cur_args_json = json.dumps(cur_arguments,
288+
ensure_ascii=False)
289+
prev_args_json = json.dumps(prev_arguments,
290+
ensure_ascii=False)
284291
logger.debug("Searching for diff between \n%s\n%s",
285292
cur_args_json, prev_args_json)
286293

0 commit comments

Comments
 (0)