Skip to content

Commit a74430e

Browse files
committed
Update client to use httpx
Signed-off-by: Mattt Zmuda <[email protected]> Add respx dependency Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 2e2f40b commit a74430e

File tree

6 files changed

+83
-131
lines changed

6 files changed

+83
-131
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ optional-dependencies = { dev = [
1717
"pytest",
1818
"pytest-asyncio",
1919
"pytest-recording",
20+
"respx",
2021
"ruff",
2122
] }
2223

replicate/client.py

Lines changed: 47 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,67 @@
11
import os
22
import re
3-
from json import JSONDecodeError
4-
from typing import Any, Dict, Iterator, Optional, Union
3+
from typing import Any, Iterator, Optional, Union
54

6-
import requests
7-
from requests.adapters import HTTPAdapter, Retry
8-
from requests.cookies import RequestsCookieJar
5+
import httpx
96

10-
from replicate.__about__ import __version__
11-
from replicate.deployment import DeploymentCollection
12-
from replicate.exceptions import ModelError, ReplicateError
13-
from replicate.model import ModelCollection
14-
from replicate.prediction import PredictionCollection
15-
from replicate.training import TrainingCollection
7+
from .__about__ import __version__
8+
from .deployment import DeploymentCollection
9+
from .exceptions import ModelError, ReplicateError
10+
from .model import ModelCollection
11+
from .prediction import PredictionCollection
12+
from .training import TrainingCollection
1613

1714

1815
class Client:
19-
def __init__(self, api_token: Optional[str] = None) -> None:
16+
"""A Replicate API client library"""
17+
18+
def __init__(
19+
self,
20+
api_token: Optional[str] = None,
21+
*,
22+
base_url: Optional[str] = None,
23+
timeout: Optional[httpx.Timeout] = None,
24+
**kwargs,
25+
) -> None:
2026
super().__init__()
21-
# Client is instantiated at import time, so do as little as possible.
22-
# This includes resolving environment variables -- they might be set programmatically.
23-
self.api_token = api_token
24-
self.base_url = os.environ.get(
27+
28+
api_token = api_token or os.environ.get("REPLICATE_API_TOKEN")
29+
30+
base_url = base_url or os.environ.get(
2531
"REPLICATE_API_BASE_URL", "https://api.replicate.com"
2632
)
27-
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
2833

29-
# TODO: make thread safe
30-
self.read_session = _create_session()
31-
read_retries = Retry(
32-
total=5,
33-
backoff_factor=2,
34-
# Only retry 500s on GET so we don't unintionally mutute data
35-
allowed_methods=["GET"],
36-
# https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors
37-
status_forcelist=[
38-
429,
39-
500,
40-
502,
41-
503,
42-
504,
43-
520,
44-
521,
45-
522,
46-
523,
47-
524,
48-
526,
49-
527,
50-
],
34+
timeout = timeout or httpx.Timeout(
35+
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
5136
)
52-
self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries))
53-
self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries))
54-
55-
self.write_session = _create_session()
56-
write_retries = Retry(
57-
total=5,
58-
backoff_factor=2,
59-
allowed_methods=["POST", "PUT"],
60-
# Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data
61-
status_forcelist=[429],
62-
)
63-
self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries))
64-
self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries))
65-
66-
def _request(self, method: str, path: str, **kwargs) -> requests.Response:
67-
# from requests.Session
68-
if method in ["GET", "OPTIONS"]:
69-
kwargs.setdefault("allow_redirects", True)
70-
if method in ["HEAD"]:
71-
kwargs.setdefault("allow_redirects", False)
72-
kwargs.setdefault("headers", {})
73-
kwargs["headers"].update(self._headers())
74-
session = self.read_session
75-
if method in ["POST", "PUT", "DELETE", "PATCH"]:
76-
session = self.write_session
77-
resp = session.request(method, self.base_url + path, **kwargs)
78-
if 400 <= resp.status_code < 600:
79-
try:
80-
raise ReplicateError(resp.json()["detail"])
81-
except (JSONDecodeError, KeyError):
82-
pass
83-
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
84-
return resp
8537

