Skip to content

Fix issue with very long streams #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "workflowai"
version = "0.5.0"
version = "0.5.1"
description = ""
authors = ["Guillaume Aquilina <[email protected]>"]
readme = "README.md"
Expand Down
46 changes: 39 additions & 7 deletions workflowai/core/client/_api.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from typing import Any, AsyncIterator, Literal, Optional, TypeVar, Union, overload

import httpx
from pydantic import BaseModel, TypeAdapter, ValidationError

from workflowai.core.client._utils import split_chunks
from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError

# A type for return values
_R = TypeVar("_R")
_M = TypeVar("_M", bound=BaseModel)

_logger = logging.getLogger("WorkflowAI")


class APIClient:
def __init__(self, endpoint: str, api_key: str, source_headers: Optional[dict[str, str]] = None):
Expand Down Expand Up @@ -106,6 +108,37 @@ def _extract_error(
response=response,
) from exception

async def _wrap_sse(self, raw: AsyncIterator[bytes], termination_chars: bytes = b"\n\n"):
data = b""
in_data = False
async for chunk in raw:
data += chunk
if not in_data:
if data.startswith(b"data: "):
data = data[6:]
in_data = True
else:
# We will wait for the next chunk, we might be in the middle
# of 'data: '
continue

# Splitting the chunk by separator
splits = data.split(b"\n\ndata: ")
if len(splits) > 1:
# Yielding the rest of the splits except the last one
for data in splits[0:-1]:
yield data
# The last split could be incomplete
data = splits[-1]

if data.endswith(termination_chars):
yield data[: -len(termination_chars)]
data = b""
in_data = False

if data:
_logger.warning("Data left after processing", extra={"data": data})

async def stream(
self,
method: Literal["GET", "POST"],
Expand All @@ -122,15 +155,14 @@ async def stream(
if not response.is_success:
# We need to read the response to get the error message
await response.aread()
response.raise_for_status()
await self.raise_for_status(response)
return

async for chunk in response.aiter_bytes():
payload = ""
async for chunk in self._wrap_sse(response.aiter_bytes()):
try:
for payload in split_chunks(chunk):
yield returns.model_validate_json(payload)
yield returns.model_validate_json(chunk)
except ValidationError as e:
raise self._extract_error(response, payload, e) from None
raise self._extract_error(response, chunk, e) from None

async def raise_for_status(self, response: httpx.Response):
if response.status_code < 200 or response.status_code >= 300:
Expand Down
117 changes: 94 additions & 23 deletions workflowai/core/client/_api_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from collections.abc import Awaitable, Callable

import httpx
import pytest
from pydantic import BaseModel
from pytest_httpx import HTTPXMock
from pytest_httpx import HTTPXMock, IteratorStream

from workflowai.core.client._api import APIClient
from workflowai.core.domain.errors import WorkflowAIError
Expand Down Expand Up @@ -60,26 +62,95 @@ def test_extract_error_with_custom_error(self):
assert e.value.response == response


async def test_stream_404(httpx_mock: HTTPXMock):
class TestInputModel(BaseModel):
test_input: str

class TestOutputModel(BaseModel):
test_output: str

httpx_mock.add_response(status_code=404)

client = APIClient(endpoint="https://blabla.com", api_key="test_api_key")
@pytest.fixture
def client() -> APIClient:
return APIClient(endpoint="https://blabla.com", api_key="test_api_key")


class TestInputModel(BaseModel):
bla: str = "bla"


class TestOutputModel(BaseModel):
a: str


class TestAPIClientStream:
async def test_stream_404(self, httpx_mock: HTTPXMock, client: APIClient):
class TestInputModel(BaseModel):
test_input: str

class TestOutputModel(BaseModel):
test_output: str

httpx_mock.add_response(status_code=404)

with pytest.raises(WorkflowAIError) as e: # noqa: PT012
async for _ in client.stream(
method="GET",
path="test_path",
data=TestInputModel(test_input="test"),
returns=TestOutputModel,
):
pass

assert e.value.response
assert e.value.response.status_code == 404
assert e.value.response.reason_phrase == "Not Found"

@pytest.fixture
async def stream_fn(self, client: APIClient):
async def _stm():
return [
chunk
async for chunk in client.stream(
method="GET",
path="test_path",
data=TestInputModel(),
returns=TestOutputModel,
)
]

return _stm

async def test_stream_with_single_chunk(
self,
stream_fn: Callable[[], Awaitable[list[TestOutputModel]]],
httpx_mock: HTTPXMock,
):
httpx_mock.add_response(
stream=IteratorStream(
[
b'data: {"a":"test"}\n\n',
],
),
)

try:
async for _ in client.stream(
method="GET",
path="test_path",
data=TestInputModel(test_input="test"),
returns=TestOutputModel,
):
pass
except httpx.HTTPStatusError as e:
assert isinstance(e, httpx.HTTPStatusError)
assert e.response.status_code == 404
assert e.response.reason_phrase == "Not Found"
chunks = await stream_fn()
assert chunks == [TestOutputModel(a="test")]

@pytest.mark.parametrize(
"streamed_chunks",
[
# 2 perfect chunks([b'data: {"a":"test"}\n\n', b'data: {"a":"test2"}\n\n'],),
[b'data: {"a":"test"}\n\n', b'data: {"a":"test2"}\n\n'],
# 2 chunks in one
[b'data: {"a":"test"}\n\ndata: {"a":"test2"}\n\n'],
# Split not at the end
[b'data: {"a":"test"}', b'\n\ndata: {"a":"test2"}\n\n'],
# Really messy
[b"dat", b'a: {"a":"', b'test"}', b"\n", b"\ndata", b': {"a":"test2"}\n\n'],
],
)
async def test_stream_with_multiple_chunks(
self,
stream_fn: Callable[[], Awaitable[list[TestOutputModel]]],
httpx_mock: HTTPXMock,
streamed_chunks: list[bytes],
):
assert isinstance(streamed_chunks, list), "sanity check"
assert all(isinstance(chunk, bytes) for chunk in streamed_chunks), "sanity check"

httpx_mock.add_response(stream=IteratorStream(streamed_chunks))
chunks = await stream_fn()
assert chunks == [TestOutputModel(a="test"), TestOutputModel(a="test2")]
Loading