diff --git a/.gitignore b/.gitignore index 86feca03..180b8005 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Virtualenv +.venv + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/replicate/base_model.py b/replicate/base_model.py index 51fadf21..30d1bed4 100644 --- a/replicate/base_model.py +++ b/replicate/base_model.py @@ -21,3 +21,11 @@ def reload(self): new_model = self._collection.get(self.id) for k, v in new_model.dict().items(): setattr(self, k, v) + + async def reload_async(self): + """ + Load this object from the server again. + """ + new_model = await self._collection.get_async(self.id) + for k, v in new_model.dict().items(): + setattr(self, k, v) diff --git a/replicate/client.py b/replicate/client.py index be40179d..90d09935 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -1,6 +1,7 @@ import os from json import JSONDecodeError +import httpx import requests from requests.adapters import HTTPAdapter, Retry @@ -21,6 +22,9 @@ def __init__(self, api_token=None) -> None: ) self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5")) + max_retries: int = 5 + self.httpx_transport = httpx.AsyncHTTPTransport(retries=max_retries) + # TODO: make thread safe self.read_session = requests.Session() read_retries = Retry( @@ -78,6 +82,32 @@ def _request(self, method: str, path: str, **kwargs): raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") return resp + async def _request_async(self, method: str, path: str, **kwargs): + # from requests.Session + if method in ["GET", "OPTIONS"]: + kwargs.setdefault("allow_redirects", True) + if method in ["HEAD"]: + kwargs.setdefault("allow_redirects", False) + kwargs.setdefault("headers", {}) + kwargs["headers"].update(self._headers()) + + async with httpx.AsyncClient( + follow_redirects=True, + transport=self.httpx_transport, + ) as client: + if "allow_redirects" in kwargs: + kwargs.pop("allow_redirects") + + resp = await client.request(method, self.base_url + path, **kwargs) + + if 400 <= resp.status_code < 600: + try: + raise ReplicateError(resp.json()["detail"]) + except (JSONDecodeError, KeyError): + pass + raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}") + return resp + def _headers(self): return { "Authorization": f"Token {self._api_token()}", diff --git a/replicate/collection.py b/replicate/collection.py index 1b9c1368..1340a75c 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -21,6 +21,15 @@ def get(self, key): def create(self, attrs=None): raise NotImplementedError + async def list_async(self): + raise NotImplementedError + + async def get_async(self, key): + raise NotImplementedError + + async def create_async(self, attrs=None): + raise NotImplementedError + def prepare_model(self, attrs): """ Create a model from a set of attributes. diff --git a/replicate/files.py b/replicate/files.py index 82f70c89..8520b1ae 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -3,22 +3,16 @@ import mimetypes import os +import httpx import requests -def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: +def to_data_url(fh: io.IOBase) -> str: """ Lifted straight from cog.files """ fh.seek(0) - if output_file_prefix is not None: - name = getattr(fh, "name", "output") - url = output_file_prefix + os.path.basename(name) - resp = requests.put(url, files={"file": fh}) - resp.raise_for_status() - return url - b = fh.read() # The file handle is strings, not bytes if isinstance(b, str): @@ -31,3 +25,60 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: mime_type = "application/octet-stream" s = encoded_body.decode("utf-8") return f"data:{mime_type};base64,{s}" + + +def upload_file_to_server(fh: io.IOBase, output_file_prefix: str) -> str: + """ + Lifted straight from cog.files + """ + fh.seek(0) + + name = getattr(fh, "name", "output") + url = output_file_prefix + os.path.basename(name) + resp = requests.put(url, files={"file": fh}) + resp.raise_for_status() + return url + + +def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str: + """ + Lifted straight from cog.files + """ + fh.seek(0) + + if output_file_prefix is not None: + url = upload_file_to_server(fh, output_file_prefix) + return url + + data_url: str = to_data_url(fh) + return data_url + + +async def upload_file_to_server_async(fh: io.IOBase, output_file_prefix: str) -> str: + """ + Lifted straight from cog.files + """ + fh.seek(0) + + name = getattr(fh, "name", "output") + url = output_file_prefix + os.path.basename(name) + + # httpx does not follow redirects by default + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.put(url, files={"file": fh}) + + return url + + +async def upload_file_async(fh: io.IOBase, output_file_prefix: str = None) -> str: + """ + Lifted straight from cog.files + """ + fh.seek(0) + + if output_file_prefix is not None: + url = await upload_file_to_server_async(fh, output_file_prefix) + return url + + data_url: str = to_data_url(fh) + return data_url diff --git a/replicate/prediction.py b/replicate/prediction.py index 7edddc70..2ca24c29 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,3 +1,4 @@ +import asyncio import time from typing import Any, Dict, Iterator, List, Optional @@ -27,6 +28,34 @@ def wait(self): time.sleep(self._client.poll_interval) self.reload() + async def wait_async(self): + """Wait for prediction to finish.""" + while self.status not in ["succeeded", "failed", "canceled"]: + await asyncio.sleep(0.5) + await self.reload_async() + + async def output_iterator_async(self) -> Iterator[Any]: + # TODO: check output is list + previous_output = self.output or [] + while self.status not in ["succeeded", "failed", "canceled"]: + output = self.output or [] + new_output = output[len(previous_output) :] + for output in new_output: + yield output + previous_output = output + + await asyncio.sleep(0.5) + await self.reload_async() + + if self.status == "failed": + raise ModelError(self.error) + + output = self.output or [] + new_output = output[len(previous_output) :] + for output in new_output: + yield output + + def output_iterator(self) -> Iterator[Any]: # TODO: check output is list previous_output = self.output or [] @@ -51,6 +80,10 @@ def cancel(self): """Cancel a currently running prediction""" self._client._request("POST", f"/v1/predictions/{self.id}/cancel") + async def cancel_async(self): + """Cancel a currently running prediction""" + await self._client._request_async("POST", f"/v1/predictions/{self.id}/cancel") + class PredictionCollection(Collection): model = Prediction @@ -93,3 +126,43 @@ def list(self) -> List[Prediction]: # HACK: resolve this? make it lazy somehow? del prediction["version"] return [self.prepare_model(obj) for obj in predictions] + + async def create_async( + self, + version: Version, + input: Dict[str, Any], + webhook_completed: Optional[str] = None, + ) -> Prediction: + input = encode_json(input, upload_file=upload_file) + body = { + "version": version.id, + "input": input, + } + if webhook_completed is not None: + body["webhook_completed"] = webhook_completed + + resp = await self._client._request_async( + "POST", + "/v1/predictions", + json=body, + ) + + obj = resp.json() + obj["version"] = version + return self.prepare_model(obj) + + async def get_async(self, id: str) -> Prediction: + resp = await self._client._request_async("GET", f"/v1/predictions/{id}") + obj = resp.json() + # HACK: resolve this? make it lazy somehow? + del obj["version"] + return self.prepare_model(obj) + + async def list_async(self) -> List[Prediction]: + resp = await self._client._request_async("GET", f"/v1/predictions") + # TODO: paginate + predictions = resp.json()["results"] + for prediction in predictions: + # HACK: resolve this? make it lazy somehow? + del prediction["version"] + return [self.prepare_model(obj) for obj in predictions] diff --git a/replicate/version.py b/replicate/version.py index cc4cbd0c..bfdf9a28 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -31,6 +31,25 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: raise ModelError(prediction.error) return prediction.output + + async def predict_async(self, **kwargs) -> Union[Any, Iterator[Any]]: + # TODO: support args + prediction = await self._client.predictions.create_async(version=self, input=kwargs) + # Return an iterator of the output + # FIXME: might just be a list, not an iterator. I wonder if we should differentiate? + schema = self.get_transformed_schema() + output = schema["components"]["schemas"]["Output"] + if ( + output.get("type") == "array" + and output.get("x-cog-array-type") == "iterator" + ): + return prediction.output_iterator_async() + + await prediction.wait_async() + if prediction.status == "failed": + raise ModelError(prediction.error) + return prediction.output + def get_transformed_schema(self): schema = self.openapi_schema schema = make_schema_backwards_compatible(schema, self.cog_version) @@ -44,6 +63,25 @@ def __init__(self, client, model): super().__init__(client=client) self._model = model + # doesn't exist yet + async def get_async(self, id: str) -> Version: + """ + Get a specific version. + """ + resp = await self._client._request_async( + "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}" + ) + return self.prepare_model(resp.json()) + + async def list_async(self) -> List[Version]: + """ + Return a list of all versions for a model. + """ + resp = await self._client._request_async( + "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions" + ) + return [self.prepare_model(obj) for obj in resp.json()["results"]] + # doesn't exist yet def get(self, id: str) -> Version: """ diff --git a/requirements-dev.txt b/requirements-dev.txt index b3d9b4fa..a8b62281 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ packaging==21.3 pytest==7.1.2 +pytest-asyncio==0.21.0 responses==0.21.0 diff --git a/setup.py b/setup.py index 826f2483..62971aa6 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,6 @@ license="BSD", url="https://github.com/replicate/replicate-python", python_requires=">=3.6", - install_requires=["requests", "pydantic", "packaging"], + install_requires=["requests", "pydantic", "packaging", "httpx"], classifiers=[], ) diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 00000000..b70e9243 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,20 @@ +import pytest + +import replicate + +@pytest.mark.asyncio +async def test_async_client(): + model = replicate.models.get("creatorrr/instructor-large") + version = await model.versions.get_async("bd2701dac1aea9d598bda71e6ae56b204287c0a79e2cadf96b1393127d044495") + + inputs = { + # Text to embed + 'text': "Hello world! How are you doing?", + + # Embedding instruction + 'instruction': "Represent the following text", + } + + output = await version.predict_async(**inputs) + + assert output["result"]