Skip to content
Merged
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ disable = [
"R0801", # Similar lines in N files
"W0212", # Access to a protected member
"W0622", # Redefining built-in
"R0903", # Too few public methods
]

[tool.ruff]
Expand Down
2 changes: 1 addition & 1 deletion replicate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .client import Client
from replicate.client import Client

default_client = Client()
run = default_client.run
Expand Down
9 changes: 0 additions & 9 deletions replicate/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,3 @@ class BaseModel(pydantic.BaseModel):

_client: "Client" = pydantic.PrivateAttr()
_collection: "Collection" = pydantic.PrivateAttr()

def reload(self) -> None:
"""
Load this object from the server again.
"""

new_model = self._collection.get(self.id) # pylint: disable=no-member
for k, v in new_model.dict().items(): # pylint: disable=invalid-name
setattr(self, k, v)
22 changes: 13 additions & 9 deletions replicate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@

import httpx

from .__about__ import __version__
from .deployment import DeploymentCollection
from .exceptions import ModelError, ReplicateError
from .model import ModelCollection
from .prediction import PredictionCollection
from .training import TrainingCollection
from .version import Version
from replicate.__about__ import __version__
from replicate.deployment import DeploymentCollection
from replicate.exceptions import ModelError, ReplicateError
from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
from replicate.schema import make_schema_backwards_compatible
from replicate.training import TrainingCollection
from replicate.version import Version


class Client:
Expand Down Expand Up @@ -143,7 +144,9 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noq
version = Version(**resp.json())

# Return an iterator of the output
schema = version.get_transformed_schema()
schema = make_schema_backwards_compatible(
version.openapi_schema, version.cog_version
)
output = schema["components"]["schemas"]["Output"]
if (
output.get("type") == "array"
Expand Down Expand Up @@ -175,9 +178,10 @@ class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
)
MAX_BACKOFF_WAIT = 60

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport],
*,
max_attempts: int = 10,
max_backoff_wait: float = MAX_BACKOFF_WAIT,
backoff_factor: float = 0.1,
Expand Down
26 changes: 5 additions & 21 deletions replicate/collection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast
from typing import TYPE_CHECKING, Dict, Generic, TypeVar, Union, cast

if TYPE_CHECKING:
from replicate.client import Client
Expand All @@ -15,29 +15,13 @@ class Collection(abc.ABC, Generic[Model]):
A base class for representing objects of a particular type on the server.
"""

_client: "Client"
model: Model

def __init__(self, client: "Client") -> None:
self._client = client

@property
@abc.abstractmethod
def model(self) -> Model: # pylint: disable=missing-function-docstring
pass

@abc.abstractmethod
def list(self) -> List[Model]: # pylint: disable=missing-function-docstring
pass

@abc.abstractmethod
def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring
pass

@abc.abstractmethod
def create( # pylint: disable=missing-function-docstring
self, *args, **kwargs
) -> Model:
pass

def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
"""
Create a model from a set of attributes.
"""
Expand Down
110 changes: 24 additions & 86 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload

from typing_extensions import Unpack
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.prediction import Prediction, PredictionCollection
from replicate.prediction import Prediction

if TYPE_CHECKING:
from replicate.client import Client
Expand All @@ -17,6 +15,8 @@ class Deployment(BaseModel):
A deployment of a model hosted on Replicate.
"""

_collection: "DeploymentCollection"

username: str
"""
The name of the user or organization that owns the deployment.
Expand All @@ -43,15 +43,6 @@ class DeploymentCollection(Collection):

model = Deployment

def list(self) -> List[Deployment]:
"""
List deployments.

Raises:
NotImplementedError: This method is not implemented.
"""
raise NotImplementedError()

def get(self, name: str) -> Deployment:
"""
Get a deployment by name.
Expand All @@ -65,89 +56,35 @@ def get(self, name: str) -> Deployment:
# TODO: fetch model from server
# TODO: support permanent IDs
username, name = name.split("/")
return self.prepare_model({"username": username, "name": name})

def create(
self,
*args,
**kwargs,
) -> Deployment:
"""
Create a deployment.

Raises:
NotImplementedError: This method is not implemented.
"""
raise NotImplementedError()
return self._prepare_model({"username": username, "name": name})

