diff --git a/pyproject.toml b/pyproject.toml index d319137..e7fce42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "workflowai" -version = "0.5.0" +version = "0.5.1" description = "" authors = ["Guillaume Aquilina "] readme = "README.md" diff --git a/workflowai/core/client/_api.py b/workflowai/core/client/_api.py index 1634636..edb0789 100644 --- a/workflowai/core/client/_api.py +++ b/workflowai/core/client/_api.py @@ -1,15 +1,17 @@ +import logging from typing import Any, AsyncIterator, Literal, Optional, TypeVar, Union, overload import httpx from pydantic import BaseModel, TypeAdapter, ValidationError -from workflowai.core.client._utils import split_chunks from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError # A type for return values _R = TypeVar("_R") _M = TypeVar("_M", bound=BaseModel) +_logger = logging.getLogger("WorkflowAI") + class APIClient: def __init__(self, endpoint: str, api_key: str, source_headers: Optional[dict[str, str]] = None): @@ -106,6 +108,37 @@ def _extract_error( response=response, ) from exception + async def _wrap_sse(self, raw: AsyncIterator[bytes], termination_chars: bytes = b"\n\n"): + data = b"" + in_data = False + async for chunk in raw: + data += chunk + if not in_data: + if data.startswith(b"data: "): + data = data[6:] + in_data = True + else: + # We will wait for the next chunk, we might be in the middle + # of 'data: ' + continue + + # Splitting the chunk by separator + splits = data.split(b"\n\ndata: ") + if len(splits) > 1: + # Yielding the rest of the splits except the last one + for data in splits[0:-1]: + yield data + # The last split could be incomplete + data = splits[-1] + + if data.endswith(termination_chars): + yield data[: -len(termination_chars)] + data = b"" + in_data = False + + if data: + _logger.warning("Data left after processing", extra={"data": data}) + async def stream( self, method: Literal["GET", "POST"], @@ -122,15 +155,14 @@ async def stream( if not response.is_success: # We need to read the response to get the error message await response.aread() - response.raise_for_status() + await self.raise_for_status(response) + return - async for chunk in response.aiter_bytes(): - payload = "" + async for chunk in self._wrap_sse(response.aiter_bytes()): try: - for payload in split_chunks(chunk): - yield returns.model_validate_json(payload) + yield returns.model_validate_json(chunk) except ValidationError as e: - raise self._extract_error(response, payload, e) from None + raise self._extract_error(response, chunk, e) from None async def raise_for_status(self, response: httpx.Response): if response.status_code < 200 or response.status_code >= 300: diff --git a/workflowai/core/client/_api_test.py b/workflowai/core/client/_api_test.py index 4b57eb0..88ded0f 100644 --- a/workflowai/core/client/_api_test.py +++ b/workflowai/core/client/_api_test.py @@ -1,7 +1,9 @@ +from collections.abc import Awaitable, Callable + import httpx import pytest from pydantic import BaseModel -from pytest_httpx import HTTPXMock +from pytest_httpx import HTTPXMock, IteratorStream from workflowai.core.client._api import APIClient from workflowai.core.domain.errors import WorkflowAIError @@ -60,26 +62,95 @@ def test_extract_error_with_custom_error(self): assert e.value.response == response -async def test_stream_404(httpx_mock: HTTPXMock): - class TestInputModel(BaseModel): - test_input: str - - class TestOutputModel(BaseModel): - test_output: str - - httpx_mock.add_response(status_code=404) - - client = APIClient(endpoint="https://blabla.com", api_key="test_api_key") +@pytest.fixture +def client() -> APIClient: + return APIClient(endpoint="https://blabla.com", api_key="test_api_key") + + +class TestInputModel(BaseModel): + bla: str = "bla" + + +class TestOutputModel(BaseModel): + a: str + + +class TestAPIClientStream: + async def test_stream_404(self, httpx_mock: HTTPXMock, client: APIClient): + class TestInputModel(BaseModel): + test_input: str + + class TestOutputModel(BaseModel): + test_output: str + + httpx_mock.add_response(status_code=404) + + with pytest.raises(WorkflowAIError) as e: # noqa: PT012 + async for _ in client.stream( + method="GET", + path="test_path", + data=TestInputModel(test_input="test"), + returns=TestOutputModel, + ): + pass + + assert e.value.response + assert e.value.response.status_code == 404 + assert e.value.response.reason_phrase == "Not Found" + + @pytest.fixture + async def stream_fn(self, client: APIClient): + async def _stm(): + return [ + chunk + async for chunk in client.stream( + method="GET", + path="test_path", + data=TestInputModel(), + returns=TestOutputModel, + ) + ] + + return _stm + + async def test_stream_with_single_chunk( + self, + stream_fn: Callable[[], Awaitable[list[TestOutputModel]]], + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response( + stream=IteratorStream( + [ + b'data: {"a":"test"}\n\n', + ], + ), + ) - try: - async for _ in client.stream( - method="GET", - path="test_path", - data=TestInputModel(test_input="test"), - returns=TestOutputModel, - ): - pass - except httpx.HTTPStatusError as e: - assert isinstance(e, httpx.HTTPStatusError) - assert e.response.status_code == 404 - assert e.response.reason_phrase == "Not Found" + chunks = await stream_fn() + assert chunks == [TestOutputModel(a="test")] + + @pytest.mark.parametrize( + "streamed_chunks", + [ + # 2 perfect chunks([b'data: {"a":"test"}\n\n', b'data: {"a":"test2"}\n\n'],), + [b'data: {"a":"test"}\n\n', b'data: {"a":"test2"}\n\n'], + # 2 chunks in one + [b'data: {"a":"test"}\n\ndata: {"a":"test2"}\n\n'], + # Split not at the end + [b'data: {"a":"test"}', b'\n\ndata: {"a":"test2"}\n\n'], + # Really messy + [b"dat", b'a: {"a":"', b'test"}', b"\n", b"\ndata", b': {"a":"test2"}\n\n'], + ], + ) + async def test_stream_with_multiple_chunks( + self, + stream_fn: Callable[[], Awaitable[list[TestOutputModel]]], + httpx_mock: HTTPXMock, + streamed_chunks: list[bytes], + ): + assert isinstance(streamed_chunks, list), "sanity check" + assert all(isinstance(chunk, bytes) for chunk in streamed_chunks), "sanity check" + + httpx_mock.add_response(stream=IteratorStream(streamed_chunks)) + chunks = await stream_fn() + assert chunks == [TestOutputModel(a="test"), TestOutputModel(a="test2")]