Skip to content

Commit 77d3284

Browse files
authored
Merge pull request #24 from WorkflowAI/guillaume/fix-issue-with-long-stream-chunks
Fix issue with very long streams
2 parents 0dbabf9 + e1af9ba commit 77d3284

File tree

3 files changed

+134
-31
lines changed

3 files changed

+134
-31
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.5.0"
3+
version = "0.5.1"
44
description = ""
55
authors = ["Guillaume Aquilina <[email protected]>"]
66
readme = "README.md"

workflowai/core/client/_api.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
import logging
12
from typing import Any, AsyncIterator, Literal, Optional, TypeVar, Union, overload
23

34
import httpx
45
from pydantic import BaseModel, TypeAdapter, ValidationError
56

6-
from workflowai.core.client._utils import split_chunks
77
from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError
88

99
# A type for return values
1010
_R = TypeVar("_R")
1111
_M = TypeVar("_M", bound=BaseModel)
1212

13+
_logger = logging.getLogger("WorkflowAI")
14+
1315

1416
class APIClient:
1517
def __init__(self, endpoint: str, api_key: str, source_headers: Optional[dict[str, str]] = None):
@@ -106,6 +108,37 @@ def _extract_error(
106108
response=response,
107109
) from exception
108110

111+
async def _wrap_sse(self, raw: AsyncIterator[bytes], termination_chars: bytes = b"\n\n"):
112+
data = b""
113+
in_data = False
114+
async for chunk in raw:
115+
data += chunk
116+
if not in_data:
117+
if data.startswith(b"data: "):
118+
data = data[6:]
119+
in_data = True
120+
else:
121+
# We will wait for the next chunk, we might be in the middle
122+
# of 'data: '
123+
continue
124+
125+
# Splitting the chunk by separator
126+
splits = data.split(b"\n\ndata: ")
127+
if len(splits) > 1:
128+
# Yielding the rest of the splits except the last one
129+
for data in splits[0:-1]:
130+
yield data
131+
# The last split could be incomplete
132+
data = splits[-1]
133+
134+
if data.endswith(termination_chars):
135+
yield data[: -len(termination_chars)]
136+
data = b""
137+
in_data = False
138+
139+
if data:
140+
_logger.warning("Data left after processing", extra={"data": data})
141+
109142
async def stream(
110143
self,
111144
method: Literal["GET", "POST"],
@@ -122,15 +155,14 @@ async def stream(
122155
if not response.is_success:
123156
# We need to read the response to get the error message
124157
await response.aread()
125-
response.raise_for_status()
158+
await self.raise_for_status(response)
159+
return
126160

127-
async for chunk in response.aiter_bytes():
128-
payload = ""
161+
async for chunk in self._wrap_sse(response.aiter_bytes()):
129162
try:
130-
for payload in split_chunks(chunk):
131-
yield returns.model_validate_json(payload)
163+
yield returns.model_validate_json(chunk)
132164
except ValidationError as e:
133-
raise self._extract_error(response, payload, e) from None
165+
raise self._extract_error(response, chunk, e) from None
134166

135167
async def raise_for_status(self, response: httpx.Response):
136168
if response.status_code < 200 or response.status_code >= 300:

workflowai/core/client/_api_test.py

Lines changed: 94 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from collections.abc import Awaitable, Callable
2+
13
import httpx
24
import pytest
35
from pydantic import BaseModel
4-
from pytest_httpx import HTTPXMock
6+
from pytest_httpx import HTTPXMock, IteratorStream
57

68
from workflowai.core.client._api import APIClient
79
from workflowai.core.domain.errors import WorkflowAIError
@@ -60,26 +62,95 @@ def test_extract_error_with_custom_error(self):
6062
assert e.value.response == response
6163

6264

63-
async def test_stream_404(httpx_mock: HTTPXMock):
64-
class TestInputModel(BaseModel):
65-
test_input: str
66-
67-
class TestOutputModel(BaseModel):
68-
test_output: str
69-
70-
httpx_mock.add_response(status_code=404)
71-
72-
client = APIClient(endpoint="https://blabla.com", api_key="test_api_key")
65+
@pytest.fixture
66+
def client() -> APIClient:
67+
return APIClient(endpoint="https://blabla.com", api_key="test_api_key")
68+
69+
70+
class TestInputModel(BaseModel):
71+
bla: str = "bla"
72+
73+
74+
class TestOutputModel(BaseModel):
75+
a: str
76+
77+
78+
class TestAPIClientStream:
79+
async def test_stream_404(self, httpx_mock: HTTPXMock, client: APIClient):
80+
class TestInputModel(BaseModel):
81+
test_input: str
82+
83+
class TestOutputModel(BaseModel):
84+
test_output: str
85+
86+
httpx_mock.add_response(status_code=404)
87+
88+
with pytest.raises(WorkflowAIError) as e: # noqa: PT012
89+
async for _ in client.stream(
90+
method="GET",
91+
path="test_path",
92+
data=TestInputModel(test_input="test"),
93+
returns=TestOutputModel,
94+
):
95+
pass
96+
97+
assert e.value.response
98+
assert e.value.response.status_code == 404
99+
assert e.value.response.reason_phrase == "Not Found"
100+
101+
@pytest.fixture
102+
async def stream_fn(self, client: APIClient):
103+
async def _stm():
104+
return [
105+
chunk
106+
async for chunk in client.stream(
107+
method="GET",
108+
path="test_path",
109+
data=TestInputModel(),
110+
returns=TestOutputModel,
111+
)
112+
]
113+
114+
return _stm
115+
116+
async def test_stream_with_single_chunk(
117+
self,
118+
stream_fn: Callable[[], Awaitable[list[TestOutputModel]]],
119+
httpx_mock: HTTPXMock,
120+
):
121+
httpx_mock.add_response(
122+
stream=IteratorStream(
123+
[
124+
b'data: {"a":"test"}\n\n',
125+
],
126+
),
127+
)
73128

74-
try:
75-
async for _ in client.stream(
76-
method="GET",
77-
path="test_path",
78-
data=TestInputModel(test_input="test"),
79-
returns=TestOutputModel,
80-
):
81-
pass
82-
except httpx.HTTPStatusError as e:
83-
assert isinstance(e, httpx.HTTPStatusError)
84-
assert e.response.status_code == 404
85-
assert e.response.reason_phrase == "Not Found"
129+
chunks = await stream_fn()
130+
assert chunks == [TestOutputModel(a="test")]
131+
132+
@pytest.mark.parametrize(
133+
"streamed_chunks",
134+
[
135+
# 2 perfect chunks([b'data: {"a":"test"}\n\n', b'data: {"a":"test2"}\n\n'],),
136+
[b'data: {"a":"test"}\n\n', b'data: {"a":"test2"}\n\n'],
137+
# 2 chunks in one
138+
[b'data: {"a":"test"}\n\ndata: {"a":"test2"}\n\n'],
139+
# Split not at the end
140+
[b'data: {"a":"test"}', b'\n\ndata: {"a":"test2"}\n\n'],
141+
# Really messy
142+
[b"dat", b'a: {"a":"', b'test"}', b"\n", b"\ndata", b': {"a":"test2"}\n\n'],
143+
],
144+
)
145+
async def test_stream_with_multiple_chunks(
146+
self,
147+
stream_fn: Callable[[], Awaitable[list[TestOutputModel]]],
148+
httpx_mock: HTTPXMock,
149+
streamed_chunks: list[bytes],
150+
):
151+
assert isinstance(streamed_chunks, list), "sanity check"
152+
assert all(isinstance(chunk, bytes) for chunk in streamed_chunks), "sanity check"
153+
154+
httpx_mock.add_response(stream=IteratorStream(streamed_chunks))
155+
chunks = await stream_fn()
156+
assert chunks == [TestOutputModel(a="test"), TestOutputModel(a="test2")]

0 commit comments

Comments
 (0)