Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ or a handle to a file on your local device.
"an astronaut riding a horse"
```

`replicate.run` raises `ModelError` if the prediction fails.
You can access the exception's `prediction` property
to get more information about the failure.

```python
import replicate
from replicate.exceptions import ModelError

try:
output = replicate.run("stability-ai/stable-diffusion-3", { "prompt": "An astronaut riding a rainbow unicorn" })
except ModelError as e
if "(some known issue)" in e.logs:
pass

print("Failed prediction: " + e.prediction.id)
```


## Run a model and stream its output

Replicate’s API supports server-sent event streams (SSEs) for language models.
Expand Down
11 changes: 10 additions & 1 deletion replicate/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Optional
from typing import TYPE_CHECKING, Optional

import httpx

if TYPE_CHECKING:
from replicate.prediction import Prediction


class ReplicateException(Exception):
"""A base class for all Replicate exceptions."""
Expand All @@ -10,6 +13,12 @@ class ReplicateException(Exception):
class ModelError(ReplicateException):
"""An error from user's code in a model."""

prediction: "Prediction"

def __init__(self, prediction: "Prediction") -> None:
self.prediction = prediction
super().__init__(prediction.error)


class ReplicateError(ReplicateException):
"""
Expand Down
4 changes: 2 additions & 2 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def output_iterator(self) -> Iterator[Any]:
self.reload()

if self.status == "failed":
raise ModelError(self.error)
raise ModelError(self)

output = self.output or []
new_output = output[len(previous_output) :]
Expand All @@ -272,7 +272,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]:
await self.async_reload()

if self.status == "failed":
raise ModelError(self.error)
raise ModelError(self)

output = self.output or []
new_output = output[len(previous_output) :]
Expand Down
4 changes: 2 additions & 2 deletions replicate/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(
prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction.error)
raise ModelError(prediction)

return prediction.output

Expand Down Expand Up @@ -97,7 +97,7 @@ async def async_run(
await prediction.async_wait()

if prediction.status == "failed":
raise ModelError(prediction.error)
raise ModelError(prediction)

return prediction.output

Expand Down
71 changes: 70 additions & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import replicate
from replicate.client import Client
from replicate.exceptions import ReplicateError
from replicate.exceptions import ModelError, ReplicateError


@pytest.mark.vcr("run.yaml")
Expand Down Expand Up @@ -184,3 +184,72 @@ def prediction_with_status(status: str) -> dict:
)

assert output == "Hello, world!"


@pytest.mark.asyncio
async def test_run_with_model_error(mock_replicate_api_token):
def prediction_with_status(status: str) -> dict:
return {
"id": "p1",
"model": "test/example",
"version": "v1",
"urls": {
"get": "https://api.replicate.com/v1/predictions/p1",
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
},
"created_at": "2023-10-05T12:00:00.000000Z",
"source": "api",
"status": status,
"input": {"text": "world"},
"output": None,
"error": "OOM" if status == "failed" else None,
"logs": "",
}

router = respx.Router(base_url="https://api.replicate.com/v1")
router.route(method="POST", path="/predictions").mock(
return_value=httpx.Response(
201,
json=prediction_with_status("processing"),
)
)
router.route(method="GET", path="/predictions/p1").mock(
return_value=httpx.Response(
200,
json=prediction_with_status("failed"),
)
)
router.route(
method="GET",
path="/models/test/example/versions/v1",
).mock(
return_value=httpx.Response(
201,
json={
"id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1",
"created_at": "2024-07-18T00:35:56.210272Z",
"cog_version": "0.9.10",
"openapi_schema": {
"openapi": "3.0.2",
},
},
)
)
router.route(host="api.replicate.com").pass_through()

client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)
client.poll_interval = 0.001

with pytest.raises(ModelError) as excinfo:
client.run(
"test/example:v1",
input={
"text": "Hello, world!",
},
)

assert str(excinfo.value) == "OOM"
assert excinfo.value.prediction.error == "OOM"
assert excinfo.value.prediction.status == "failed"