From f9d16dbc102adf0775deaf51c955261cb2a0da31 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 11 Sep 2024 15:15:55 +0100 Subject: [PATCH 01/16] Update lockfile from running `rye sync` --- requirements-dev.lock | 2 ++ requirements.lock | 2 ++ 2 files changed, 4 insertions(+) diff --git a/requirements-dev.lock b/requirements-dev.lock index 3eae4db9..90d7aeda 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -6,6 +6,8 @@ # features: [] # all-features: false # with-sources: false +# generate-hashes: false +# universal: false -e file:. annotated-types==0.6.0 diff --git a/requirements.lock b/requirements.lock index 53ab3f58..b1e20e40 100644 --- a/requirements.lock +++ b/requirements.lock @@ -6,6 +6,8 @@ # features: [] # all-features: false # with-sources: false +# generate-hashes: false +# universal: false -e file:. annotated-types==0.6.0 From 57f6f0ba34ca6f9eb7ba00e4f17a22481a5ed93e Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 11 Sep 2024 15:16:06 +0100 Subject: [PATCH 02/16] Update pyproject.toml to configure pyright --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index e95ca1ea..c9367d0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,3 +86,7 @@ ignore = [ "ANN201", # Missing return type annotation for public function "ANN202", # Missing return type annotation for private function ] + +[tool.pyright] +venvPath = "." +venv = ".venv" From 3cc0b86c880ca061293b1f328cc2c3d721da2307 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 11 Sep 2024 15:16:24 +0100 Subject: [PATCH 03/16] Implement a FileOutput interface --- replicate/stream.py | 65 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/replicate/stream.py b/replicate/stream.py index 844973d4..068f728d 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -1,3 +1,6 @@ +import io +import base64 +import httpx from enum import Enum from typing import ( TYPE_CHECKING, @@ -9,9 +12,10 @@ Optional, Union, ) - +from contextlib import asynccontextmanager, contextmanager from typing_extensions import Unpack + from replicate import identifier from replicate.exceptions import ReplicateError @@ -22,8 +26,6 @@ if TYPE_CHECKING: - import httpx - from replicate.client import Client from replicate.identifier import ModelVersionIdentifier from replicate.model import Model @@ -31,6 +33,63 @@ from replicate.version import Version +class FileOutputProvider: + url: str + client: "Client" + + def __init__(self, url: str, client: "Client"): + self.url = url + self.client = client + + def read(self) -> bytes: + with self.stream() as file: + return file.read() + + @contextmanager + def stream(self) -> Iterator["FileOutput"]: + with self.client._client.stream("GET", self.url) as response: + response.raise_for_status() + yield FileOutput(response) + + @asynccontextmanager + async def astream(self) -> AsyncIterator["FileOutput"]: + async with self.client._async_client.stream("GET", self.url) as response: + response.raise_for_status() + yield FileOutput(response) + + async def aread(self) -> bytes: + async with self.astream() as file: + return await file.aread() + + def __repr__(self) -> str: + return self.url + + +class FileOutput(httpx.ByteStream, httpx.AsyncByteStream): + def __init__(self, response: httpx.Response): + self.response = response + + def __iter__(self) -> Iterator[bytes]: + for bytes in self.response.iter_bytes(): + yield bytes + + def close(self): + return self.response.close() + + def read(self): + return self.response.read() + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for bytes in self.response.aiter_bytes(): + yield bytes + + async def aclose(self): + return await self.response.aclose() + + async def aread(self): + return await self.response.aread() + + class ServerSentEvent(pydantic.BaseModel): # type: ignore """ A server-sent event. From 7b2e7f474e6c506ea9925b34f03677a32becbe64 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 11 Sep 2024 15:16:35 +0100 Subject: [PATCH 04/16] Implement experimental FileOutput interface --- replicate/client.py | 6 +- replicate/run.py | 27 +++++++ replicate/stream.py | 54 +++++--------- tests/test_run.py | 175 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 222 insertions(+), 40 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 5cee7a6f..d149ff4d 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -164,25 +164,27 @@ def run( self, ref: str, input: Optional[Dict[str, Any]] = None, + use_file_output: bool = False, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output. """ - return run(self, ref, input, **params) + return run(self, ref, input, use_file_output, **params) async def async_run( self, ref: str, input: Optional[Dict[str, Any]] = None, + use_file_output: bool = False, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ - return await async_run(self, ref, input, **params) + return await async_run(self, ref, input, use_file_output, **params) def stream( self, diff --git a/replicate/run.py b/replicate/run.py index ae1ca7e5..5c0559a7 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -17,6 +17,7 @@ from replicate.prediction import Prediction from replicate.schema import make_schema_backwards_compatible from replicate.version import Version, Versions +from replicate.stream import FileOutput if TYPE_CHECKING: from replicate.client import Client @@ -28,6 +29,7 @@ def run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, + use_file_output: bool = False, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ @@ -60,6 +62,9 @@ def run( if prediction.status == "failed": raise ModelError(prediction) + if use_file_output: + return transform_output(prediction.output, client) + return prediction.output @@ -67,6 +72,7 @@ async def async_run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, + use_file_output: bool = False, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ @@ -99,6 +105,9 @@ async def async_run( if prediction.status == "failed": raise ModelError(prediction) + if use_file_output: + return transform_output(prediction.output, client) + return prediction.output @@ -130,4 +139,22 @@ def _make_async_output_iterator( return None +def transform(obj, func): + if isinstance(obj, dict): + return {k: transform(v, func) for k, v in obj.items()} + elif isinstance(obj, list): + return [transform(item, func) for item in obj] + else: + return func(obj) + + +def transform_output(value: Any, client: "Client"): + def wrapper(x): + if isinstance(x, str) and (x.startswith("https:") or x.startswith("data:")): + return FileOutput(x, client) + return x + + return transform(value, wrapper) + + __all__: List = [] diff --git a/replicate/stream.py b/replicate/stream.py index 068f728d..d0b9c2f1 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -33,7 +33,7 @@ from replicate.version import Version -class FileOutputProvider: +class FileOutput(httpx.ByteStream, httpx.AsyncByteStream): url: str client: "Client" @@ -42,52 +42,32 @@ def __init__(self, url: str, client: "Client"): self.client = client def read(self) -> bytes: - with self.stream() as file: - return file.read() - - @contextmanager - def stream(self) -> Iterator["FileOutput"]: with self.client._client.stream("GET", self.url) as response: response.raise_for_status() - yield FileOutput(response) + return response.read() - @asynccontextmanager - async def astream(self) -> AsyncIterator["FileOutput"]: - async with self.client._async_client.stream("GET", self.url) as response: + def __iter__(self) -> Iterator[bytes]: + with self.client._client.stream("GET", self.url) as response: response.raise_for_status() - yield FileOutput(response) + for chunk in response.iter_bytes(): + yield chunk async def aread(self) -> bytes: - async with self.astream() as file: - return await file.aread() - - def __repr__(self) -> str: - return self.url - - -class FileOutput(httpx.ByteStream, httpx.AsyncByteStream): - def __init__(self, response: httpx.Response): - self.response = response - - def __iter__(self) -> Iterator[bytes]: - for bytes in self.response.iter_bytes(): - yield bytes - - def close(self): - return self.response.close() - - def read(self): - return self.response.read() + async with self.client._async_client.stream("GET", self.url) as response: + response.raise_for_status() + return await response.aread() async def __aiter__(self) -> AsyncIterator[bytes]: - async for bytes in self.response.aiter_bytes(): - yield bytes + async with self.client._async_client.stream("GET", self.url) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk - async def aclose(self): - return await self.response.aclose() + def __str__(self) -> str: + return self.url - async def aread(self): - return await self.response.aread() + def __repr__(self) -> str: + return self.url class ServerSentEvent(pydantic.BaseModel): # type: ignore diff --git a/tests/test_run.py b/tests/test_run.py index d117eb32..5ddf9918 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -5,9 +5,11 @@ import pytest import respx +from typing import cast import replicate from replicate.client import Client from replicate.exceptions import ModelError, ReplicateError +from replicate.stream import FileOutput @pytest.mark.vcr("run.yaml") @@ -73,7 +75,7 @@ async def test_run_concurrently(mock_replicate_api_token, record_mode): results = await asyncio.gather(*tasks) assert len(results) == len(prompts) assert all(isinstance(result, list) for result in results) - assert all(len(result) > 0 for result in results) + assert all(len(results) > 0 for result in results) @pytest.mark.vcr("run.yaml") @@ -253,3 +255,174 @@ def prediction_with_status(status: str) -> dict: assert str(excinfo.value) == "OOM" assert excinfo.value.prediction.error == "OOM" assert excinfo.value.prediction.status == "failed" + + +@pytest.mark.asyncio +async def test_run_with_file_output(mock_replicate_api_token): + def prediction_with_status( + status: str, output: str | list[str] | None = None + ) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("processing"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status( + "succeeded", "https://api.replicate.com/v1/assets/output.txt" + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2024-07-18T00:35:56.210272Z", + "cog_version": "0.9.10", + "openapi_schema": { + "openapi": "3.0.2", + }, + }, + ) + ) + router.route(method="GET", path="/assets/output.txt").mock( + return_value=httpx.Response(200, content=b"Hello, world!") + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + output = cast( + FileOutput, + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + use_file_output=True, + ), + ) + + assert output.url == "https://api.replicate.com/v1/assets/output.txt" + + assert output.read() == b"Hello, world!" + for chunk in output: + assert chunk == b"Hello, world!" + + assert await output.aread() == b"Hello, world!" + async for chunk in output: + assert chunk == b"Hello, world!" + + +@pytest.mark.asyncio +async def test_run_with_file_output_array(mock_replicate_api_token): + def prediction_with_status( + status: str, output: str | list[str] | None = None + ) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("processing"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status( + "succeeded", + [ + "https://api.replicate.com/v1/assets/hello.txt", + "https://api.replicate.com/v1/assets/world.txt", + ], + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2024-07-18T00:35:56.210272Z", + "cog_version": "0.9.10", + "openapi_schema": { + "openapi": "3.0.2", + }, + }, + ) + ) + router.route(method="GET", path="/assets/hello.txt").mock( + return_value=httpx.Response(200, content=b"Hello,") + ) + router.route(method="GET", path="/assets/world.txt").mock( + return_value=httpx.Response(200, content=b" world!") + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + [output1, output2] = cast( + list[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + use_file_output=True, + ), + ) + + assert output1.url == "https://api.replicate.com/v1/assets/hello.txt" + assert output2.url == "https://api.replicate.com/v1/assets/world.txt" + + assert output1.read() == b"Hello," + assert output2.read() == b" world!" From 9b7c82c2cb2271e0eecb7736bfde29a9ac5848aa Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 10:52:08 -0700 Subject: [PATCH 05/16] Ignore ANN401 warnings Signed-off-by: Mattt Zmuda --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c9367d0e..765c0210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ ignore = [ "ANN003", # Missing type annotation for `**kwargs` "ANN101", # Missing type annotation for self in method "ANN102", # Missing type annotation for cls in classmethod + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name} "W191", # Indentation contains tabs "UP037", # Remove quotes from type annotation ] From bd12f1db6a694e1d944710d717c0fdb1a27bf1d0 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 10:54:51 -0700 Subject: [PATCH 06/16] Refactor transform_output Signed-off-by: Mattt Zmuda --- replicate/run.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/replicate/run.py b/replicate/run.py index 5c0559a7..738d4558 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -16,8 +17,8 @@ from replicate.model import Model from replicate.prediction import Prediction from replicate.schema import make_schema_backwards_compatible -from replicate.version import Version, Versions from replicate.stream import FileOutput +from replicate.version import Version, Versions if TYPE_CHECKING: from replicate.client import Client @@ -139,22 +140,19 @@ def _make_async_output_iterator( return None -def transform(obj, func): - if isinstance(obj, dict): - return {k: transform(v, func) for k, v in obj.items()} - elif isinstance(obj, list): - return [transform(item, func) for item in obj] - else: - return func(obj) - - -def transform_output(value: Any, client: "Client"): - def wrapper(x): - if isinstance(x, str) and (x.startswith("https:") or x.startswith("data:")): - return FileOutput(x, client) - return x +def transform_output(value: Any, client: "Client") -> Any: + def transform(obj: Any) -> Any: + if isinstance(obj, Mapping): + return {k: transform(v) for k, v in obj.items()} + elif isinstance(obj, Sequence) and not isinstance(obj, str): + return [transform(item) for item in obj] + elif isinstance(obj, str) and ( + obj.startswith("https:") or obj.startswith("data:") + ): + return FileOutput(obj, client) + return obj - return transform(value, wrapper) + return transform(value) __all__: List = [] From 310aea2afb428526f42660061442784c7bffb609 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 10:56:23 -0700 Subject: [PATCH 07/16] Fix warning: Boolean default positional argument in function definition Signed-off-by: Mattt Zmuda --- replicate/client.py | 4 ++-- replicate/run.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index d149ff4d..3da3cc15 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -164,7 +164,7 @@ def run( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: bool = False, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ @@ -177,7 +177,7 @@ async def async_run( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: bool = False, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ diff --git a/replicate/run.py b/replicate/run.py index 738d4558..5b576239 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -30,7 +30,7 @@ def run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: bool = False, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ @@ -73,7 +73,7 @@ async def async_run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: bool = False, + use_file_output: Optional[bool] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ From 69c95e65dcf40dcb65bd754ae910649b75d7dc0e Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 10:59:00 -0700 Subject: [PATCH 08/16] Rename json module to helpers Signed-off-by: Mattt Zmuda --- replicate/deployment.py | 2 +- replicate/{json.py => helpers.py} | 0 replicate/model.py | 2 +- replicate/prediction.py | 2 +- replicate/training.py | 2 +- tests/{test_json.py => test_helpers.py} | 2 +- 6 files changed, 5 insertions(+), 5 deletions(-) rename replicate/{json.py => helpers.py} (100%) rename tests/{test_json.py => test_helpers.py} (95%) diff --git a/replicate/deployment.py b/replicate/deployment.py index 8d9836b0..e17edcbc 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -3,7 +3,7 @@ from typing_extensions import Unpack, deprecated from replicate.account import Account -from replicate.json import async_encode_json, encode_json +from replicate.helpers import async_encode_json, encode_json from replicate.pagination import Page from replicate.prediction import ( Prediction, diff --git a/replicate/json.py b/replicate/helpers.py similarity index 100% rename from replicate/json.py rename to replicate/helpers.py diff --git a/replicate/model.py b/replicate/model.py index ccae9cd0..9847ce4b 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -4,7 +4,7 @@ from replicate.exceptions import ReplicateException from replicate.identifier import ModelVersionIdentifier -from replicate.json import async_encode_json, encode_json +from replicate.helpers import async_encode_json, encode_json from replicate.pagination import Page from replicate.prediction import ( Prediction, diff --git a/replicate/prediction.py b/replicate/prediction.py index 7028a712..9770029b 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -20,7 +20,7 @@ from replicate.exceptions import ModelError, ReplicateError from replicate.file import FileEncodingStrategy -from replicate.json import async_encode_json, encode_json +from replicate.helpers import async_encode_json, encode_json from replicate.pagination import Page from replicate.resource import Namespace, Resource from replicate.stream import EventSource diff --git a/replicate/training.py b/replicate/training.py index ba3554df..b03779a3 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -14,7 +14,7 @@ from typing_extensions import NotRequired, Unpack from replicate.identifier import ModelVersionIdentifier -from replicate.json import async_encode_json, encode_json +from replicate.helpers import async_encode_json, encode_json from replicate.model import Model from replicate.pagination import Page from replicate.resource import Namespace, Resource diff --git a/tests/test_json.py b/tests/test_helpers.py similarity index 95% rename from tests/test_json.py rename to tests/test_helpers.py index b8d76569..0c41cab7 100644 --- a/tests/test_json.py +++ b/tests/test_helpers.py @@ -2,7 +2,7 @@ import pytest -from replicate.json import base64_encode_file +from replicate.helpers import base64_encode_file @pytest.mark.parametrize( From c238c39f44fa19a636b0b4682ad17a7157d0283e Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:00:22 -0700 Subject: [PATCH 09/16] Move transform_output to helpers module Signed-off-by: Mattt Zmuda --- replicate/helpers.py | 18 ++++++++++++++++++ replicate/run.py | 18 +----------------- tests/test_run.py | 2 +- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index 90154a84..8edc0404 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -1,6 +1,7 @@ import base64 import io import mimetypes +from collections.abc import Mapping, Sequence from pathlib import Path from types import GeneratorType from typing import TYPE_CHECKING, Any, Optional @@ -108,3 +109,20 @@ def base64_encode_file(file: io.IOBase) -> str: mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream" ) return f"data:{mime_type};base64,{encoded_body}" + + +def transform_output(value: Any, client: "Client") -> Any: + from replicate.stream import FileOutput # pylint: disable=import-outside-toplevel + + def transform(obj: Any) -> Any: + if isinstance(obj, Mapping): + return {k: transform(v) for k, v in obj.items()} + elif isinstance(obj, Sequence) and not isinstance(obj, str): + return [transform(item) for item in obj] + elif isinstance(obj, str) and ( + obj.startswith("https:") or obj.startswith("data:") + ): + return FileOutput(obj, client) + return obj + + return transform(value) diff --git a/replicate/run.py b/replicate/run.py index 5b576239..fd1accfb 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -1,4 +1,3 @@ -from collections.abc import Mapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -14,10 +13,10 @@ from replicate import identifier from replicate.exceptions import ModelError +from replicate.helpers import transform_output from replicate.model import Model from replicate.prediction import Prediction from replicate.schema import make_schema_backwards_compatible -from replicate.stream import FileOutput from replicate.version import Version, Versions if TYPE_CHECKING: @@ -140,19 +139,4 @@ def _make_async_output_iterator( return None -def transform_output(value: Any, client: "Client") -> Any: - def transform(obj: Any) -> Any: - if isinstance(obj, Mapping): - return {k: transform(v) for k, v in obj.items()} - elif isinstance(obj, Sequence) and not isinstance(obj, str): - return [transform(item) for item in obj] - elif isinstance(obj, str) and ( - obj.startswith("https:") or obj.startswith("data:") - ): - return FileOutput(obj, client) - return obj - - return transform(value) - - __all__: List = [] diff --git a/tests/test_run.py b/tests/test_run.py index 5ddf9918..b6b05c01 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,11 +1,11 @@ import asyncio import sys +from typing import cast import httpx import pytest import respx -from typing import cast import replicate from replicate.client import Client from replicate.exceptions import ModelError, ReplicateError From 4687e2815c25b4979d9362abf0f4b3709f3fb361 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:07:39 -0700 Subject: [PATCH 10/16] Move FileOutput to helpers module Signed-off-by: Mattt Zmuda --- replicate/helpers.py | 43 ++++++++++++++++++++++++++++++++++++++++--- replicate/stream.py | 44 ++------------------------------------------ tests/test_run.py | 2 +- 3 files changed, 43 insertions(+), 46 deletions(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index 8edc0404..11aefaaf 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -4,7 +4,9 @@ from collections.abc import Mapping, Sequence from pathlib import Path from types import GeneratorType -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional + +import httpx if TYPE_CHECKING: from replicate.client import Client @@ -111,9 +113,44 @@ def base64_encode_file(file: io.IOBase) -> str: return f"data:{mime_type};base64,{encoded_body}" -def transform_output(value: Any, client: "Client") -> Any: - from replicate.stream import FileOutput # pylint: disable=import-outside-toplevel +class FileOutput(httpx.ByteStream, httpx.AsyncByteStream): + url: str + client: "Client" + + def __init__(self, url: str, client: "Client"): + self.url = url + self.client = client + + def read(self) -> bytes: + with self.client._client.stream("GET", self.url) as response: + response.raise_for_status() + return response.read() + + def __iter__(self) -> Iterator[bytes]: + with self.client._client.stream("GET", self.url) as response: + response.raise_for_status() + for chunk in response.iter_bytes(): + yield chunk + + async def aread(self) -> bytes: + async with self.client._async_client.stream("GET", self.url) as response: + response.raise_for_status() + return await response.aread() + async def __aiter__(self) -> AsyncIterator[bytes]: + async with self.client._async_client.stream("GET", self.url) as response: + response.raise_for_status() + async for chunk in response.aiter_bytes(): + yield chunk + + def __str__(self) -> str: + return self.url + + def __repr__(self) -> str: + return self.url + + +def transform_output(value: Any, client: "Client") -> Any: def transform(obj: Any) -> Any: if isinstance(obj, Mapping): return {k: transform(v) for k, v in obj.items()} diff --git a/replicate/stream.py b/replicate/stream.py index d0b9c2f1..3472799e 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -1,6 +1,3 @@ -import io -import base64 -import httpx from enum import Enum from typing import ( TYPE_CHECKING, @@ -12,9 +9,9 @@ Optional, Union, ) -from contextlib import asynccontextmanager, contextmanager -from typing_extensions import Unpack +import httpx +from typing_extensions import Unpack from replicate import identifier from replicate.exceptions import ReplicateError @@ -33,43 +30,6 @@ from replicate.version import Version -class FileOutput(httpx.ByteStream, httpx.AsyncByteStream): - url: str - client: "Client" - - def __init__(self, url: str, client: "Client"): - self.url = url - self.client = client - - def read(self) -> bytes: - with self.client._client.stream("GET", self.url) as response: - response.raise_for_status() - return response.read() - - def __iter__(self) -> Iterator[bytes]: - with self.client._client.stream("GET", self.url) as response: - response.raise_for_status() - for chunk in response.iter_bytes(): - yield chunk - - async def aread(self) -> bytes: - async with self.client._async_client.stream("GET", self.url) as response: - response.raise_for_status() - return await response.aread() - - async def __aiter__(self) -> AsyncIterator[bytes]: - async with self.client._async_client.stream("GET", self.url) as response: - response.raise_for_status() - async for chunk in response.aiter_bytes(): - yield chunk - - def __str__(self) -> str: - return self.url - - def __repr__(self) -> str: - return self.url - - class ServerSentEvent(pydantic.BaseModel): # type: ignore """ A server-sent event. diff --git a/tests/test_run.py b/tests/test_run.py index b6b05c01..1ef77fba 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -9,7 +9,7 @@ import replicate from replicate.client import Client from replicate.exceptions import ModelError, ReplicateError -from replicate.stream import FileOutput +from replicate.helpers import FileOutput @pytest.mark.vcr("run.yaml") From 46d69d1d1dc754606bd394647248d3ceed089327 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:14:28 -0700 Subject: [PATCH 11/16] Inherit from SyncByteStream instead of ByteStream SyncByteStream is an abstract class; ByteStream is a concrete class that inherits abstracts sync and async byte stream classes Signed-off-by: Mattt Zmuda --- replicate/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index 11aefaaf..eb40db03 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -113,7 +113,7 @@ def base64_encode_file(file: io.IOBase) -> str: return f"data:{mime_type};base64,{encoded_body}" -class FileOutput(httpx.ByteStream, httpx.AsyncByteStream): +class FileOutput(httpx.SyncByteStream, httpx.AsyncByteStream): url: str client: "Client" From 1ae54d52d441ddf908baaec19b1999e2b5bf9fd8 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:14:42 -0700 Subject: [PATCH 12/16] Fix warnings Signed-off-by: Mattt Zmuda --- replicate/helpers.py | 5 ++--- replicate/model.py | 2 +- replicate/training.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index eb40db03..c93b9358 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -117,7 +117,7 @@ class FileOutput(httpx.SyncByteStream, httpx.AsyncByteStream): url: str client: "Client" - def __init__(self, url: str, client: "Client"): + def __init__(self, url: str, client: "Client") -> None: self.url = url self.client = client @@ -129,8 +129,7 @@ def read(self) -> bytes: def __iter__(self) -> Iterator[bytes]: with self.client._client.stream("GET", self.url) as response: response.raise_for_status() - for chunk in response.iter_bytes(): - yield chunk + yield from response.iter_bytes() async def aread(self) -> bytes: async with self.client._async_client.stream("GET", self.url) as response: diff --git a/replicate/model.py b/replicate/model.py index 9847ce4b..ba5e1113 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -3,8 +3,8 @@ from typing_extensions import NotRequired, TypedDict, Unpack, deprecated from replicate.exceptions import ReplicateException -from replicate.identifier import ModelVersionIdentifier from replicate.helpers import async_encode_json, encode_json +from replicate.identifier import ModelVersionIdentifier from replicate.pagination import Page from replicate.prediction import ( Prediction, diff --git a/replicate/training.py b/replicate/training.py index b03779a3..28e28b4a 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -13,8 +13,8 @@ from typing_extensions import NotRequired, Unpack -from replicate.identifier import ModelVersionIdentifier from replicate.helpers import async_encode_json, encode_json +from replicate.identifier import ModelVersionIdentifier from replicate.model import Model from replicate.pagination import Page from replicate.resource import Namespace, Resource From 024916e94ed4853725ed88a85de57c9e6975f229 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:15:19 -0700 Subject: [PATCH 13/16] Remove custom implementation of __repr__ Signed-off-by: Mattt Zmuda --- replicate/helpers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index c93b9358..ef68ec81 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -145,9 +145,6 @@ async def __aiter__(self) -> AsyncIterator[bytes]: def __str__(self) -> str: return self.url - def __repr__(self) -> str: - return self.url - def transform_output(value: Any, client: "Client") -> Any: def transform(obj: Any) -> Any: From fe2323af9406e16df882cd34adda0fffcacbedfb Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:35:02 -0700 Subject: [PATCH 14/16] Add docstrings Signed-off-by: Mattt Zmuda --- replicate/helpers.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/replicate/helpers.py b/replicate/helpers.py index ef68ec81..c8518857 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -114,8 +114,20 @@ def base64_encode_file(file: io.IOBase) -> str: class FileOutput(httpx.SyncByteStream, httpx.AsyncByteStream): + """ + An object that can be used to read the contents of an output file + created by running a Replicate model. + """ + url: str + """ + The file URL. + """ + client: "Client" + """ + A Replicate client used to download the file. + """ def __init__(self, url: str, client: "Client") -> None: self.url = url @@ -147,6 +159,10 @@ def __str__(self) -> str: def transform_output(value: Any, client: "Client") -> Any: + """ + Transform the output of a prediction to a `FileOutput` object if it's a URL. + """ + def transform(obj: Any) -> Any: if isinstance(obj, Mapping): return {k: transform(v) for k, v in obj.items()} From 04be83f99cdd8b2b8a1ea8891b133a1b0ed50989 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 11 Sep 2024 11:39:12 -0700 Subject: [PATCH 15/16] Rename FileOutput.client to ._client Signed-off-by: Mattt Zmuda --- replicate/helpers.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index c8518857..1c87eff4 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -124,32 +124,29 @@ class FileOutput(httpx.SyncByteStream, httpx.AsyncByteStream): The file URL. """ - client: "Client" - """ - A Replicate client used to download the file. - """ + _client: "Client" def __init__(self, url: str, client: "Client") -> None: self.url = url - self.client = client + self._client = client def read(self) -> bytes: - with self.client._client.stream("GET", self.url) as response: + with self._client._client.stream("GET", self.url) as response: response.raise_for_status() return response.read() def __iter__(self) -> Iterator[bytes]: - with self.client._client.stream("GET", self.url) as response: + with self._client._client.stream("GET", self.url) as response: response.raise_for_status() yield from response.iter_bytes() async def aread(self) -> bytes: - async with self.client._async_client.stream("GET", self.url) as response: + async with self._client._async_client.stream("GET", self.url) as response: response.raise_for_status() return await response.aread() async def __aiter__(self) -> AsyncIterator[bytes]: - async with self.client._async_client.stream("GET", self.url) as response: + async with self._client._async_client.stream("GET", self.url) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk From c0a4b138ed4c04bbd6da7bcd4837fa87b55a0757 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Thu, 12 Sep 2024 11:34:11 +0100 Subject: [PATCH 16/16] Add support to FileOutput to read data-uris --- replicate/helpers.py | 16 +++++++++ tests/test_run.py | 81 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/replicate/helpers.py b/replicate/helpers.py index 1c87eff4..e0bada5d 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -131,21 +131,37 @@ def __init__(self, url: str, client: "Client") -> None: self._client = client def read(self) -> bytes: + if self.url.startswith("data:"): + _, encoded = self.url.split(",", 1) + return base64.b64decode(encoded) + with self._client._client.stream("GET", self.url) as response: response.raise_for_status() return response.read() def __iter__(self) -> Iterator[bytes]: + if self.url.startswith("data:"): + yield self.read() + return + with self._client._client.stream("GET", self.url) as response: response.raise_for_status() yield from response.iter_bytes() async def aread(self) -> bytes: + if self.url.startswith("data:"): + _, encoded = self.url.split(",", 1) + return base64.b64decode(encoded) + async with self._client._async_client.stream("GET", self.url) as response: response.raise_for_status() return await response.aread() async def __aiter__(self) -> AsyncIterator[bytes]: + if self.url.startswith("data:"): + yield await self.aread() + return + async with self._client._async_client.stream("GET", self.url) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): diff --git a/tests/test_run.py b/tests/test_run.py index 1ef77fba..11fde976 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -426,3 +426,84 @@ def prediction_with_status( assert output1.read() == b"Hello," assert output2.read() == b" world!" + + +@pytest.mark.asyncio +async def test_run_with_file_output_data_uri(mock_replicate_api_token): + def prediction_with_status( + status: str, output: str | list[str] | None = None + ) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=prediction_with_status("processing"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json=prediction_with_status( + "succeeded", + "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==", + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", + "created_at": "2024-07-18T00:35:56.210272Z", + "cog_version": "0.9.10", + "openapi_schema": { + "openapi": "3.0.2", + }, + }, + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + output = cast( + FileOutput, + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + use_file_output=True, + ), + ) + + assert output.url == "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==" + assert output.read() == b"Hello, world!" + for chunk in output: + assert chunk == b"Hello, world!" + + assert await output.aread() == b"Hello, world!" + async for chunk in output: + assert chunk == b"Hello, world!"