def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
def _prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
if isinstance(attrs, BaseModel):
attrs.id = f"{attrs.username}/{attrs.name}"
elif isinstance(attrs, dict):
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
return super().prepare_model(attrs)
return super()._prepare_model(attrs)


class DeploymentPredictionCollection(Collection):
"""
Namespace for operations related to predictions in a deployment.
"""

model = Prediction

def __init__(self, client: "Client", deployment: Deployment) -> None:
super().__init__(client=client)
self._deployment = deployment

def list(self) -> List[Prediction]:
"""
List predictions in a deployment.

Raises:
NotImplementedError: This method is not implemented.
"""
raise NotImplementedError()

def get(self, id: str) -> Prediction:
"""
Get a prediction by ID.

Args:
id: The ID of the prediction.
Returns:
Prediction: The prediction object.
"""

resp = self._client._request("GET", f"/v1/predictions/{id}")
obj = resp.json()
# HACK: resolve this? make it lazy somehow?
del obj["version"]
return self.prepare_model(obj)

@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
def create(
self,
input: Dict[str, Any],
*,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
) -> Prediction:
...

@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
*,
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
) -> Prediction:
...

def create(
self,
*args,
**kwargs: Unpack[PredictionCollection.CreateParams], # type: ignore[misc]
) -> Prediction:
"""
Create a new prediction with the deployment.
Expand All @@ -163,20 +100,21 @@ def create(
Prediction: The created prediction object.
"""

input = args[0] if len(args) > 0 else kwargs.get("input")
if input is None:
raise ValueError(
"An input must be provided as a positional or keyword argument."
)

body = {
"input": encode_json(input, upload_file=upload_file),
}

for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]:
value = kwargs.get(key)
if value is not None:
body[key] = value
if webhook is not None:
body["webhook"] = webhook

if webhook_completed is not None:
body["webhook_completed"] = webhook_completed

if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter

if stream is not None:
body["stream"] = stream

resp = self._client._request(
"POST",
Expand All @@ -186,4 +124,4 @@ def create(
obj = resp.json()
obj["deployment"] = self._deployment
del obj["version"]
return self.prepare_model(obj)
return self._prepare_model(obj)
2 changes: 1 addition & 1 deletion replicate/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
class ReplicateException(Exception):
pass
"""A base class for all Replicate exceptions."""


class ModelError(ReplicateException):
Expand Down
32 changes: 15 additions & 17 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class Model(BaseModel):
A machine learning model hosted on Replicate.
"""

_collection: "ModelCollection"

url: str
"""
The URL of the model.
Expand Down Expand Up @@ -105,6 +107,15 @@ def versions(self) -> VersionCollection:

return VersionCollection(client=self._client, model=self)

def reload(self) -> None:
"""
Load this object from the server.
"""

obj = self._collection.get(f"{self.owner}/{self.name}") # pylint: disable=no-member
for name, value in obj.dict().items():
setattr(self, name, value)


class ModelCollection(Collection):
"""
Expand All @@ -124,7 +135,7 @@ def list(self) -> List[Model]:
resp = self._client._request("GET", "/v1/models")
# TODO: paginate
models = resp.json()["results"]
return [self.prepare_model(obj) for obj in models]
return [self._prepare_model(obj) for obj in models]

def get(self, key: str) -> Model:
"""
Expand All @@ -137,22 +148,9 @@ def get(self, key: str) -> Model:
"""

resp = self._client._request("GET", f"/v1/models/{key}")
return self.prepare_model(resp.json())

def create(
self,
*args,
**kwargs,
) -> Model:
"""
Create a model.

Raises:
NotImplementedError: This method is not implemented.
"""
raise NotImplementedError()
return self._prepare_model(resp.json())

def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
if isinstance(attrs, BaseModel):
attrs.id = f"{attrs.owner}/{attrs.name}"
elif isinstance(attrs, dict):
Expand All @@ -165,7 +163,7 @@ def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
if "latest_version" in attrs and attrs["latest_version"] == {}:
attrs.pop("latest_version")

model = super().prepare_model(attrs)
model = super()._prepare_model(attrs)

if model.default_example is not None:
model.default_example._client = self._client
Expand Down
Loading