Skip to content

Commit f53f829

Browse files
committed
client: Retry network errors
Add retry decorator for network-related exceptions in the client module, including `RemoteProtocolError`, `ConnectError`, `TimeoutException`, and `NetworkError`. Abstract HTTP request logic in the client module into a single method and apply the retry decorator to it. Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 892a31b commit f53f829

File tree

4 files changed

+86
-12
lines changed

4 files changed

+86
-12
lines changed

client/cogstack_model_gateway_client/client.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import httpx
77

8+
from cogstack_model_gateway_client.exceptions import retry_if_network_error
9+
810

911
class GatewayClient:
1012
def __init__(
@@ -41,6 +43,33 @@ async def wrapper(self, *args, **kwargs):
4143

4244
return wrapper
4345

46+
@require_client
47+
@retry_if_network_error
48+
async def _request(
49+
self,
50+
method: str,
51+
url: str,
52+
*,
53+
data=None,
54+
json=None,
55+
files=None,
56+
params=None,
57+
headers=None,
58+
**kwargs,
59+
) -> httpx.Response:
60+
"""Make HTTP requests with retry logic."""
61+
resp = await self._client.request(
62+
method=method,
63+
url=url,
64+
data=data,
65+
json=json,
66+
files=files,
67+
params=params,
68+
headers=headers,
69+
**kwargs,
70+
)
71+
return resp.raise_for_status()
72+
4473
@require_client
4574
async def submit_task(
4675
self,
@@ -59,11 +88,12 @@ async def submit_task(
5988
if not model_name:
6089
raise ValueError("Please provide a model name or set a default model for the client.")
6190
url = f"{self.base_url}/models/{model_name}/tasks/{task}"
62-
resp = await self._client.post(
63-
url, data=data, json=json, files=files, params=params, headers=headers
91+
92+
resp = await self._request(
93+
"POST", url, data=data, json=json, files=files, params=params, headers=headers
6494
)
65-
resp.raise_for_status()
6695
task_info = resp.json()
96+
6797
if wait_for_completion:
6898
task_uuid = task_info["uuid"]
6999
task_info = await self.wait_for_task(task_uuid)
@@ -143,8 +173,7 @@ async def _get_task(self, task_uuid: str, detail: bool = True, download: bool =
143173
"""Get a Gateway task."""
144174
url = f"{self.base_url}/tasks/{task_uuid}"
145175
params = {"detail": detail, "download": download}
146-
resp = await self._client.get(url, params=params)
147-
return resp.raise_for_status()
176+
return await self._request("GET", url, params=params)
148177

149178
@require_client
150179
async def get_task(self, task_uuid: str, detail: bool = True):
@@ -207,8 +236,7 @@ async def wait_for_task(
207236
async def get_models(self, verbose: bool = False):
208237
"""Get the list of available models from the Gateway."""
209238
url = f"{self.base_url}/models/"
210-
resp = await self._client.get(url, params={"verbose": verbose})
211-
resp.raise_for_status()
239+
resp = await self._request("GET", url, params={"verbose": verbose})
212240
return resp.json()
213241

214242
@require_client
@@ -218,8 +246,7 @@ async def get_model(self, model_name: str = None):
218246
if not model_name:
219247
raise ValueError("Please provide a model name or set a default model for the client.")
220248
url = f"{self.base_url}/models/{model_name}/info"
221-
resp = await self._client.get(url)
222-
resp.raise_for_status()
249+
resp = await self._request("GET", url)
223250
return resp.json()
224251

225252
@require_client
@@ -236,8 +263,7 @@ async def deploy_model(
236263
raise ValueError("Please provide a model name or set a default model for the client.")
237264
url = f"{self.base_url}/models/{model_name}"
238265
data = {"tracking_id": tracking_id, "model_uri": model_uri, "ttl": ttl}
239-
resp = await self._client.post(url, json=data)
240-
resp.raise_for_status()
266+
resp = await self._request("POST", url, json=data)
241267
return resp.json()
242268

243269

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import logging
2+
3+
import httpx
4+
from tenacity import (
5+
before_sleep_log,
6+
retry,
7+
retry_if_exception,
8+
stop_after_attempt,
9+
wait_fixed,
10+
)
11+
12+
log = logging.getLogger("cmg.client")
13+
14+
15+
def is_network_error(exception: Exception):
16+
"""Check if the exception is a network-related error."""
17+
return isinstance(
18+
exception,
19+
httpx.RemoteProtocolError
20+
| httpx.ConnectError
21+
| httpx.TimeoutException
22+
| httpx.NetworkError,
23+
)
24+
25+
26+
retry_if_network_error = retry(
27+
stop=stop_after_attempt(3),
28+
wait=wait_fixed(2),
29+
retry=retry_if_exception(is_network_error),
30+
before_sleep=before_sleep_log(log, logging.DEBUG),
31+
)

client/poetry.lock

Lines changed: 17 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

client/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ pre-commit = "^4.2.0"
3535
pytest = "^8.4.1"
3636
pytest-asyncio = "^1.0.0"
3737
pytest-mock = "^3.14.1"
38+
tenacity = "^9.1.2"
3839

3940
[tool.poetry.requires-plugins]
4041
poetry-dynamic-versioning = { version = ">=1.0.0,<2.0.0", extras = ["plugin"] }

0 commit comments

Comments
 (0)