@@ -65,14 +65,15 @@ def __init__(
65
65
self ._message = None
66
66
self ._content_block = {}
67
67
self ._record_message = False
68
+ self ._ended = False
68
69
69
70
def __iter__ (self ):
70
71
try :
71
72
for event in self .__wrapped__ :
72
73
self ._process_event (event )
73
74
yield event
74
75
except EventStreamError as exc :
75
- self ._stream_error_callback (exc )
76
+ self ._handle_stream_error (exc )
76
77
raise
77
78
78
79
def _process_event (self , event ):
@@ -133,15 +134,22 @@ def _process_event(self, event):
133
134
134
135
if output_tokens := usage .get ("outputTokens" ):
135
136
self ._response ["usage" ]["outputTokens" ] = output_tokens
136
-
137
- self ._stream_done_callback (self ._response )
137
+ self ._complete_stream (self ._response )
138
138
139
139
return
140
140
141
141
def close (self ):
142
142
self .__wrapped__ .close ()
143
143
# Treat the stream as done to ensure the span end.
144
- self ._stream_done_callback (self ._response )
144
+ self ._complete_stream (self ._response )
145
+
146
+ def _complete_stream (self , response ):
147
+ self ._stream_done_callback (response , self ._ended )
148
+ self ._ended = True
149
+
150
+ def _handle_stream_error (self , exc ):
151
+ self ._stream_error_callback (exc , self ._ended )
152
+ self ._ended = True
145
153
146
154
147
155
# pylint: disable=abstract-method
@@ -168,19 +176,28 @@ def __init__(
168
176
self ._content_block = {}
169
177
self ._tool_json_input_buf = ""
170
178
self ._record_message = False
179
+ self ._ended = False
171
180
172
181
def close (self ):
173
182
self .__wrapped__ .close ()
174
183
# Treat the stream as done to ensure the span end.
175
- self ._stream_done_callback (self ._response )
184
+ self ._stream_done_callback (self ._response , self ._ended )
185
+
186
+ def _complete_stream (self , response ):
187
+ self ._stream_done_callback (response , self ._ended )
188
+ self ._ended = True
189
+
190
+ def _handle_stream_error (self , exc ):
191
+ self ._stream_error_callback (exc , self ._ended )
192
+ self ._ended = True
176
193
177
194
def __iter__ (self ):
178
195
try :
179
196
for event in self .__wrapped__ :
180
197
self ._process_event (event )
181
198
yield event
182
199
except EventStreamError as exc :
183
- self ._stream_error_callback (exc )
200
+ self ._handle_stream_error (exc )
184
201
raise
185
202
186
203
def _process_event (self , event ):
@@ -223,7 +240,7 @@ def _process_amazon_titan_chunk(self, chunk):
223
240
self ._response ["output" ] = {
224
241
"message" : {"content" : [{"text" : chunk ["outputText" ]}]}
225
242
}
226
- self ._stream_done_callback (self ._response )
243
+ self ._complete_stream (self ._response )
227
244
228
245
def _process_amazon_nova_chunk (self , chunk ):
229
246
# pylint: disable=too-many-branches
@@ -293,7 +310,7 @@ def _process_amazon_nova_chunk(self, chunk):
293
310
if output_tokens := usage .get ("outputTokens" ):
294
311
self ._response ["usage" ]["outputTokens" ] = output_tokens
295
312
296
- self ._stream_done_callback (self ._response )
313
+ self ._complete_stream (self ._response )
297
314
return
298
315
299
316
def _process_anthropic_claude_chunk (self , chunk ):
@@ -365,7 +382,7 @@ def _process_anthropic_claude_chunk(self, chunk):
365
382
self ._record_message = False
366
383
self ._message = None
367
384
368
- self ._stream_done_callback (self ._response )
385
+ self ._complete_stream (self ._response )
369
386
return
370
387
371
388
0 commit comments