From 2c27b99beb0542741f777c12745330e295337154 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 22 Apr 2024 04:12:17 -0700 Subject: [PATCH 1/8] Add overloads for predictions.create with version, model, or deployment Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 122 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 2 deletions(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 2d59791e..3b23f081 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -11,7 +11,9 @@ List, Literal, Optional, + Tuple, Union, + overload, ) from typing_extensions import NotRequired, TypedDict, Unpack @@ -31,6 +33,8 @@ if TYPE_CHECKING: from replicate.client import Client + from replicate.deployment import Deployment + from replicate.model import Model from replicate.stream import ServerSentEvent @@ -380,21 +384,78 @@ class CreatePredictionParams(TypedDict): stream: NotRequired[bool] """Enable streaming of prediction output.""" + @overload def create( self, version: Union[Version, str], input: Optional[Dict[str, Any]], **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + @overload + def create( + self, + *, + model: Union[str, Tuple[str, str], "Model"], + input: Optional[Dict[str, Any]], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + @overload + def create( + self, + *, + deployment: Union[str, Tuple[str, str], "Deployment"], + input: Optional[Dict[str, Any]], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + def create( # type: ignore + self, + *args, + model: Optional[Union[str, Tuple[str, str], "Model"]] = None, + version: Optional[Union[Version, str, "Version"]] = None, + deployment: Optional[Union[str, Tuple[str, str], "Deployment"]] = None, + input: Optional[Dict[str, Any]] = None, + **params: Unpack["Predictions.CreatePredictionParams"], ) -> Prediction: """ - Create a new prediction for the specified model version. + Create a new prediction for the specified model, version, or deployment. """ + if args: + version = args[0] if len(args) > 0 else None + input = args[1] if len(args) > 1 else input + + if sum(bool(x) for x in [model, version, deployment]) != 1: + raise ValueError( + "Exactly one of 'model', 'version', or 'deployment' must be specified." + ) + + if model is not None: + from replicate.model import Models + + return Models(self._client).predictions.create( + model=model, + input=input or {}, + **params, + ) + + if deployment is not None: + from replicate.deployment import Deployments + + return Deployments(self._client).predictions.create( + deployment=deployment, + input=input or {}, + **params, + ) + body = _create_prediction_body( version, input, **params, ) + resp = self._client._request( "POST", "/v1/predictions", @@ -403,21 +464,78 @@ def create( return _json_to_prediction(self._client, resp.json()) + @overload async def async_create( self, version: Union[Version, str], input: Optional[Dict[str, Any]], **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + @overload + async def async_create( + self, + *, + model: Union[str, Tuple[str, str], "Model"], + input: Optional[Dict[str, Any]], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + @overload + async def async_create( + self, + *, + deployment: Union[str, Tuple[str, str], "Deployment"], + input: Optional[Dict[str, Any]], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + async def async_create( # type: ignore + self, + *args, + model: Optional[Union[str, Tuple[str, str], "Model"]] = None, + version: Optional[Union[Version, str, "Version"]] = None, + deployment: Optional[Union[str, Tuple[str, str], "Deployment"]] = None, + input: Optional[Dict[str, Any]] = None, + **params: Unpack["Predictions.CreatePredictionParams"], ) -> Prediction: """ - Create a new prediction for the specified model version. + Create a new prediction for the specified model, version, or deployment. """ + if args: + version = args[0] if len(args) > 0 else None + input = args[1] if len(args) > 1 else input + + if sum(bool(x) for x in [model, version, deployment]) != 1: + raise ValueError( + "Exactly one of 'model', 'version', or 'deployment' must be specified." + ) + + if model is not None: + from replicate.model import Models + + return await Models(self._client).predictions.async_create( + model=model, + input=input or {}, + **params, + ) + + if deployment is not None: + from replicate.deployment import Deployments + + return await Deployments(self._client).predictions.async_create( + deployment=deployment, + input=input or {}, + **params, + ) + body = _create_prediction_body( version, input, **params, ) + resp = await self._client._async_request( "POST", "/v1/predictions", From a007e94c7af916cb324ea29ce0e68b38520ac785 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 22 Apr 2024 04:13:02 -0700 Subject: [PATCH 2/8] Fix 'RuntimeError: Event loop is closed' in tests Signed-off-by: Mattt Zmuda --- tests/conftest.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index a29ed640..103d1693 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,17 @@ +import asyncio import os from unittest import mock import pytest +import pytest_asyncio + + +@pytest_asyncio.fixture(scope="session", autouse=True) +def event_loop(): + event_loop_policy = asyncio.get_event_loop_policy() + loop = event_loop_policy.new_event_loop() + yield loop + loop.close() @pytest.fixture(scope="session") From 26b7d2739ad6040847cf810fb74208967d7beb7d Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 22 Apr 2024 04:13:39 -0700 Subject: [PATCH 3/8] Update version of SDXL Signed-off-by: Mattt Zmuda --- tests/test_prediction.py | 14 +++++++------- tests/test_run.py | 2 +- tests/test_training.py | 18 +++++++++--------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index c64a5989..7a7b824a 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -17,7 +17,7 @@ async def test_predictions_create(async_flag): if async_flag: model = await replicate.models.async_get("stability-ai/sdxl") version = await model.versions.async_get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = await replicate.predictions.async_create( version=version, @@ -26,7 +26,7 @@ async def test_predictions_create(async_flag): else: model = replicate.models.get("stability-ai/sdxl") version = model.versions.get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = replicate.predictions.create( version=version, @@ -42,7 +42,7 @@ async def test_predictions_create(async_flag): @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_predictions_create_with_positional_argument(async_flag): - version = "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + version = "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" input = { "prompt": "a studio photo of a rainbow colored corgi", @@ -95,7 +95,7 @@ async def test_predictions_cancel(async_flag): if async_flag: model = await replicate.models.async_get("stability-ai/sdxl") version = await model.versions.async_get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = await replicate.predictions.async_create( version=version, @@ -111,7 +111,7 @@ async def test_predictions_cancel(async_flag): else: model = replicate.models.get("stability-ai/sdxl") version = model.versions.get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = replicate.predictions.create( version=version, @@ -140,7 +140,7 @@ async def test_predictions_cancel_instance_method(async_flag): if async_flag: model = await replicate.models.async_get("stability-ai/sdxl") version = await model.versions.async_get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = await replicate.predictions.async_create( version=version, @@ -154,7 +154,7 @@ async def test_predictions_cancel_instance_method(async_flag): else: model = replicate.models.get("stability-ai/sdxl") version = model.versions.get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = replicate.predictions.create( version=version, diff --git a/tests/test_run.py b/tests/test_run.py index 00c93cbc..84c8f3ab 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -17,7 +17,7 @@ async def test_run(async_flag, record_mode): if record_mode == "none": replicate.default_client.poll_interval = 0.001 - version = "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + version = "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" input = { "prompt": "a studio photo of a rainbow colored corgi", diff --git a/tests/test_training.py b/tests/test_training.py index 1955ffe6..64926c64 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -13,7 +13,7 @@ async def test_trainings_create(async_flag, mock_replicate_api_token): if async_flag: training = await replicate.trainings.async_create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -23,7 +23,7 @@ async def test_trainings_create(async_flag, mock_replicate_api_token): else: training = replicate.trainings.create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -47,7 +47,7 @@ async def test_trainings_create_with_named_version_argument( return else: training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -71,7 +71,7 @@ async def test_trainings_create_with_positional_argument( return else: training = replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", { "input_images": input_images_url, "use_face_detection_instead": True, @@ -93,7 +93,7 @@ async def test_trainings_create_with_invalid_destination( if async_flag: await replicate.trainings.async_create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -103,7 +103,7 @@ async def test_trainings_create_with_invalid_destination( else: replicate.trainings.create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, }, @@ -140,7 +140,7 @@ async def test_trainings_cancel(async_flag, mock_replicate_api_token): if async_flag: training = await replicate.trainings.async_create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input=input, destination=destination, ) @@ -151,7 +151,7 @@ async def test_trainings_cancel(async_flag, mock_replicate_api_token): assert training.status == "canceled" else: training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", destination=destination, input=input, ) @@ -179,7 +179,7 @@ async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_t return else: training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", destination=destination, input=input, ) From 4ef74980b70cfea7e6d5a1a06f2a1b61be8fb3b9 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 22 Apr 2024 04:13:57 -0700 Subject: [PATCH 4/8] Add test coverage for predictions.create overloads Signed-off-by: Mattt Zmuda --- tests/test_prediction.py | 96 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 89 insertions(+), 7 deletions(-) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 7a7b824a..7e04b8fa 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,4 +1,6 @@ +import httpx import pytest +import respx import replicate @@ -67,21 +69,86 @@ async def test_predictions_create_with_positional_argument(async_flag): assert prediction.status == "starting" -@pytest.mark.vcr("predictions-get.yaml") +@pytest.mark.vcr("predictions-create-by-model.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) -async def test_predictions_get(async_flag): - id = "vgcm4plb7tgzlyznry5d5jkgvu" +async def test_predictions_create_by_model(async_flag): + model = "meta/meta-llama-3-8b-instruct" + input = { + "prompt": "write a haiku about llamas", + } if async_flag: - prediction = await replicate.predictions.async_get(id) + prediction = await replicate.predictions.async_create( + model=model, + input=input, + ) else: - prediction = replicate.predictions.get(id) + prediction = replicate.predictions.create( + model=model, + input=input, + ) - assert prediction.id == id + assert prediction.id is not None + # assert prediction.model == model + assert prediction.status == "starting" -@pytest.mark.vcr("predictions-cancel.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_predictions_create_by_deployment(async_flag): + router = respx.Router(base_url="https://api.replicate.com/v1") + + router.route( + method="POST", + path="/deployments/replicate/my-app-image-generator/predictions", + name="deployments.predictions.create", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "p1", + "model": "replicate/my-app-image-generator", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "source": "api", + "status": "starting", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + ) + + router.route(host="api.replicate.com").pass_through() + + client = replicate.Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + input = {"text": "world"} + + if async_flag: + prediction = await client.predictions.async_create( + deployment="replicate/my-app-image-generator", + input=input, + ) + else: + prediction = client.predictions.create( + deployment="replicate/my-app-image-generator", + input=input, + ) + + assert prediction.id is not None + assert prediction.status == "starting" + + +@pytest.mark.vcr("models-predictions-create.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_predictions_cancel(async_flag): @@ -126,6 +193,20 @@ async def test_predictions_cancel(async_flag): assert prediction.status == "canceled" +@pytest.mark.vcr("predictions-get.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_predictions_get(async_flag): + id = "vgcm4plb7tgzlyznry5d5jkgvu" + + if async_flag: + prediction = await replicate.predictions.async_get(id) + else: + prediction = replicate.predictions.get(id) + + assert prediction.id == id + + @pytest.mark.vcr("predictions-cancel.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) @@ -199,6 +280,7 @@ async def test_predictions_stream(async_flag): assert prediction.id is not None assert prediction.version == version.id assert prediction.status == "starting" + assert prediction.urls is not None assert prediction.urls["stream"] is not None From 74eac4e5c924a00452c8c4713609a2f2d04988f2 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 22 Apr 2024 04:14:10 -0700 Subject: [PATCH 5/8] Update existing cassettes Signed-off-by: Mattt Zmuda --- tests/cassettes/predictions-cancel.yaml | 512 ----- tests/cassettes/predictions-create.yaml | 816 ++++---- tests/cassettes/predictions-get.yaml | 304 +-- tests/cassettes/run.yaml | 1760 +++++++++-------- tests/cassettes/trainings-cancel.yaml | 206 +- tests/cassettes/trainings-create.yaml | 106 +- ...trainings-create__invalid-destination.yaml | 106 +- tests/cassettes/trainings-get.yaml | 142 +- 8 files changed, 1734 insertions(+), 2218 deletions(-) delete mode 100644 tests/cassettes/predictions-cancel.yaml diff --git a/tests/cassettes/predictions-cancel.yaml b/tests/cassettes/predictions-cancel.yaml deleted file mode 100644 index f67e671e..00000000 --- a/tests/cassettes/predictions-cancel.yaml +++ /dev/null @@ -1,512 +0,0 @@ -interactions: -- request: - body: '' - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - host: - - api.replicate.com - user-agent: - - replicate-python/0.11.0 - method: GET - uri: https://api.replicate.com/v1/models/stability-ai/sdxl - response: - content: "{\"url\":\"https://replicate.com/stability-ai/sdxl\",\"owner\":\"stability-ai\",\"name\":\"sdxl\",\"description\":\"A - text-to-image generative AI model that creates beautiful 1024x1024 images\",\"visibility\":\"public\",\"github_url\":\"https://github.com/Stability-AI/generative-models\",\"paper_url\":\"https://arxiv.org/abs/2307.01952\",\"license_url\":\"https://github.com/Stability-AI/generative-models/blob/main/model_licenses/LICENSE-SDXL1.0\",\"run_count\":918101,\"cover_image_url\":\"https://tjzk.replicate.delivery/models_models_cover_image/61004930-fb88-4e09-9bd4-74fd8b4aa677/sdxl_cover.png\",\"default_example\":{\"completed_at\":\"2023-07-26T21:04:37.933562Z\",\"created_at\":\"2023-07-26T21:04:23.762683Z\",\"error\":null,\"id\":\"vu42q7dbkm6iicbpal4v6uvbqm\",\"input\":{\"width\":1024,\"height\":1024,\"prompt\":\"An - astronaut riding a rainbow unicorn, cinematic, dramatic\",\"refine\":\"expert_ensemble_refiner\",\"scheduler\":\"DDIM\",\"num_outputs\":1,\"guidance_scale\":7.5,\"high_noise_frac\":0.8,\"prompt_strength\":0.8,\"num_inference_steps\":50},\"logs\":\"Using - seed: 12103\\ntxt2img mode\\n 0%| | 0/40 [00:00"}' - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '148' - content-type: - - application/json - host: - - api.replicate.com - user-agent: - - replicate-python/0.11.0 - method: POST - uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5/trainings - response: - content: '{"detail":"The specified training destination does not exist","status":404} + - request: + body: + '{"input": {"input_images": "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip"}, + "destination": ""}' + headers: + accept: + - "*/*" + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - "148" + content-type: + - application/json + host: + - api.replicate.com + user-agent: + - replicate-python/0.11.0 + method: POST + uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b/trainings + response: + content: + '{"detail":"The specified training destination does not exist","status":404} - ' - headers: - CF-Cache-Status: - - DYNAMIC - CF-RAY: - - 7f7c2190ed8c281a-SEA - Connection: - - keep-alive - Content-Length: - - '76' - Content-Type: - - application/problem+json - Date: - - Wed, 16 Aug 2023 19:37:18 GMT - NEL: - - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' - Report-To: - - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=0vMWFlGDyffyF0A%2FL4%2FH830OVHnZd0gZDww4oocSSHq7eMAt327ut6v%2B2qAda7fThmH4WcElLTM%2B3PFyrsa1w1SHgfEdWyJSv8TYYi2nWXMqeP5EJc1SDjV958HGKSKDnjH5"}],"group":"cf-nel","max_age":604800}' - Server: - - cloudflare - Strict-Transport-Security: - - max-age=15552000 - ratelimit-remaining: - - '2999' - ratelimit-reset: - - '1' - via: - - 1.1 google - http_version: HTTP/1.1 - status_code: 404 + ' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7f7c2190ed8c281a-SEA + Connection: + - keep-alive + Content-Length: + - "76" + Content-Type: + - application/problem+json + Date: + - Wed, 16 Aug 2023 19:37:18 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=0vMWFlGDyffyF0A%2FL4%2FH830OVHnZd0gZDww4oocSSHq7eMAt327ut6v%2B2qAda7fThmH4WcElLTM%2B3PFyrsa1w1SHgfEdWyJSv8TYYi2nWXMqeP5EJc1SDjV958HGKSKDnjH5"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + ratelimit-remaining: + - "2999" + ratelimit-reset: + - "1" + via: + - 1.1 google + http_version: HTTP/1.1 + status_code: 404 version: 1 diff --git a/tests/cassettes/trainings-get.yaml b/tests/cassettes/trainings-get.yaml index 3777f743..933591e8 100644 --- a/tests/cassettes/trainings-get.yaml +++ b/tests/cassettes/trainings-get.yaml @@ -1,73 +1,73 @@ interactions: -- request: - body: '' - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - host: - - api.replicate.com - user-agent: - - replicate-python/0.11.0 - method: GET - uri: https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte - response: - content: '{"completed_at":null,"created_at":"2023-08-16T19:33:26.906823Z","error":null,"id":"medrnz3bm5dd6ultvad2tejrte","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":null,"metrics":{},"output":null,"started_at":"2023-08-16T19:33:42.114513Z","status":"processing","urls":{"get":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte","cancel":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte/cancel"},"model":"stability-ai/sdxl","version":"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5","webhook_completed":null}' - headers: - CF-Cache-Status: - - DYNAMIC - CF-RAY: - - 7f7c1beaedff279c-SEA - Connection: - - keep-alive - Content-Encoding: - - gzip - Content-Type: - - application/json - Date: - - Wed, 16 Aug 2023 19:33:26 GMT - NEL: - - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' - Report-To: - - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=SntiwLHCR4wiv49Qmn%2BR1ZblcX%2FgoVlIgsek4yZliZiWts2SqPjqTjrSkB%2Bwch8oHqR%2BBNVs1cSbihlHd8MWPXsbwC2uShz0c6tD4nclaecblb3FnEp4Mccy9hlZ39izF9Tm"}],"group":"cf-nel","max_age":604800}' - Server: - - cloudflare - Strict-Transport-Security: - - max-age=15552000 - Transfer-Encoding: - - chunked - allow: - - OPTIONS, GET - content-security-policy-report-only: - - 'style-src ''report-sample'' ''self'' ''unsafe-inline'' https://fonts.googleapis.com; - img-src ''report-sample'' ''self'' data: https://replicate.delivery https://*.replicate.delivery - https://*.githubusercontent.com https://github.com; worker-src ''none''; media-src - ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery - https://*.mux.com https://*.gstatic.com https://*.sentry.io; connect-src ''report-sample'' - ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com - https://*.rudderstack.com https://*.mux.com https://*.sentry.io; script-src - ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; - font-src ''report-sample'' ''self'' data: https://fonts.replicate.ai https://fonts.gstatic.com; - default-src ''self''; report-uri' - cross-origin-opener-policy: - - same-origin - ratelimit-remaining: - - '2999' - ratelimit-reset: - - '1' - referrer-policy: - - same-origin - vary: - - Cookie, origin - via: - - 1.1 vegur, 1.1 google - x-content-type-options: - - nosniff - x-frame-options: - - DENY - http_version: HTTP/1.1 - status_code: 200 + - request: + body: "" + headers: + accept: + - "*/*" + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.11.0 + method: GET + uri: https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte + response: + content: '{"completed_at":null,"created_at":"2023-08-16T19:33:26.906823Z","error":null,"id":"medrnz3bm5dd6ultvad2tejrte","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":null,"metrics":{},"output":null,"started_at":"2023-08-16T19:33:42.114513Z","status":"processing","urls":{"get":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte","cancel":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte/cancel"},"model":"stability-ai/sdxl","version":"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b","webhook_completed":null}' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7f7c1beaedff279c-SEA + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 16 Aug 2023 19:33:26 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=SntiwLHCR4wiv49Qmn%2BR1ZblcX%2FgoVlIgsek4yZliZiWts2SqPjqTjrSkB%2Bwch8oHqR%2BBNVs1cSbihlHd8MWPXsbwC2uShz0c6tD4nclaecblb3FnEp4Mccy9hlZ39izF9Tm"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + allow: + - OPTIONS, GET + content-security-policy-report-only: + - "style-src 'report-sample' 'self' 'unsafe-inline' https://fonts.googleapis.com; + img-src 'report-sample' 'self' data: https://replicate.delivery https://*.replicate.delivery + https://*.githubusercontent.com https://github.com; worker-src 'none'; media-src + 'report-sample' 'self' https://replicate.delivery https://*.replicate.delivery + https://*.mux.com https://*.gstatic.com https://*.sentry.io; connect-src 'report-sample' + 'self' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com + https://*.rudderstack.com https://*.mux.com https://*.sentry.io; script-src + 'report-sample' 'self' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; + font-src 'report-sample' 'self' data: https://fonts.replicate.ai https://fonts.gstatic.com; + default-src 'self'; report-uri" + cross-origin-opener-policy: + - same-origin + ratelimit-remaining: + - "2999" + ratelimit-reset: + - "1" + referrer-policy: + - same-origin + vary: + - Cookie, origin + via: + - 1.1 vegur, 1.1 google + x-content-type-options: + - nosniff + x-frame-options: + - DENY + http_version: HTTP/1.1 + status_code: 200 version: 1 From 2584ffb0613d994275d14399cfb9733bab72cd60 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 22 Apr 2024 04:16:04 -0700 Subject: [PATCH 6/8] Add new cassettes Signed-off-by: Mattt Zmuda --- tests/cassettes/predictions-cancel.yaml | 529 ++++++++++++++++++ .../test_predictions_cancel[False].yaml | 529 ++++++++++++++++++ .../test_predictions_cancel[True].yaml | 529 ++++++++++++++++++ ...st_predictions_create_by_model[False].yaml | 61 ++ ...est_predictions_create_by_model[True].yaml | 120 ++++ 5 files changed, 1768 insertions(+) create mode 100644 tests/cassettes/predictions-cancel.yaml create mode 100644 tests/cassettes/test_predictions_cancel[False].yaml create mode 100644 tests/cassettes/test_predictions_cancel[True].yaml create mode 100644 tests/cassettes/test_predictions_create_by_model[False].yaml create mode 100644 tests/cassettes/test_predictions_create_by_model[True].yaml diff --git a/tests/cassettes/predictions-cancel.yaml b/tests/cassettes/predictions-cancel.yaml new file mode 100644 index 00000000..cdbcf9f3 --- /dev/null +++ b/tests/cassettes/predictions-cancel.yaml @@ -0,0 +1,529 @@ +interactions: +- request: + body: '{"input": {"prompt": "Please write a haiku about llamas."}}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '59' + content-type: + - application/json + host: + - api.replicate.com + user-agent: + - replicate-python/0.21.0 + method: POST + uri: https://api.replicate.com/v1/models/meta/llama-2-70b-chat/predictions + response: + content: '{"id":"heat2o3bzn3ahtr6bjfftvbaci","model":"replicate/lifeboat-70b","version":"d-c6559c5791b50af57b69f4a73f8e021c","input":{"prompt":"Please + write a haiku about llamas."},"logs":"","error":null,"status":"starting","created_at":"2023-11-27T13:35:45.99397566Z","urls":{"cancel":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci/cancel","get":"https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci"}} + + ' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 82cac197efaec53d-SEA + Connection: + - keep-alive + Content-Length: + - '431' + Content-Type: + - application/json + Date: + - Mon, 27 Nov 2023 13:35:46 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=7R5RONMF6xaGRc39n0wnSe3jU1FbpX64Xz4U%2B%2F2nasvFaz0pKARxPhnzDgYkLaWgdK9zWrD2jxU04aKOy5HMPHAXboJ993L4zfsOyto56lBtdqSjNgkptzzxYEsKD%2FxIhe2F"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + ratelimit-remaining: + - '599' + ratelimit-reset: + - '1' + via: + - 1.1 google + http_version: HTTP/1.1 + status_code: 201 +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.25.1 + method: GET + uri: https://api.replicate.com/v1/models/stability-ai/sdxl + response: + content: '{"cover_image_url": "https://tjzk.replicate.delivery/models_models_cover_image/61004930-fb88-4e09-9bd4-74fd8b4aa677/sdxl_cover.png", + "created_at": "2023-07-26T17:53:09.882651Z", "default_example": {"completed_at": + "2023-10-12T17:10:12.909279Z", "created_at": "2023-10-12T17:10:07.956869Z", + "error": null, "id": "dzsqmb3bg4lqpjkz2iptjqgccm", "input": {"width": 768, "height": + 768, "prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic", + "refine": "expert_ensemble_refiner", "scheduler": "K_EULER", "lora_scale": 0.6, + "num_outputs": 1, "guidance_scale": 7.5, "apply_watermark": false, "high_noise_frac": + 0.8, "negative_prompt": "", "prompt_strength": 0.8, "num_inference_steps": 25}, + "logs": "Using seed: 16010\nPrompt: An astronaut riding a rainbow unicorn, cinematic, + dramatic\ntxt2img mode\n 0%| | 0/16 [00:00 Date: Mon, 22 Apr 2024 04:40:50 -0700 Subject: [PATCH 7/8] Add test case for specifying model and version and deployment Signed-off-by: Mattt Zmuda --- tests/test_prediction.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 7e04b8fa..c07e02c6 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -148,6 +148,41 @@ async def test_predictions_create_by_deployment(async_flag): assert prediction.status == "starting" +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_predictions_create_fail_with_too_many_arguments(async_flag): + router = respx.Router(base_url="https://api.replicate.com/v1") + + client = replicate.Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + model = "meta/meta-llama-3-8b-instruct" + deployment = "replicate/my-app-image-generator" + input = {} + + with pytest.raises(ValueError) as exc_info: + if async_flag: + await client.predictions.async_create( + version=version, + model=model, + deployment=deployment, + input=input, + ) + else: + client.predictions.create( + version=version, + model=model, + deployment=deployment, + input=input, + ) + assert ( + str(exc_info.value) + == "Exactly one of 'model', 'version', or 'deployment' must be specified." + ) + + @pytest.mark.vcr("models-predictions-create.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) From 9d713aaac29a958c1427374fab55c2b79db287a9 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 22 Apr 2024 04:42:34 -0700 Subject: [PATCH 8/8] Ignore pylint import-outside-toplevel warning Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 3b23f081..b1825c30 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -433,7 +433,9 @@ def create( # type: ignore ) if model is not None: - from replicate.model import Models + from replicate.model import ( # pylint: disable=import-outside-toplevel + Models, + ) return Models(self._client).predictions.create( model=model, @@ -442,7 +444,9 @@ def create( # type: ignore ) if deployment is not None: - from replicate.deployment import Deployments + from replicate.deployment import ( # pylint: disable=import-outside-toplevel + Deployments, + ) return Deployments(self._client).predictions.create( deployment=deployment, @@ -513,7 +517,9 @@ async def async_create( # type: ignore ) if model is not None: - from replicate.model import Models + from replicate.model import ( # pylint: disable=import-outside-toplevel + Models, + ) return await Models(self._client).predictions.async_create( model=model, @@ -522,7 +528,9 @@ async def async_create( # type: ignore ) if deployment is not None: - from replicate.deployment import Deployments + from replicate.deployment import ( # pylint: disable=import-outside-toplevel + Deployments, + ) return await Deployments(self._client).predictions.async_create( deployment=deployment,