diff --git a/replicate/client.py b/replicate/client.py index b267e1c1..2441a4c7 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -329,9 +329,8 @@ def _build_httpx_client( timeout: Optional[httpx.Timeout] = None, **kwargs, ) -> Union[httpx.Client, httpx.AsyncClient]: - headers = { - "User-Agent": f"replicate-python/{__version__}", - } + headers = kwargs.pop("headers", {}) + headers["User-Agent"] = f"replicate-python/{__version__}" if ( api_token := api_token or os.environ.get("REPLICATE_API_TOKEN") diff --git a/tests/test_client.py b/tests/test_client.py index 163b185e..ee345203 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -85,3 +85,27 @@ async def test_server_error_handling(): client._request("GET", "/") assert "status: 500" in str(exc_info.value) assert "detail: Server error occurred" in str(exc_info.value) + + +def test_custom_headers_are_applied(): + import replicate + from replicate.exceptions import ReplicateError + + custom_headers = {"Custom-Header": "CustomValue"} + + def mock_send(request: httpx.Request, **kwargs) -> httpx.Response: + assert "Custom-Header" in request.headers + assert request.headers["Custom-Header"] == "CustomValue" + + return httpx.Response(401, json={}) + + client = replicate.Client( + api_token="dummy_token", + headers=custom_headers, + transport=httpx.MockTransport(mock_send), + ) + + try: + client.accounts.current() + except ReplicateError: + pass