86-
def _headers(self) -> Dict[str, str]:
87-
return {
88-
"Authorization": f"Token {self._api_token()}",
38+
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
39+
40+
headers = {
41+
"Authorization": f"Token {api_token}",
8942
"User-Agent": f"replicate-python/{__version__}",
9043
}
9144

92-
def _api_token(self) -> str:
93-
token = self.api_token
94-
# Evaluate lazily in case environment variable is set with dotenv, or something
95-
if token is None:
96-
token = os.environ.get("REPLICATE_API_TOKEN")
97-
if not token:
98-
raise ReplicateError(
99-
"""No API token provided. You need to set the REPLICATE_API_TOKEN environment variable or create a client with `replicate.Client(api_token=...)`.
45+
transport = kwargs.pop("transport", httpx.HTTPTransport())
10046

101-
You can find your API key on https://replicate.com"""
102-
)
103-
return token
47+
self._client = self._build_client(
48+
**kwargs,
49+
base_url=base_url,
50+
headers=headers,
51+
timeout=timeout,
52+
transport=transport,
53+
)
54+
55+
def _build_client(self, **kwargs) -> httpx.Client:
56+
return httpx.Client(**kwargs)
57+
58+
def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
59+
resp = self._client.request(method, path, **kwargs)
60+
61+
if 400 <= resp.status_code < 600:
62+
raise ReplicateError(resp.json()["detail"])
63+
64+
return resp
10465

10566
@property
10667
def models(self) -> ModelCollection:
@@ -150,21 +111,3 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
150111
if prediction.status == "failed":
151112
raise ModelError(prediction.error)
152113
return prediction.output
153-
154-
155-
class _NonpersistentCookieJar(RequestsCookieJar):
156-
"""
157-
A cookie jar that doesn't persist cookies between requests.
158-
"""
159-
160-
def set(self, name, value, **kwargs) -> None:
161-
return
162-
163-
def set_cookie(self, cookie, *args, **kwargs) -> None:
164-
return
165-
166-
167-
def _create_session() -> requests.Session:
168-
s = requests.Session()
169-
s.cookies = _NonpersistentCookieJar()
170-
return s

replicate/files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from typing import Optional
66

7-
import requests
7+
import httpx
88

99

1010
def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
@@ -24,7 +24,7 @@ def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
2424
if output_file_prefix is not None:
2525
name = getattr(fh, "name", "output")
2626
url = output_file_prefix + os.path.basename(name)
27-
resp = requests.put(url, files={"file": fh}, timeout=None)
27+
resp = httpx.put(url, files={"file": fh}, timeout=None) # type: ignore
2828
resp.raise_for_status()
2929
return url
3030

requirements-dev.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ h11==0.14.0
2121
httpcore==0.17.3
2222
# via httpx
2323
httpx==0.24.1
24-
# via replicate (pyproject.toml)
24+
# via
25+
# replicate (pyproject.toml)
26+
# respx
2527
idna==3.4
2628
# via
2729
# anyio
@@ -63,6 +65,8 @@ pytest-recording==0.13.0
6365
# via replicate (pyproject.toml)
6466
pyyaml==6.0.1
6567
# via vcrpy
68+
respx==0.20.2
69+
# via replicate (pyproject.toml)
6670
ruff==0.0.278
6771
# via replicate (pyproject.toml)
6872
sniffio==1.3.0

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ def mock_replicate_api_token(scope="class"):
99
if os.environ.get("REPLICATE_API_TOKEN", "") != "":
1010
yield
1111
else:
12-
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "test-token"}):
12+
with mock.patch.dict(
13+
os.environ,
14+
{"REPLICATE_API_TOKEN": "test-token", "REPLICATE_POLL_INTERVAL": "0.0"},
15+
):
1316
yield
1417

1518

tests/test_deployment.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,16 @@
1-
import responses
2-
from responses import matchers
1+
import httpx
2+
import respx
33

44
from replicate.client import Client
55

6-
7-
@responses.activate
8-
def test_deployment_predictions_create():
9-
client = Client(api_token="abc123")
10-
11-
deployment = client.deployments.get("test/model")
12-
13-
rsp = responses.post(
14-
"https://api.replicate.com/v1/deployments/test/model/predictions",
15-
match=[
16-
matchers.json_params_matcher(
17-
{
18-
"input": {"text": "world"},
19-
"webhook": "https://example.com/webhook",
20-
"webhook_events_filter": ["completed"],
21-
}
22-
),
23-
],
6+
router = respx.Router(base_url="https://api.replicate.com/v1")
7+
router.route(
8+
method="POST",
9+
path="/deployments/test/model/predictions",
10+
name="deployments.predictions.create",
11+
).mock(
12+
return_value=httpx.Response(
13+
201,
2414
json={
2515
"id": "p1",
2616
"version": "v1",
@@ -31,17 +21,28 @@ def test_deployment_predictions_create():
3121
"created_at": "2022-04-26T20:00:40.658234Z",
3222
"source": "api",
3323
"status": "processing",
34-
"input": {"text": "hello"},
24+
"input": {"text": "world"},
3525
"output": None,
3626
"error": None,
3727
"logs": "",
3828
},
3929
)
30+
)
31+
router.route(host="api.replicate.com").pass_through()
32+
33+
34+
def test_deployment_predictions_create():
35+
client = Client(
36+
api_token="test-token", transport=httpx.MockTransport(router.handler)
37+
)
38+
deployment = client.deployments.get("test/model")
4039

41-
deployment.predictions.create(
40+
prediction = deployment.predictions.create(
4241
input={"text": "world"},
4342
webhook="https://example.com/webhook",
4443
webhook_events_filter=["completed"],
4544
)
4645

47-
assert rsp.call_count == 1
46+
assert router["deployments.predictions.create"].called
47+
assert prediction.id == "p1"
48+
assert prediction.input == {"text": "world"}

0 commit comments

Comments
 (0)