Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ Some models, like [methexis-inc/img2prompt](https://replicate.com/methexis-inc/i
"an astronaut riding a horse"
```

> [!NOTE]
> You can also use the Replicate client asynchronously by prepending `async_` to the method name.
>
> Here's an example of how to run several predictions concurrently and wait for them all to complete:
>
> ```python
> import asyncio
> import replicate
>
> # https://replicate.com/stability-ai/sdxl
> model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
> prompts = [
> f"A chariot pulled by a team of {count} rainbow unicorns"
> for count in ["two", "four", "six", "eight"]
> ]
>
> async with asyncio.TaskGroup() as tg:
> tasks = [
> tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
> for prompt in prompts
> ]
>
> results = await asyncio.gather(*tasks)
> print(results)
> ```

## Run a model in the background

You can start a model and run it in the background:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ packages = ["replicate"]
[tool.mypy]
plugins = "pydantic.mypy"
exclude = ["tests/"]
enable_incomplete_feature = ["Unpack"]

[tool.pylint.main]
disable = [
Expand All @@ -48,6 +49,7 @@ disable = [
"W0622", # Redefining built-in
"R0903", # Too few public methods
]
good-names = ["id"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does good-names mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're good names, Zeke.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(But seriously, this is to stop pylint from complaining about variables named id)


[tool.ruff]
select = [
Expand Down
3 changes: 3 additions & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from replicate.client import Client

default_client = Client()

run = default_client.run
async_run = default_client.async_run

collections = default_client.collections
hardware = default_client.hardware
deployments = default_client.deployments
Expand Down
177 changes: 99 additions & 78 deletions replicate/client.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,37 @@
import os
import random
import re
import time
from datetime import datetime
from typing import (
Any,
Dict,
Iterable,
Iterator,
Mapping,
Optional,
Type,
Union,
)

import httpx
from typing_extensions import Unpack

from replicate.__about__ import __version__
from replicate.collection import Collections
from replicate.deployment import Deployments
from replicate.exceptions import ModelError, ReplicateError
from replicate.hardware import Hardwares
from replicate.exceptions import ReplicateError
from replicate.hardware import HardwareNamespace as Hardware
from replicate.model import Models
from replicate.prediction import Predictions
from replicate.schema import make_schema_backwards_compatible
from replicate.run import async_run, run
from replicate.training import Trainings
from replicate.version import Version


class Client:
"""A Replicate API client library"""

__client: Optional[httpx.Client] = None
__async_client: Optional[httpx.AsyncClient] = None

def __init__(
self,
Expand All @@ -42,46 +44,45 @@ def __init__(
super().__init__()

self._api_token = api_token
self._base_url = (
base_url
or os.environ.get("REPLICATE_API_BASE_URL")
or "https://api.replicate.com"
)
self._timeout = timeout or httpx.Timeout(
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
)
self._transport = kwargs.pop("transport", httpx.HTTPTransport())
self._base_url = base_url
self._timeout = timeout
self._client_kwargs = kwargs

self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

@property
def _client(self) -> httpx.Client:
if self.__client is None:
headers = {
"User-Agent": f"replicate-python/{__version__}",
}

api_token = self._api_token or os.environ.get("REPLICATE_API_TOKEN")

if api_token is not None and api_token != "":
headers["Authorization"] = f"Token {api_token}"

self.__client = httpx.Client(
if not self.__client:
self.__client = _build_httpx_client(
httpx.Client,
self._api_token,
self._base_url,
self._timeout,
**self._client_kwargs,
base_url=self._base_url,
headers=headers,
timeout=self._timeout,
transport=RetryTransport(wrapped_transport=self._transport),
)
) # type: ignore[assignment]
return self.__client # type: ignore[return-value]

return self.__client
@property
def _async_client(self) -> httpx.AsyncClient:
if not self.__async_client:
self.__async_client = _build_httpx_client(
httpx.AsyncClient,
self._api_token,
self._base_url,
self._timeout,
**self._client_kwargs,
) # type: ignore[assignment]
return self.__async_client # type: ignore[return-value]

def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
resp = self._client.request(method, path, **kwargs)
_raise_for_status(resp)

return resp

if 400 <= resp.status_code < 600:
raise ReplicateError(resp.json()["detail"])
async def _async_request(self, method: str, path: str, **kwargs) -> httpx.Response:
resp = await self._async_client.request(method, path, **kwargs)
_raise_for_status(resp)

return resp

Expand All @@ -100,11 +101,11 @@ def deployments(self) -> Deployments:
return Deployments(client=self)

@property
def hardware(self) -> Hardwares:
def hardware(self) -> Hardware:
"""
Namespace for operations related to hardware.
"""
return Hardwares(client=self)
return Hardware(client=self)

@property
def models(self) -> Models:
Expand All @@ -127,55 +128,29 @@ def trainings(self) -> Trainings:
"""
return Trainings(client=self)

def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401
def run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output.

Args:
model_version: The model version to run, in the format `owner/name:version`
kwargs: The input to the model, as a dictionary
Returns:
The output of the model
"""
# Split model_version into owner, name, version in format owner/name:version
match = re.match(
r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", model_version
)
if not match:
raise ReplicateError(
f"Invalid model_version: {model_version}. Expected format: owner/name:version"
)

owner = match.group("owner")
name = match.group("name")
version_id = match.group("version")

prediction = self.predictions.create(version=version_id, **kwargs)
return run(self, ref, input, **params)

if owner and name:
# FIXME: There should be a method for fetching a version without first fetching its model
resp = self._request(
"GET", f"/v1/models/{owner}/{name}/versions/{version_id}"
)
version = Version(**resp.json())

# Return an iterator of the output
schema = make_schema_backwards_compatible(
version.openapi_schema, version.cog_version
)
output = schema["components"]["schemas"]["Output"]
if (
output.get("type") == "array"
and output.get("x-cog-array-type") == "iterator"
):
return prediction.output_iterator()

prediction.wait()

if prediction.status == "failed":
raise ModelError(prediction.error)
async def async_run(
self,
ref: str,
input: Optional[Dict[str, Any]] = None,
**params: Unpack["Predictions.CreatePredictionParams"],
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
"""
Run a model and wait for its output asynchronously.
"""

return prediction.output
return await async_run(self, ref, input, **params)


# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
Expand Down Expand Up @@ -305,3 +280,49 @@ async def aclose(self) -> None:

def close(self) -> None:
self._wrapped_transport.close() # type: ignore


def _build_httpx_client(
client_type: Type[Union[httpx.Client, httpx.AsyncClient]],
api_token: Optional[str] = None,
base_url: Optional[str] = None,
timeout: Optional[httpx.Timeout] = None,
**kwargs,
) -> Union[httpx.Client, httpx.AsyncClient]:
headers = {
"User-Agent": f"replicate-python/{__version__}",
}

if (
api_token := api_token or os.environ.get("REPLICATE_API_TOKEN")
) and api_token != "":
headers["Authorization"] = f"Token {api_token}"

base_url = (
base_url or os.environ.get("REPLICATE_BASE_URL") or "https://api.replicate.com"
)
if base_url == "":
base_url = "https://api.replicate.com"

timeout = timeout or httpx.Timeout(
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
)

transport = kwargs.pop("transport", None) or (
httpx.HTTPTransport()
if client_type is httpx.Client
else httpx.AsyncHTTPTransport()
)

return client_type(
base_url=base_url,
headers=headers,
timeout=timeout,
transport=RetryTransport(wrapped_transport=transport), # type: ignore[arg-type]
**kwargs,
)


def _raise_for_status(resp: httpx.Response) -> None:
if 400 <= resp.status_code < 600:
raise ReplicateError(resp.json()["detail"])
Loading