@@ -91,7 +91,8 @@ def extract_tool_calls(
9191 function = FunctionCall (
9292 name = function_call ["name" ],
9393 # function call args are JSON but as a string
94- arguments = json .dumps (function_call ["arguments" ])))
94+ arguments = json .dumps (function_call ["arguments" ],
95+ ensure_ascii = False )))
9596 for function_call in raw_function_calls
9697 ]
9798
@@ -139,13 +140,26 @@ def extract_tool_calls_streaming(
139140 self .tool_call_start_token_id )
140141 cur_tool_end_count = current_token_ids .count (
141142 self .tool_call_end_token_id )
143+ tool_call_portion = None
144+ text_portion = None
142145
143146 # case: if we're generating text, OR rounding out a tool call
144147 if (cur_tool_start_count == cur_tool_end_count
145- and prev_tool_end_count == cur_tool_end_count ):
148+ and prev_tool_end_count == cur_tool_end_count
149+ and self .tool_call_end_token not in delta_text ):
146150 logger .debug ("Generating text content! skipping tool parsing." )
147- if delta_text != self .tool_call_end_token :
148- return DeltaMessage (content = delta_text )
151+ return DeltaMessage (content = delta_text )
152+
153+ if self .tool_call_end_token in delta_text :
154+ logger .debug ("tool_call_end_token in delta_text" )
155+ full_text = current_text + delta_text
156+ tool_call_portion = full_text .split (
157+ self .tool_call_start_token )[- 1 ].split (
158+ self .tool_call_end_token )[0 ].rstrip ()
159+ delta_text = delta_text .split (
160+ self .tool_call_end_token )[0 ].rstrip ()
161+ text_portion = delta_text .split (
162+ self .tool_call_end_token )[- 1 ].lstrip ()
149163
150164 # case: if tool open & close tag counts don't match, we're doing
151165 # imaginary "else" block here
@@ -184,15 +198,21 @@ def extract_tool_calls_streaming(
184198
185199 # case -- the current tool call is being closed.
186200 elif (cur_tool_start_count == cur_tool_end_count
187- and cur_tool_end_count > prev_tool_end_count ):
201+ and cur_tool_end_count >= prev_tool_end_count ):
202+ if (self .prev_tool_call_arr is None
203+ or len (self .prev_tool_call_arr ) == 0 ):
204+ logger .debug (
205+ "attempting to close tool call, but no tool call" )
206+ return None
188207 diff = self .prev_tool_call_arr [self .current_tool_id ].get (
189208 "arguments" )
190209 if diff :
191210 diff = diff .encode ('utf-8' ).decode (
192211 'unicode_escape' ) if diff is str else diff
193- diff = json .dumps (
194- diff , ensure_ascii = False
195- )[len (self .streamed_args_for_tool [self .current_tool_id ]):]
212+ if ('"}' not in delta_text ):
213+ return None
214+ end_loc = delta_text .rindex ('"}' )
215+ diff = delta_text [:end_loc ] + '"}'
196216 logger .debug (
197217 "Finishing tool and found diff that had not "
198218 "been streamed yet: %s" , diff )
@@ -221,10 +241,15 @@ def extract_tool_calls_streaming(
221241 except partial_json_parser .core .exceptions .MalformedJSON :
222242 logger .debug ('not enough tokens to parse into JSON yet' )
223243 return None
244+ except json .decoder .JSONDecodeError :
245+ logger .debug ("unable to parse JSON" )
246+ return None
224247
225248 # case - we haven't sent the tool name yet. If it's available, send
226249 # it. otherwise, wait until it's available.
227250 if not self .current_tool_name_sent :
251+ if (current_tool_call is None ):
252+ return None
228253 function_name : Union [str , None ] = current_tool_call .get ("name" )
229254 if function_name :
230255 self .current_tool_name_sent = True
@@ -284,13 +309,17 @@ def extract_tool_calls_streaming(
284309 # autocompleting the JSON
285310 elif cur_arguments and not prev_arguments :
286311
287- cur_arguments_json = json .dumps (cur_arguments )
312+ cur_arguments_json = json .dumps (cur_arguments ,
313+ ensure_ascii = False )
288314 logger .debug ("finding %s in %s" , delta_text ,
289315 cur_arguments_json )
290316
291317 # get the location where previous args differ from current
292- args_delta_start_loc = cur_arguments_json .index (delta_text ) \
293- + len (delta_text )
318+ if (delta_text not in cur_arguments_json [:- 2 ]):
319+ return None
320+ args_delta_start_loc = cur_arguments_json [:- 2 ]. \
321+ rindex (delta_text ) + \
322+ len (delta_text )
294323
295324 # use that to find the actual delta
296325 arguments_delta = cur_arguments_json [:args_delta_start_loc ]
0 commit comments