diff --git a/README.md b/README.md index a4e80c54..eb411cf7 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,24 @@ or a handle to a file on your local device. "an astronaut riding a horse" ``` +`replicate.run` raises `ModelError` if the prediction fails. +You can access the exception's `prediction` property +to get more information about the failure. + +```python +import replicate +from replicate.exceptions import ModelError + +try: + output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" }) +except ModelError as e + if "(some known issue)" in e.logs: + pass + + print("Failed prediction: " + e.prediction.id) +``` + + ## Run a model and stream its output Replicate’s API supports server-sent event streams (SSEs) for language models. diff --git a/replicate/exceptions.py b/replicate/exceptions.py index 6302d10f..f52f9fb4 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -1,7 +1,10 @@ -from typing import Optional +from typing import TYPE_CHECKING, Optional import httpx +if TYPE_CHECKING: + from replicate.prediction import Prediction + class ReplicateException(Exception): """A base class for all Replicate exceptions.""" @@ -10,6 +13,12 @@ class ReplicateException(Exception): class ModelError(ReplicateException): """An error from user's code in a model.""" + prediction: "Prediction" + + def __init__(self, prediction: "Prediction") -> None: + self.prediction = prediction + super().__init__(prediction.error) + class ReplicateError(ReplicateException): """ diff --git a/replicate/prediction.py b/replicate/prediction.py index 871566d7..74c1946e 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -249,7 +249,7 @@ def output_iterator(self) -> Iterator[Any]: self.reload() if self.status == "failed": - raise ModelError(self.error) + raise ModelError(self) output = self.output or [] new_output = output[len(previous_output) :] @@ -272,7 +272,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: await self.async_reload() if self.status == "failed": - raise ModelError(self.error) + raise ModelError(self) output = self.output or [] new_output = output[len(previous_output) :] diff --git a/replicate/run.py b/replicate/run.py index 975cc4dc..ae1ca7e5 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -58,7 +58,7 @@ def run( prediction.wait() if prediction.status == "failed": - raise ModelError(prediction.error) + raise ModelError(prediction) return prediction.output @@ -97,7 +97,7 @@ async def async_run( await prediction.async_wait() if prediction.status == "failed": - raise ModelError(prediction.error) + raise ModelError(prediction) return prediction.output diff --git a/tests/test_run.py b/tests/test_run.py index 84c8f3ab..d117eb32 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -7,7 +7,7 @@ import replicate from replicate.client import Client -from replicate.exceptions import ReplicateError +from replicate.exceptions import ModelError, ReplicateError @pytest.mark.vcr("run.yaml") @@ -184,3 +184,72 @@ def prediction_with_status(status: str) -> dict: ) assert output == "Hello, world!" + + +@pytest.mark.asyncio +async def test_run_with_model_error(mock_replicate_api_token): + def prediction_with_status(status: str) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": None, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("processing"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status("failed"), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2024-07-18T00:35:56.210272Z", + "cog_version": "0.9.10", + "openapi_schema": { + "openapi": "3.0.2", + }, + }, + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + with pytest.raises(ModelError) as excinfo: + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ) + + assert str(excinfo.value) == "OOM" + assert excinfo.value.prediction.error == "OOM" + assert excinfo.value.prediction.status == "failed"