diff --git a/replicate/client.py b/replicate/client.py index 3656d826..b267e1c1 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -365,4 +365,4 @@ def _build_httpx_client( def _raise_for_status(resp: httpx.Response) -> None: if 400 <= resp.status_code < 600: - raise ReplicateError(resp.json()["detail"]) + raise ReplicateError.from_response(resp) diff --git a/replicate/exceptions.py b/replicate/exceptions.py index e1aa51c4..4ac839c0 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -1,3 +1,8 @@ +from typing import Optional + +import httpx + + class ReplicateException(Exception): """A base class for all Replicate exceptions.""" @@ -7,4 +12,84 @@ class ModelError(ReplicateException): class ReplicateError(ReplicateException): - """An error from Replicate.""" + """ + An error from Replicate's API. + + This class represents a problem details response as defined in RFC 7807. + """ + + type: Optional[str] + """A URI that identifies the error type.""" + + title: Optional[str] + """A short, human-readable summary of the error.""" + + status: Optional[int] + """The HTTP status code.""" + + detail: Optional[str] + """A human-readable explanation specific to this occurrence of the error.""" + + instance: Optional[str] + """A URI that identifies the specific occurrence of the error.""" + + def __init__( + self, + type: Optional[str] = None, + title: Optional[str] = None, + status: Optional[int] = None, + detail: Optional[str] = None, + instance: Optional[str] = None, + ) -> None: + self.type = type + self.title = title + self.status = status + self.detail = detail + self.instance = instance + + @classmethod + def from_response(cls, response: httpx.Response) -> "ReplicateError": + """Create a ReplicateError from an HTTP response.""" + try: + data = response.json() + except ValueError: + data = {} + + return cls( + type=data.get("type"), + title=data.get("title"), + detail=data.get("detail"), + status=response.status_code, + instance=data.get("instance"), + ) + + def to_dict(self) -> dict: + return { + key: value + for key, value in { + "type": self.type, + "title": self.title, + "status": self.status, + "detail": self.detail, + "instance": self.instance, + }.items() + if value is not None + } + + def __str__(self) -> str: + return "ReplicateError Details:\n" + "\n".join( + [f"{key}: {value}" for key, value in self.to_dict().items()] + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + params = ", ".join( + [ + f"type={repr(self.type)}", + f"title={repr(self.title)}", + f"status={repr(self.status)}", + f"detail={repr(self.detail)}", + f"instance={repr(self.instance)}", + ] + ) + return f"{class_name}({params})" diff --git a/tests/test_client.py b/tests/test_client.py index 95636771..163b185e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -31,3 +31,57 @@ async def test_authorization_when_setting_environ_after_import(): client = replicate.Client(transport=httpx.MockTransport(router.handler)) resp = client._request("GET", "/") assert resp.status_code == 200 + + +@pytest.mark.asyncio +async def test_client_error_handling(): + import replicate + from replicate.exceptions import ReplicateError + + router = respx.Router() + router.route( + method="GET", + url="https://api.replicate.com/", + headers={"Authorization": "Token test-client-error"}, + ).mock( + return_value=httpx.Response( + 400, + json={"detail": "Client error occurred"}, + ) + ) + + token = "test-client-error" # noqa: S105 + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": token}): + client = replicate.Client(transport=httpx.MockTransport(router.handler)) + with pytest.raises(ReplicateError) as exc_info: + client._request("GET", "/") + assert "status: 400" in str(exc_info.value) + assert "detail: Client error occurred" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_server_error_handling(): + import replicate + from replicate.exceptions import ReplicateError + + router = respx.Router() + router.route( + method="GET", + url="https://api.replicate.com/", + headers={"Authorization": "Token test-server-error"}, + ).mock( + return_value=httpx.Response( + 500, + json={"detail": "Server error occurred"}, + ) + ) + + token = "test-server-error" # noqa: S105 + + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": token}): + client = replicate.Client(transport=httpx.MockTransport(router.handler)) + with pytest.raises(ReplicateError) as exc_info: + client._request("GET", "/") + assert "status: 500" in str(exc_info.value) + assert "detail: Server error occurred" in str(exc_info.value)