diff --git a/README.md b/README.md index 011ea98a..9241fdb6 100644 --- a/README.md +++ b/README.md @@ -25,88 +25,103 @@ We recommend not adding the token directly to your source code, because you don' Create a new Python file and add the following code: ```python -import replicate -model = replicate.models.get("stability-ai/stable-diffusion") -version = model.versions.get("27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478") -version.predict(prompt="a 19th century portrait of a wombat gentleman") +>>> import replicate +>>> replicate.run( + "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", + input={"prompt": "a 19th century portrait of a wombat gentleman"} + ) -# ['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png'] +['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png'] ``` Some models, like [methexis-inc/img2prompt](https://replicate.com/methexis-inc/img2prompt), receive images as inputs. To pass a file as an input, use a file handle or URL: ```python -model = replicate.models.get("methexis-inc/img2prompt") -version = model.versions.get("50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5") -inputs = { - "image": open("path/to/mystery.jpg", "rb"), -} -output = version.predict(**inputs) - -# [['n02123597', 'Siamese_cat', 0.8829364776611328], -# ['n02123394', 'Persian_cat', 0.09810526669025421], -# ['n02123045', 'tabby', 0.005758069921284914]] +>>> output = replicate.run( + "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746", + input={"image": open("path/to/mystery.jpg", "rb")}, + ) + +"an astronaut riding a horse" ``` -## Compose models into a pipeline +## Run a model in the background -You can run a model and feed the output into another model: +You can start a model and run it in the background: ```python -laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05") -swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a") -image = laionide.predict(prompt="avocado armchair") -upscaled_image = swinir.predict(image=image) -``` +>>> model = replicate.models.get("kvfrans/clipdraw") +>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b") +>>> prediction = replicate.predictions.create( + version=version, + input={"prompt":"Watercolor painting of an underwater submarine"}) -## Get output from a running model +>>> prediction +Prediction(...) -Run a model and get its output while it's running: +>>> prediction.status +'starting' -```python -model = replicate.models.get("pixray/text2image") -version = model.versions.get("5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf") -for image in version.predict(prompts="san francisco sunset"): - display(image) +>>> dict(prediction) +{"id": "...", "status": "starting", ...} + +>>> prediction.reload() +>>> prediction.status +'processing' + +>>> print(prediction.logs) +iteration: 0, render:loss: -0.6171875 +iteration: 10, render:loss: -0.92236328125 +iteration: 20, render:loss: -1.197265625 +iteration: 30, render:loss: -1.3994140625 + +>>> prediction.wait() + +>>> prediction.status +'succeeded' + +>>> prediction.output +'https://.../output.png' ``` -## Run a model in the background +## Run a model in the background and get a webhook -You can start a model and run it in the background: +You can run a model and get a webhook when it completes, instead of waiting for it to finish: ```python model = replicate.models.get("kvfrans/clipdraw") version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b") prediction = replicate.predictions.create( version=version, - input={"prompt":"Watercolor painting of an underwater submarine"}) - -# >>> prediction -# Prediction(...) + input={"prompt":"Watercolor painting of an underwater submarine"}, + webhook="https://example.com/your-webhook", + webhook_events_filter=["completed"] +) +``` -# >>> prediction.status -# 'starting' +## Compose models into a pipeline -# >>> dict(prediction) -# {"id": "...", "status": "starting", ...} +You can run a model and feed the output into another model: -# >>> prediction.reload() -# >>> prediction.status -# 'processing' +```python +laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05") +swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a") +image = laionide.predict(prompt="avocado armchair") +upscaled_image = swinir.predict(image=image) +``` -# >>> print(prediction.logs) -# iteration: 0, render:loss: -0.6171875 -# iteration: 10, render:loss: -0.92236328125 -# iteration: 20, render:loss: -1.197265625 -# iteration: 30, render:loss: -1.3994140625 +## Get output from a running model -# >>> prediction.wait() +Run a model and get its output while it's running: -# >>> prediction.status -# 'succeeded' +```python +iterator = replicate.run( + "pixray/text2image:5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf", + input={"prompts": "san francisco sunset"} +) -# >>> prediction.output -# 'https://.../output.png' +for image in iterator: + display(image) ``` ## Cancel a prediction @@ -114,20 +129,21 @@ prediction = replicate.predictions.create( You can cancel a running prediction: ```python -model = replicate.models.get("kvfrans/clipdraw") -version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b") -prediction = replicate.predictions.create( - version=version, - input={"prompt":"Watercolor painting of an underwater submarine"}) +>>> model = replicate.models.get("kvfrans/clipdraw") +>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b") +>>> prediction = replicate.predictions.create( + version=version, + input={"prompt":"Watercolor painting of an underwater submarine"} + ) -# >>> prediction.status -# 'starting' +>>> prediction.status +'starting' -# >>> prediction.cancel() +>>> prediction.cancel() -# >>> prediction.reload() -# >>> prediction.status -# 'canceled' +>>> prediction.reload() +>>> prediction.status +'canceled' ``` ## List predictions diff --git a/replicate/__init__.py b/replicate/__init__.py index 1e5a502a..34db2000 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -2,5 +2,6 @@ from .client import Client default_client = Client() +run = default_client.run models = default_client.models predictions = default_client.predictions diff --git a/replicate/client.py b/replicate/client.py index 3e3694d5..edeadaf0 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -1,11 +1,13 @@ import os +import re from json import JSONDecodeError +from typing import Any, Iterator, Union import requests from requests.adapters import HTTPAdapter, Retry from replicate.__about__ import __version__ -from replicate.exceptions import ReplicateError +from replicate.exceptions import ModelError, ReplicateError from replicate.model import ModelCollection from replicate.prediction import PredictionCollection @@ -35,7 +37,20 @@ def __init__(self, api_token=None) -> None: # TODO: Only retry on GET so we don't unintionally mutute data method_whitelist=["GET", "POST", "PUT"], # https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors - status_forcelist=[429, 500, 502, 503, 504, 520, 521, 522, 523, 524, 526, 527], + status_forcelist=[ + 429, + 500, + 502, + 503, + 504, + 520, + 521, + 522, + 523, + 524, + 526, + 527, + ], ) self.session.mount("http://", HTTPAdapter(max_retries=retries)) @@ -84,3 +99,30 @@ def models(self) -> ModelCollection: @property def predictions(self) -> PredictionCollection: return PredictionCollection(client=self) + + def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]: + """ + Run a model in the format owner/name:version. + """ + # Split model_version into owner, name, version in format owner/name:version + m = re.match(r"^(?P[^/]+/[^:]+):(?P.+)$", model_version) + if not m: + raise ReplicateError( + f"Invalid model_version: {model_version}. Expected format: owner/name:version" + ) + model = self.models.get(m.group("model")) + version = model.versions.get(m.group("version")) + prediction = self.predictions.create(version=version, **kwargs) + # Return an iterator of the output + schema = version.get_transformed_schema() + output = schema["components"]["schemas"]["Output"] + if ( + output.get("type") == "array" + and output.get("x-cog-array-type") == "iterator" + ): + return prediction.output_iterator() + + prediction.wait() + if prediction.status == "failed": + raise ModelError(prediction.error) + return prediction.output diff --git a/replicate/version.py b/replicate/version.py index cc4cbd0c..cf3376a8 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -1,4 +1,5 @@ import datetime +import warnings from typing import Any, Iterator, List, Union from replicate.base_model import BaseModel @@ -14,10 +15,13 @@ class Version(BaseModel): openapi_schema: Any def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: - # TODO: support args + warnings.warn( + "version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.", + DeprecationWarning, + ) + prediction = self._client.predictions.create(version=self, input=kwargs) # Return an iterator of the output - # FIXME: might just be a list, not an iterator. I wonder if we should differentiate? schema = self.get_transformed_schema() output = schema["components"]["schemas"]["Output"] if ( diff --git a/tests/factories.py b/tests/factories.py index 6379b8fe..349fb7f4 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -1,5 +1,8 @@ import datetime +import responses +from responses import matchers + from replicate.client import Client from replicate.version import Version @@ -9,148 +12,195 @@ def create_client(): return client -def create_version(client=None, openapi_schema=None, cog_version="0.3.0"): - if client is None: - client = create_client() - version = Version( - id="v1", - created_at=datetime.datetime.now(), - cog_version=cog_version, - openapi_schema=openapi_schema - or { - "info": {"title": "Cog", "version": "0.1.0"}, - "paths": { - "/": { - "get": { - "summary": "Root", - "responses": { - "200": { - "content": {"application/json": {"schema": {}}}, - "description": "Successful Response", - } - }, - "operationId": "root__get", - } - }, - "/predictions": { - "post": { - "summary": "Predict", - "responses": { - "200": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Response" - } - } - }, - "description": "Successful Response", - }, - "422": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - }, - "description": "Validation Error", +def get_mock_schema(): + return { + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": { + "/": { + "get": { + "summary": "Root", + "responses": { + "200": { + "content": {"application/json": {"schema": {}}}, + "description": "Successful Response", + } + }, + "operationId": "root__get", + } + }, + "/predictions": { + "post": { + "summary": "Predict", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Response"} + } }, + "description": "Successful Response", }, - "description": "Run a single prediction on the model", - "operationId": "predict_predictions_post", - "requestBody": { + "422": { "content": { "application/json": { - "schema": {"$ref": "#/components/schemas/Request"} + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } } - } + }, + "description": "Validation Error", }, - } - }, - }, - "openapi": "3.0.2", - "components": { - "schemas": { - "Input": { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "Text to prefix with 'hello '", + }, + "description": "Run a single prediction on the model", + "operationId": "predict_predictions_post", + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Request"} } - }, + } }, - "Output": {"type": "string", "title": "Output"}, - "Status": { - "enum": ["processing", "succeeded", "failed"], - "type": "string", - "title": "Status", - "description": "An enumeration.", + } + }, + }, + "openapi": "3.0.2", + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "Text to prefix with 'hello '", + } }, - "Request": { - "type": "object", - "title": "Request", - "properties": { - "input": {"$ref": "#/components/schemas/Input"}, - "output_file_prefix": { - "type": "string", - "title": "Output File Prefix", - }, + }, + "Output": {"type": "string", "title": "Output"}, + "Status": { + "enum": ["processing", "succeeded", "failed"], + "type": "string", + "title": "Status", + "description": "An enumeration.", + }, + "Request": { + "type": "object", + "title": "Request", + "properties": { + "input": {"$ref": "#/components/schemas/Input"}, + "output_file_prefix": { + "type": "string", + "title": "Output File Prefix", }, - "description": "The request body for a prediction", }, - "Response": { - "type": "object", - "title": "Response", - "required": ["status"], - "properties": { - "error": {"type": "string", "title": "Error"}, - "output": {"$ref": "#/components/schemas/Output"}, - "status": {"$ref": "#/components/schemas/Status"}, - }, - "description": "The response body for a prediction", + "description": "The request body for a prediction", + }, + "Response": { + "type": "object", + "title": "Response", + "required": ["status"], + "properties": { + "error": {"type": "string", "title": "Error"}, + "output": {"$ref": "#/components/schemas/Output"}, + "status": {"$ref": "#/components/schemas/Status"}, }, - "ValidationError": { - "type": "object", - "title": "ValidationError", - "required": ["loc", "msg", "type"], - "properties": { - "loc": { - "type": "array", - "items": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"}, - ] - }, - "title": "Location", + "description": "The response body for a prediction", + }, + "ValidationError": { + "type": "object", + "title": "ValidationError", + "required": ["loc", "msg", "type"], + "properties": { + "loc": { + "type": "array", + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ] }, - "msg": {"type": "string", "title": "Message"}, - "type": {"type": "string", "title": "Error Type"}, + "title": "Location", }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, }, - "HTTPValidationError": { - "type": "object", - "title": "HTTPValidationError", - "properties": { - "detail": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ValidationError" - }, - "title": "Detail", - } - }, + }, + "HTTPValidationError": { + "type": "object", + "title": "HTTPValidationError", + "properties": { + "detail": { + "type": "array", + "items": {"$ref": "#/components/schemas/ValidationError"}, + "title": "Detail", + } }, - } - }, + }, + } + }, + } + + +def mock_version_get( + owner="test", model="model", version="v1", openapi_schema=None, cog_version="0.3.9" +): + responses.get( + f"https://api.replicate.com/v1/models/{owner}/{model}/versions/{version}", + match=[ + matchers.header_matcher({"Authorization": "Token abc123"}), + ], + json={ + "id": version, + "created_at": "2022-04-26T19:29:04.418669Z", + "cog_version": "0.3.9", + "openapi_schema": openapi_schema or get_mock_schema(), }, ) + + +def mock_version_get_with_iterator_output(**kwargs): + schema = get_mock_schema() + schema["components"]["schemas"]["Output"] = { + "type": "array", + "items": {"type": "string"}, + "title": "Output", + "x-cog-array-type": "iterator", + } + mock_version_get(openapi_schema=schema, cog_version="0.3.9", **kwargs) + + +def mock_version_get_with_list_output(**kwargs): + schema = get_mock_schema() + schema["components"]["schemas"]["Output"] = { + "type": "array", + "items": {"type": "string"}, + "title": "Output", + } + mock_version_get(openapi_schema=schema, cog_version="0.3.9", **kwargs) + + +def mock_version_get_with_iterator_output_backwards_compatibility_0_3_8(**kwargs): + schema = get_mock_schema() + schema["components"]["schemas"]["Output"] = { + "type": "array", + "items": {"type": "string"}, + "title": "Output", + } + mock_version_get(openapi_schema=schema, cog_version="0.3.8", **kwargs) + + +def create_version(client=None, openapi_schema=None, cog_version="0.3.0"): + if client is None: + client = create_client() + version = Version( + id="v1", + created_at=datetime.datetime.now(), + cog_version=cog_version, + openapi_schema=openapi_schema or get_mock_schema(), + ) version._client = client return version diff --git a/tests/test_client.py b/tests/test_client.py index 42866f43..1e4d57ce 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,7 +1,19 @@ +from collections.abc import Iterable + +import pytest import responses +from responses import matchers + from replicate.__about__ import __version__ from replicate.client import Client -from responses import matchers +from replicate.exceptions import ModelError + +from .factories import ( + mock_version_get, + mock_version_get_with_iterator_output, + mock_version_get_with_iterator_output_backwards_compatibility_0_3_8, + mock_version_get_with_list_output, +) @responses.activate @@ -19,3 +31,260 @@ def test_client_sets_authorization_token_and_user_agent_headers(): ) model.versions.list() + + +@responses.activate +def test_run(): + mock_version_get(owner="test", model="model", version="v1") + responses.post( + "https://api.replicate.com/v1/predictions", + match=[ + matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + responses.get( + "https://api.replicate.com/v1/predictions/p1", + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "succeeded", + "input": {"text": "world"}, + "output": "hello world", + "error": None, + "logs": "", + }, + ) + + client = Client(api_token="abc123") + assert client.run("test/model:v1", input={"text": "world"}) == "hello world" + + +@responses.activate +def test_run_with_iterator(): + mock_version_get_with_iterator_output(owner="test", model="model", version="v1") + responses.post( + "https://api.replicate.com/v1/predictions", + match=[ + matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + responses.get( + "https://api.replicate.com/v1/predictions/p1", + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "succeeded", + "input": {"text": "world"}, + "output": ["hello world"], + "error": None, + "logs": "", + }, + ) + + client = Client(api_token="abc123") + output = client.run("test/model:v1", input={"text": "world"}) + assert isinstance(output, Iterable) + assert list(output) == ["hello world"] + + +@responses.activate +def test_run_with_list(): + mock_version_get_with_list_output(owner="test", model="model", version="v1") + responses.post( + "https://api.replicate.com/v1/predictions", + match=[ + matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + responses.get( + "https://api.replicate.com/v1/predictions/p1", + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "succeeded", + "input": {"text": "world"}, + "output": ["hello world"], + "error": None, + "logs": "", + }, + ) + + client = Client(api_token="abc123") + output = client.run("test/model:v1", input={"text": "world"}) + assert isinstance(output, list) + assert output == ["hello world"] + + +@responses.activate +def test_run_with_iterator_backwards_compatibility_cog_0_3_8(): + mock_version_get_with_iterator_output_backwards_compatibility_0_3_8( + owner="test", model="model", version="v1" + ) + responses.post( + "https://api.replicate.com/v1/predictions", + match=[ + matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + responses.get( + "https://api.replicate.com/v1/predictions/p1", + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "succeeded", + "input": {"text": "world"}, + "output": ["hello world"], + "error": None, + "logs": "", + }, + ) + + client = Client(api_token="abc123") + output = client.run("test/model:v1", input={"text": "world"}) + assert isinstance(output, Iterable) + assert list(output) == ["hello world"] + + +@responses.activate +def test_predict_with_iterator_with_failed_prediction(): + mock_version_get_with_iterator_output(owner="test", model="model", version="v1") + responses.post( + "https://api.replicate.com/v1/predictions", + match=[ + matchers.json_params_matcher({"version": "v1", "input": {"text": "world"}}) + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + responses.get( + "https://api.replicate.com/v1/predictions/p1", + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "failed", + "input": {"text": "world"}, + "output": None, + "error": "it broke", + "logs": "", + }, + ) + + client = Client(api_token="abc123") + output = client.run("test/model:v1", input={"text": "world"}) + assert isinstance(output, Iterable) + with pytest.raises(ModelError) as excinfo: + list(output) + assert "it broke" in str(excinfo.value)