diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index cee737f4f5..095746630b 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -47,8 +47,9 @@ def __stream__(self) -> Iterator[ResponseT]: cast_to = self._cast_to response = self.response process_data = self._client._process_response_data + iterator = self._iter_events() - for sse in self._iter_events(): + for sse in iterator: if sse.data.startswith("[DONE]"): break @@ -63,6 +64,10 @@ def __stream__(self) -> Iterator[ResponseT]: yield process_data(data=data, cast_to=cast_to, response=response) + # Ensure the entire stream is consumed + for sse in iterator: + ... + class AsyncStream(Generic[ResponseT]): """Provides the core interface to iterate over an asynchronous stream response.""" @@ -97,8 +102,9 @@ async def __stream__(self) -> AsyncIterator[ResponseT]: cast_to = self._cast_to response = self.response process_data = self._client._process_response_data + iterator = self._iter_events() - async for sse in self._iter_events(): + async for sse in iterator: if sse.data.startswith("[DONE]"): break @@ -113,6 +119,10 @@ async def __stream__(self) -> AsyncIterator[ResponseT]: yield process_data(data=data, cast_to=cast_to, response=response) + # Ensure the entire stream is consumed + async for sse in iterator: + ... + class ServerSentEvent: def __init__(