diff --git a/replicate/schema.py b/replicate/schema.py index a48d2351..0fb784a3 100644 --- a/replicate/schema.py +++ b/replicate/schema.py @@ -1,11 +1,16 @@ +from typing import Optional + from packaging import version # TODO: this code is shared with replicate's backend. Maybe we should put it in the Cog Python package as the source of truth? -def version_has_no_array_type(cog_version: str) -> bool: +def version_has_no_array_type(cog_version: str) -> Optional[bool]: """Iterators have x-cog-array-type=iterator in the schema from 0.3.9 onward""" - return version.parse(cog_version) < version.parse("0.3.9") + try: + return version.parse(cog_version) < version.parse("0.3.9") + except version.InvalidVersion: + return None def make_schema_backwards_compatible( @@ -13,6 +18,7 @@ def make_schema_backwards_compatible( version: str, ) -> dict: """A place to add backwards compatibility logic for our openapi schema""" + # If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type if version_has_no_array_type(version): output = schema["components"]["schemas"]["Output"] diff --git a/tests/test_run.py b/tests/test_run.py index 6aceec11..e000c9b8 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,9 @@ +import httpx import pytest +import respx import replicate +from replicate.client import Client from replicate.exceptions import ReplicateError @@ -31,3 +34,90 @@ async def test_run(mock_replicate_api_token): def test_run_with_invalid_identifier(mock_replicate_api_token): with pytest.raises(ReplicateError): replicate.run("invalid") + + +@pytest.mark.asyncio +async def test_run_version_with_invalid_cog_version(mock_replicate_api_token): + def prediction_with_status(status: str) -> dict: + return { + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": "Hello, world!" if status == "succeeded" else None, + "error": None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("running"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status("succeeded"), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/invalid", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2022-03-16T00:35:56.210272Z", + "cog_version": "dev", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": {}, + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "The text input", + }, + }, + }, + "Output": { + "type": "string", + "title": "Output", + }, + } + }, + }, + }, + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + output = client.run( + "test/example:invalid", + input={ + "text": "Hello, world!", + }, + ) + + assert output == "Hello, world!"