diff --git a/replicate/client.py b/replicate/client.py index 73c9dd2d..19753c2b 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -26,6 +26,8 @@ class Client: """A Replicate API client library""" + __client: Optional[httpx.Client] = None + def __init__( self, api_token: Optional[str] = None, @@ -36,37 +38,41 @@ def __init__( ) -> None: super().__init__() - api_token = api_token or os.environ.get("REPLICATE_API_TOKEN") - - base_url = base_url or os.environ.get( - "REPLICATE_API_BASE_URL", "https://api.replicate.com" + self._api_token = api_token + self._base_url = ( + base_url + or os.environ.get("REPLICATE_API_BASE_URL") + or "https://api.replicate.com" ) - - timeout = timeout or httpx.Timeout( + self._timeout = timeout or httpx.Timeout( 5.0, read=30.0, write=30.0, connect=5.0, pool=10.0 ) + self._transport = kwargs.pop("transport", httpx.HTTPTransport()) + self._client_kwargs = kwargs self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5")) - headers = { - "User-Agent": f"replicate-python/{__version__}", - } - - if api_token is not None and api_token != "": - headers["Authorization"] = f"Token {api_token}" - - transport = kwargs.pop("transport", httpx.HTTPTransport()) - - self._client = self._build_client( - **kwargs, - base_url=base_url, - headers=headers, - timeout=timeout, - transport=RetryTransport(wrapped_transport=transport), - ) + @property + def _client(self) -> httpx.Client: + if self.__client is None: + headers = { + "User-Agent": f"replicate-python/{__version__}", + } + + api_token = self._api_token or os.environ.get("REPLICATE_API_TOKEN") + + if api_token is not None and api_token != "": + headers["Authorization"] = f"Token {api_token}" + + self.__client = httpx.Client( + **self._client_kwargs, + base_url=self._base_url, + headers=headers, + timeout=self._timeout, + transport=RetryTransport(wrapped_transport=self._transport), + ) - def _build_client(self, **kwargs) -> httpx.Client: - return httpx.Client(**kwargs) + return self.__client def _request(self, method: str, path: str, **kwargs) -> httpx.Response: resp = self._client.request(method, path, **kwargs) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..42cf28da --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,20 @@ +import os +from unittest import mock + +import httpx +import pytest + + +@pytest.mark.asyncio +async def test_authorization_when_setting_environ_after_import(): + import replicate + + token = "test-set-after-import" # noqa: S105 + + with mock.patch.dict( + os.environ, + {"REPLICATE_API_TOKEN": token}, + ): + client: httpx.Client = replicate.default_client._client + assert "Authorization" in client.headers + assert client.headers["Authorization"] == f"Token {token}"