Skip to content

Commit afd3619

Browse files
authored
Merge pull request #25 from WorkflowAI/guillaume/fix-retry-on-connect
Retry on connect error
2 parents 77d3284 + 4b57e3b commit afd3619

File tree

8 files changed

+138
-47
lines changed

8 files changed

+138
-47
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "workflowai"
3-
version = "0.5.1"
3+
version = "0.5.2"
44
description = ""
55
authors = ["Guillaume Aquilina <[email protected]>"]
66
readme = "README.md"

workflowai/core/client/_api.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from contextlib import asynccontextmanager
23
from typing import Any, AsyncIterator, Literal, Optional, TypeVar, Union, overload
34

45
import httpx
@@ -19,17 +20,26 @@ def __init__(self, endpoint: str, api_key: str, source_headers: Optional[dict[st
1920
self.api_key = api_key
2021
self.source_headers = source_headers or {}
2122

22-
def _client(self) -> httpx.AsyncClient:
23+
@asynccontextmanager
24+
async def _client(self):
2325
source_headers = self.source_headers or {}
24-
client = httpx.AsyncClient(
26+
async with httpx.AsyncClient(
2527
base_url=self.endpoint,
2628
headers={
2729
"Authorization": f"Bearer {self.api_key}",
2830
**source_headers,
2931
},
3032
timeout=120.0,
31-
)
32-
return client
33+
) as client:
34+
try:
35+
yield client
36+
except (httpx.ReadError, httpx.ConnectError) as e:
37+
raise WorkflowAIError(
38+
response=None,
39+
error=BaseError(message="Could not read response", code="connection_error"),
40+
# We can retry after 10ms
41+
retry_after_delay_seconds=0.010,
42+
) from e
3343

3444
async def get(self, path: str, returns: type[_R], query: Union[dict[str, Any], None] = None) -> _R:
3545
async with self._client() as client:

workflowai/core/client/_api_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,38 @@ async def test_stream_with_multiple_chunks(
154154
httpx_mock.add_response(stream=IteratorStream(streamed_chunks))
155155
chunks = await stream_fn()
156156
assert chunks == [TestOutputModel(a="test"), TestOutputModel(a="test2")]
157+
158+
159+
class TestReadAndConnectError:
160+
@pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")])
161+
async def test_get(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception):
162+
httpx_mock.add_exception(exception)
163+
164+
with pytest.raises(WorkflowAIError) as e:
165+
await client.get(path="test_path", returns=TestOutputModel)
166+
167+
assert e.value.error.code == "connection_error"
168+
169+
@pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")])
170+
async def test_post(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception):
171+
httpx_mock.add_exception(exception)
172+
173+
with pytest.raises(WorkflowAIError) as e:
174+
await client.post(path="test_path", data=TestInputModel(), returns=TestOutputModel)
175+
176+
assert e.value.error.code == "connection_error"
177+
178+
@pytest.mark.parametrize("exception", [httpx.ReadError("arg"), httpx.ConnectError("arg")])
179+
async def test_stream(self, httpx_mock: HTTPXMock, client: APIClient, exception: Exception):
180+
httpx_mock.add_exception(exception)
181+
182+
with pytest.raises(WorkflowAIError) as e: # noqa: PT012
183+
async for _ in client.stream(
184+
method="GET",
185+
path="test_path",
186+
data=TestInputModel(),
187+
returns=TestOutputModel,
188+
):
189+
pass
190+
191+
assert e.value.error.code == "connection_error"

workflowai/core/client/_utils.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
# the delimiter is not withing a quoted string
44
import asyncio
55
import re
6-
from email.utils import parsedate_to_datetime
76
from json import JSONDecodeError
87
from time import time
9-
from typing import Any, Optional
108

119
from workflowai.core.client._types import OutputValidator
1210
from workflowai.core.domain.errors import BaseError, WorkflowAIError
@@ -24,22 +22,6 @@ def split_chunks(chunk: bytes):
2422
yield chunk_str[start:]
2523

2624

27-
def retry_after_to_delay_seconds(retry_after: Any) -> Optional[float]:
28-
if retry_after is None:
29-
return None
30-
31-
try:
32-
return float(retry_after)
33-
except ValueError:
34-
pass
35-
try:
36-
retry_after_date = parsedate_to_datetime(retry_after)
37-
current_time = time()
38-
return retry_after_date.timestamp() - current_time
39-
except (TypeError, ValueError, OverflowError):
40-
return None
41-
42-
4325
# Returns two functions:
4426
# - _should_retry: returns True if we should retry
4527
# - _wait_for_exception: waits after an exception only if we should retry, otherwise raises
@@ -60,13 +42,16 @@ def _should_retry():
6042
return retry_count < max_retry_count and _leftover_delay() >= 0
6143

6244
async def _wait_for_exception(e: WorkflowAIError):
63-
if not e.response:
45+
retry_after = e.retry_after_delay_seconds
46+
if retry_after is None:
6447
raise e
6548

6649
nonlocal retry_count
67-
retry_after = retry_after_to_delay_seconds(e.response.headers.get("Retry-After"))
6850
leftover_delay = _leftover_delay()
6951
if not retry_after or leftover_delay < 0 or retry_count >= max_retry_count:
52+
if not e.response:
53+
raise e
54+
7055
# Convert error to WorkflowAIError
7156
try:
7257
response_json = e.response.json()

workflowai/core/client/_utils_test.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
from typing import Optional
21
from unittest.mock import Mock
32

43
import pytest
5-
from freezegun import freeze_time
6-
from httpx import HTTPStatusError
74

8-
from workflowai.core.client._utils import build_retryable_wait, retry_after_to_delay_seconds, split_chunks
9-
from workflowai.core.domain.errors import WorkflowAIError
5+
from workflowai.core.client._utils import build_retryable_wait, split_chunks
6+
from workflowai.core.domain.errors import BaseError, WorkflowAIError
107

118

129
@pytest.mark.parametrize(
@@ -23,25 +20,12 @@ def test_split_chunks(chunk: bytes, expected: list[bytes]):
2320
assert list(split_chunks(chunk)) == expected
2421

2522

26-
@freeze_time("2024-01-01T00:00:00Z")
27-
@pytest.mark.parametrize(
28-
("retry_after", "expected"),
29-
[
30-
(None, None),
31-
("10", 10),
32-
("Wed, 01 Jan 2024 00:00:10 UTC", 10),
33-
],
34-
)
35-
def test_retry_after_to_delay_seconds(retry_after: Optional[str], expected: Optional[float]):
36-
assert retry_after_to_delay_seconds(retry_after) == expected
37-
38-
3923
class TestBuildRetryableWait:
4024
@pytest.fixture
4125
def request_error(self):
4226
response = Mock()
4327
response.headers = {"Retry-After": "0.01"}
44-
return HTTPStatusError(message="", request=Mock(), response=response)
28+
return WorkflowAIError(response=response, error=BaseError(message=""))
4529

4630
async def test_should_retry_count(self, request_error: WorkflowAIError):
4731
should_retry, wait_for_exception = build_retryable_wait(60, 1)

workflowai/core/client/client_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, AsyncIterator
44
from unittest.mock import AsyncMock, patch
55

6+
import httpx
67
import pytest
78
from pytest_httpx import HTTPXMock, IteratorStream
89

@@ -13,6 +14,7 @@
1314
WorkflowAIClient,
1415
_compute_default_version_reference, # pyright: ignore [reportPrivateUsage]
1516
)
17+
from workflowai.core.domain.errors import WorkflowAIError
1618
from workflowai.core.domain.run import Run
1719

1820

@@ -178,6 +180,26 @@ async def test_run_retries_on_too_many_requests(self, httpx_mock: HTTPXMock, cli
178180
assert reqs[0].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run"
179181
assert reqs[1].url == "http://localhost:8000/v1/_/tasks/123/schemas/1/run"
180182

183+
async def test_run_retries_on_connection_error(self, httpx_mock: HTTPXMock, client: Client):
184+
task = HelloTask(id="123", schema_id=1)
185+
186+
httpx_mock.add_exception(httpx.ConnectError("arg"))
187+
httpx_mock.add_response(json=fixtures_json("task_run.json"))
188+
189+
task_run = await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5)
190+
assert task_run.id == "8f635b73-f403-47ee-bff9-18320616c6cc"
191+
192+
async def test_max_retries(self, httpx_mock: HTTPXMock, client: Client):
193+
task = HelloTask(id="123", schema_id=1)
194+
195+
httpx_mock.add_exception(httpx.ConnectError("arg"), is_reusable=True)
196+
197+
with pytest.raises(WorkflowAIError):
198+
await client.run(task, task_input=HelloTaskInput(name="Alice"), max_retry_count=5)
199+
200+
reqs = httpx_mock.get_requests()
201+
assert len(reqs) == 5
202+
181203

182204
class TestTask:
183205
@pytest.fixture

workflowai/core/domain/errors.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from email.utils import parsedate_to_datetime
12
from json import JSONDecodeError
3+
from time import time
24
from typing import Any, Literal, Optional, Union
35

46
from httpx import Response
@@ -53,6 +55,7 @@
5355
# The request was invalid
5456
"bad_request",
5557
"invalid_file",
58+
"connection_error",
5659
],
5760
str, # Using as a fallback to avoid validation error if an error code is added to the API
5861
]
@@ -70,11 +73,34 @@ class ErrorResponse(BaseModel):
7073
task_run_id: Optional[str] = None
7174

7275

76+
def _retry_after_to_delay_seconds(retry_after: Any) -> Optional[float]:
77+
if retry_after is None:
78+
return None
79+
80+
try:
81+
return float(retry_after)
82+
except ValueError:
83+
pass
84+
try:
85+
retry_after_date = parsedate_to_datetime(retry_after)
86+
current_time = time()
87+
return retry_after_date.timestamp() - current_time
88+
except (TypeError, ValueError, OverflowError):
89+
return None
90+
91+
7392
class WorkflowAIError(Exception):
74-
def __init__(self, response: Optional[Response], error: BaseError, task_run_id: Optional[str] = None):
93+
def __init__(
94+
self,
95+
response: Optional[Response],
96+
error: BaseError,
97+
task_run_id: Optional[str] = None,
98+
retry_after_delay_seconds: Optional[float] = None,
99+
):
75100
self.error = error
76101
self.task_run_id = task_run_id
77102
self.response = response
103+
self._retry_after_delay_seconds = retry_after_delay_seconds
78104

79105
def __str__(self):
80106
return f"WorkflowAIError : [{self.error.code}] ({self.error.status_code}): [{self.error.message}]"
@@ -106,3 +132,13 @@ def from_response(cls, response: Response):
106132
),
107133
task_run_id=task_run_id,
108134
)
135+
136+
@property
137+
def retry_after_delay_seconds(self) -> Optional[float]:
138+
if self._retry_after_delay_seconds:
139+
return self._retry_after_delay_seconds
140+
141+
if self.response:
142+
return _retry_after_to_delay_seconds(self.response.headers.get("Retry-After"))
143+
144+
return None

workflowai/core/domain/errors_test.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
1+
from typing import Optional
22
from unittest.mock import Mock
33

4-
from workflowai.core.domain.errors import WorkflowAIError
4+
import pytest
5+
from freezegun import freeze_time
6+
7+
from workflowai.core.domain.errors import (
8+
WorkflowAIError,
9+
_retry_after_to_delay_seconds, # pyright: ignore [reportPrivateUsage]
10+
)
511

612

713
def test_workflow_ai_error_404():
@@ -14,3 +20,16 @@ def test_workflow_ai_error_404():
1420
assert error.error.message == "None"
1521
assert error.error.status_code == 404
1622
assert error.error.code == "object_not_found"
23+
24+
25+
@freeze_time("2024-01-01T00:00:00Z")
26+
@pytest.mark.parametrize(
27+
("retry_after", "expected"),
28+
[
29+
(None, None),
30+
("10", 10),
31+
("Wed, 01 Jan 2024 00:00:10 UTC", 10),
32+
],
33+
)
34+
def test_retry_after_to_delay_seconds(retry_after: Optional[str], expected: Optional[float]):
35+
assert _retry_after_to_delay_seconds(retry_after) == expected

0 commit comments

Comments
 (0)