Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 30 additions & 24 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
class Client:
"""A Replicate API client library"""

__client: Optional[httpx.Client] = None

def __init__(
self,
api_token: Optional[str] = None,
Expand All @@ -36,37 +38,41 @@ def __init__(
) -> None:
super().__init__()

api_token = api_token or os.environ.get("REPLICATE_API_TOKEN")

base_url = base_url or os.environ.get(
"REPLICATE_API_BASE_URL", "https://api.replicate.com"
self._api_token = api_token
self._base_url = (
base_url
or os.environ.get("REPLICATE_API_BASE_URL")
or "https://api.replicate.com"
)

timeout = timeout or httpx.Timeout(
self._timeout = timeout or httpx.Timeout(
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
)
self._transport = kwargs.pop("transport", httpx.HTTPTransport())
self._client_kwargs = kwargs

self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

headers = {
"User-Agent": f"replicate-python/{__version__}",
}

if api_token is not None and api_token != "":
headers["Authorization"] = f"Token {api_token}"

transport = kwargs.pop("transport", httpx.HTTPTransport())

self._client = self._build_client(
**kwargs,
base_url=base_url,
headers=headers,
timeout=timeout,
transport=RetryTransport(wrapped_transport=transport),
)
@property
def _client(self) -> httpx.Client:
if self.__client is None:
headers = {
"User-Agent": f"replicate-python/{__version__}",
}

api_token = self._api_token or os.environ.get("REPLICATE_API_TOKEN")

if api_token is not None and api_token != "":
headers["Authorization"] = f"Token {api_token}"

self.__client = httpx.Client(
**self._client_kwargs,
base_url=self._base_url,
headers=headers,
timeout=self._timeout,
transport=RetryTransport(wrapped_transport=self._transport),
)

def _build_client(self, **kwargs) -> httpx.Client:
return httpx.Client(**kwargs)
return self.__client

def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
resp = self._client.request(method, path, **kwargs)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
from unittest import mock

import httpx
import pytest


@pytest.mark.asyncio
async def test_authorization_when_setting_environ_after_import():
import replicate

token = "test-set-after-import" # noqa: S105

with mock.patch.dict(
os.environ,
{"REPLICATE_API_TOKEN": token},
):
client: httpx.Client = replicate.default_client._client
assert "Authorization" in client.headers
assert client.headers["Authorization"] == f"Token {token}"