diff --git a/pyproject.toml b/pyproject.toml index ba13eab4..e0996c36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ disable = [ "R0801", # Similar lines in N files "W0212", # Access to a protected member "W0622", # Redefining built-in + "R0903", # Too few public methods ] [tool.ruff] diff --git a/replicate/__init__.py b/replicate/__init__.py index c8aaeb01..d7cfeb8d 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -1,4 +1,4 @@ -from .client import Client +from replicate.client import Client default_client = Client() run = default_client.run diff --git a/replicate/base_model.py b/replicate/base_model.py index 3a954b2c..9fac634b 100644 --- a/replicate/base_model.py +++ b/replicate/base_model.py @@ -19,12 +19,3 @@ class BaseModel(pydantic.BaseModel): _client: "Client" = pydantic.PrivateAttr() _collection: "Collection" = pydantic.PrivateAttr() - - def reload(self) -> None: - """ - Load this object from the server again. - """ - - new_model = self._collection.get(self.id) # pylint: disable=no-member - for k, v in new_model.dict().items(): # pylint: disable=invalid-name - setattr(self, k, v) diff --git a/replicate/client.py b/replicate/client.py index 8f023e47..dccc1433 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -14,13 +14,14 @@ import httpx -from .__about__ import __version__ -from .deployment import DeploymentCollection -from .exceptions import ModelError, ReplicateError -from .model import ModelCollection -from .prediction import PredictionCollection -from .training import TrainingCollection -from .version import Version +from replicate.__about__ import __version__ +from replicate.deployment import DeploymentCollection +from replicate.exceptions import ModelError, ReplicateError +from replicate.model import ModelCollection +from replicate.prediction import PredictionCollection +from replicate.schema import make_schema_backwards_compatible +from replicate.training import TrainingCollection +from replicate.version import Version class Client: @@ -143,7 +144,9 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noq version = Version(**resp.json()) # Return an iterator of the output - schema = version.get_transformed_schema() + schema = make_schema_backwards_compatible( + version.openapi_schema, version.cog_version + ) output = schema["components"]["schemas"]["Output"] if ( output.get("type") == "array" @@ -175,9 +178,10 @@ class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport): ) MAX_BACKOFF_WAIT = 60 - def __init__( + def __init__( # pylint: disable=too-many-arguments self, wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport], + *, max_attempts: int = 10, max_backoff_wait: float = MAX_BACKOFF_WAIT, backoff_factor: float = 0.1, diff --git a/replicate/collection.py b/replicate/collection.py index 19e8e6c1..b9b8c40a 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -1,5 +1,5 @@ import abc -from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast +from typing import TYPE_CHECKING, Dict, Generic, TypeVar, Union, cast if TYPE_CHECKING: from replicate.client import Client @@ -15,29 +15,13 @@ class Collection(abc.ABC, Generic[Model]): A base class for representing objects of a particular type on the server. """ + _client: "Client" + model: Model + def __init__(self, client: "Client") -> None: self._client = client - @property - @abc.abstractmethod - def model(self) -> Model: # pylint: disable=missing-function-docstring - pass - - @abc.abstractmethod - def list(self) -> List[Model]: # pylint: disable=missing-function-docstring - pass - - @abc.abstractmethod - def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring - pass - - @abc.abstractmethod - def create( # pylint: disable=missing-function-docstring - self, *args, **kwargs - ) -> Model: - pass - - def prepare_model(self, attrs: Union[Model, Dict]) -> Model: + def _prepare_model(self, attrs: Union[Model, Dict]) -> Model: """ Create a model from a set of attributes. """ diff --git a/replicate/deployment.py b/replicate/deployment.py index cfdd0da1..df9bb3c6 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -1,12 +1,10 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload - -from typing_extensions import Unpack +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from replicate.base_model import BaseModel from replicate.collection import Collection from replicate.files import upload_file from replicate.json import encode_json -from replicate.prediction import Prediction, PredictionCollection +from replicate.prediction import Prediction if TYPE_CHECKING: from replicate.client import Client @@ -17,6 +15,8 @@ class Deployment(BaseModel): A deployment of a model hosted on Replicate. """ + _collection: "DeploymentCollection" + username: str """ The name of the user or organization that owns the deployment. @@ -43,15 +43,6 @@ class DeploymentCollection(Collection): model = Deployment - def list(self) -> List[Deployment]: - """ - List deployments. - - Raises: - NotImplementedError: This method is not implemented. - """ - raise NotImplementedError() - def get(self, name: str) -> Deployment: """ Get a deployment by name. @@ -65,63 +56,28 @@ def get(self, name: str) -> Deployment: # TODO: fetch model from server # TODO: support permanent IDs username, name = name.split("/") - return self.prepare_model({"username": username, "name": name}) - - def create( - self, - *args, - **kwargs, - ) -> Deployment: - """ - Create a deployment. - - Raises: - NotImplementedError: This method is not implemented. - """ - raise NotImplementedError() + return self._prepare_model({"username": username, "name": name}) - def prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment: + def _prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment: if isinstance(attrs, BaseModel): attrs.id = f"{attrs.username}/{attrs.name}" elif isinstance(attrs, dict): attrs["id"] = f"{attrs['username']}/{attrs['name']}" - return super().prepare_model(attrs) + return super()._prepare_model(attrs) class DeploymentPredictionCollection(Collection): + """ + Namespace for operations related to predictions in a deployment. + """ + model = Prediction def __init__(self, client: "Client", deployment: Deployment) -> None: super().__init__(client=client) self._deployment = deployment - def list(self) -> List[Prediction]: - """ - List predictions in a deployment. - - Raises: - NotImplementedError: This method is not implemented. - """ - raise NotImplementedError() - - def get(self, id: str) -> Prediction: - """ - Get a prediction by ID. - - Args: - id: The ID of the prediction. - Returns: - Prediction: The prediction object. - """ - - resp = self._client._request("GET", f"/v1/predictions/{id}") - obj = resp.json() - # HACK: resolve this? make it lazy somehow? - del obj["version"] - return self.prepare_model(obj) - - @overload - def create( # pylint: disable=arguments-differ disable=too-many-arguments + def create( self, input: Dict[str, Any], *, @@ -129,25 +85,6 @@ def create( # pylint: disable=arguments-differ disable=too-many-arguments webhook_completed: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, stream: Optional[bool] = None, - ) -> Prediction: - ... - - @overload - def create( # pylint: disable=arguments-differ disable=too-many-arguments - self, - *, - input: Dict[str, Any], - webhook: Optional[str] = None, - webhook_completed: Optional[str] = None, - webhook_events_filter: Optional[List[str]] = None, - stream: Optional[bool] = None, - ) -> Prediction: - ... - - def create( - self, - *args, - **kwargs: Unpack[PredictionCollection.CreateParams], # type: ignore[misc] ) -> Prediction: """ Create a new prediction with the deployment. @@ -163,20 +100,21 @@ def create( Prediction: The created prediction object. """ - input = args[0] if len(args) > 0 else kwargs.get("input") - if input is None: - raise ValueError( - "An input must be provided as a positional or keyword argument." - ) - body = { "input": encode_json(input, upload_file=upload_file), } - for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]: - value = kwargs.get(key) - if value is not None: - body[key] = value + if webhook is not None: + body["webhook"] = webhook + + if webhook_completed is not None: + body["webhook_completed"] = webhook_completed + + if webhook_events_filter is not None: + body["webhook_events_filter"] = webhook_events_filter + + if stream is not None: + body["stream"] = stream resp = self._client._request( "POST", @@ -186,4 +124,4 @@ def create( obj = resp.json() obj["deployment"] = self._deployment del obj["version"] - return self.prepare_model(obj) + return self._prepare_model(obj) diff --git a/replicate/exceptions.py b/replicate/exceptions.py index 665c1b30..e1aa51c4 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -1,5 +1,5 @@ class ReplicateException(Exception): - pass + """A base class for all Replicate exceptions.""" class ModelError(ReplicateException): diff --git a/replicate/model.py b/replicate/model.py index 887e2bfc..babe9931 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -14,6 +14,8 @@ class Model(BaseModel): A machine learning model hosted on Replicate. """ + _collection: "ModelCollection" + url: str """ The URL of the model. @@ -105,6 +107,15 @@ def versions(self) -> VersionCollection: return VersionCollection(client=self._client, model=self) + def reload(self) -> None: + """ + Load this object from the server. + """ + + obj = self._collection.get(f"{self.owner}/{self.name}") # pylint: disable=no-member + for name, value in obj.dict().items(): + setattr(self, name, value) + class ModelCollection(Collection): """ @@ -124,7 +135,7 @@ def list(self) -> List[Model]: resp = self._client._request("GET", "/v1/models") # TODO: paginate models = resp.json()["results"] - return [self.prepare_model(obj) for obj in models] + return [self._prepare_model(obj) for obj in models] def get(self, key: str) -> Model: """ @@ -137,22 +148,9 @@ def get(self, key: str) -> Model: """ resp = self._client._request("GET", f"/v1/models/{key}") - return self.prepare_model(resp.json()) - - def create( - self, - *args, - **kwargs, - ) -> Model: - """ - Create a model. - - Raises: - NotImplementedError: This method is not implemented. - """ - raise NotImplementedError() + return self._prepare_model(resp.json()) - def prepare_model(self, attrs: Union[Model, Dict]) -> Model: + def _prepare_model(self, attrs: Union[Model, Dict]) -> Model: if isinstance(attrs, BaseModel): attrs.id = f"{attrs.owner}/{attrs.name}" elif isinstance(attrs, dict): @@ -165,7 +163,7 @@ def prepare_model(self, attrs: Union[Model, Dict]) -> Model: if "latest_version" in attrs and attrs["latest_version"] == {}: attrs.pop("latest_version") - model = super().prepare_model(attrs) + model = super()._prepare_model(attrs) if model.default_example is not None: model.default_example._client = self._client diff --git a/replicate/prediction.py b/replicate/prediction.py index 05cdec2b..9fc69385 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,9 +1,7 @@ import re import time from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, TypedDict, Union, overload - -from typing_extensions import Unpack +from typing import Any, Dict, Iterator, List, Optional, Union from replicate.base_model import BaseModel from replicate.collection import Collection @@ -18,6 +16,8 @@ class Prediction(BaseModel): A prediction made by a model hosted on Replicate. """ + _collection: "PredictionCollection" + id: str """The unique ID of the prediction.""" @@ -62,6 +62,10 @@ class Prediction(BaseModel): @dataclass class Progress: + """ + The progress of a prediction. + """ + percentage: float """The percentage of the prediction that has completed.""" @@ -109,6 +113,10 @@ def wait(self) -> None: self.reload() def output_iterator(self) -> Iterator[Any]: + """ + Return an iterator of the prediction output. + """ + # TODO: check output is list previous_output = self.output or [] while self.status not in ["succeeded", "failed", "canceled"]: @@ -133,22 +141,21 @@ def cancel(self) -> None: """ self._client._request("POST", f"/v1/predictions/{self.id}/cancel") # pylint: disable=no-member + def reload(self) -> None: + """ + Load this prediction from the server. + """ + + obj = self._collection.get(self.id) # pylint: disable=no-member + for name, value in obj.dict().items(): + setattr(self, name, value) + class PredictionCollection(Collection): """ Namespace for operations related to predictions. """ - class CreateParams(TypedDict): - """Parameters for creating a prediction.""" - - version: Union[Version, str] - input: Dict[str, Any] - webhook: Optional[str] - webhook_completed: Optional[str] - webhook_events_filter: Optional[List[str]] - stream: Optional[bool] - model = Prediction def list(self) -> List[Prediction]: @@ -165,9 +172,9 @@ def list(self) -> List[Prediction]: for prediction in predictions: # HACK: resolve this? make it lazy somehow? del prediction["version"] - return [self.prepare_model(obj) for obj in predictions] + return [self._prepare_model(obj) for obj in predictions] - def get(self, id: str) -> Prediction: + def get(self, id: str) -> Prediction: # pylint: disable=invalid-name """ Get a prediction by ID. @@ -181,10 +188,9 @@ def get(self, id: str) -> Prediction: obj = resp.json() # HACK: resolve this? make it lazy somehow? del obj["version"] - return self.prepare_model(obj) + return self._prepare_model(obj) - @overload - def create( # pylint: disable=arguments-differ disable=too-many-arguments + def create( self, version: Union[Version, str], input: Dict[str, Any], @@ -193,26 +199,6 @@ def create( # pylint: disable=arguments-differ disable=too-many-arguments webhook_completed: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, stream: Optional[bool] = None, - ) -> Prediction: - ... - - @overload - def create( # pylint: disable=arguments-differ disable=too-many-arguments - self, - *, - version: Union[Version, str], - input: Dict[str, Any], - webhook: Optional[str] = None, - webhook_completed: Optional[str] = None, - webhook_events_filter: Optional[List[str]] = None, - stream: Optional[bool] = None, - ) -> Prediction: - ... - - def create( - self, - *args, - **kwargs: Unpack[CreateParams], # type: ignore[misc] ) -> Prediction: """ Create a new prediction for the specified model version. @@ -229,28 +215,22 @@ def create( Prediction: The created prediction object. """ - # Support positional arguments for backwards compatibility - version = args[0] if args else kwargs.get("version") - if version is None: - raise ValueError( - "A version identifier must be provided as a positional or keyword argument." - ) - - input = args[1] if len(args) > 1 else kwargs.get("input") - if input is None: - raise ValueError( - "An input must be provided as a positional or keyword argument." - ) - body = { "version": version if isinstance(version, str) else version.id, "input": encode_json(input, upload_file=upload_file), } - for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]: - value = kwargs.get(key) - if value is not None: - body[key] = value + if webhook is not None: + body["webhook"] = webhook + + if webhook_completed is not None: + body["webhook_completed"] = webhook_completed + + if webhook_events_filter is not None: + body["webhook_events_filter"] = webhook_events_filter + + if stream is not None: + body["stream"] = stream resp = self._client._request( "POST", @@ -263,4 +243,4 @@ def create( else: del obj["version"] - return self.prepare_model(obj) + return self._prepare_model(obj) diff --git a/replicate/training.py b/replicate/training.py index ee3fe7e4..6b847fab 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -16,6 +16,8 @@ class Training(BaseModel): A training made for a model hosted on Replicate. """ + _collection: "TrainingCollection" + id: str """The unique ID of the training.""" @@ -62,6 +64,15 @@ def cancel(self) -> None: """Cancel a running training""" self._client._request("POST", f"/v1/trainings/{self.id}/cancel") # pylint: disable=no-member + def reload(self) -> None: + """ + Load the training from the server. + """ + + obj = self._collection.get(self.id) # pylint: disable=no-member + for name, value in obj.dict().items(): + setattr(self, name, value) + class TrainingCollection(Collection): """ @@ -94,9 +105,9 @@ def list(self) -> List[Training]: for training in trainings: # HACK: resolve this? make it lazy somehow? del training["version"] - return [self.prepare_model(obj) for obj in trainings] + return [self._prepare_model(obj) for obj in trainings] - def get(self, id: str) -> Training: + def get(self, id: str) -> Training: # pylint: disable=invalid-name """ Get a training by ID. @@ -113,7 +124,7 @@ def get(self, id: str) -> Training: obj = resp.json() # HACK: resolve this? make it lazy somehow? del obj["version"] - return self.prepare_model(obj) + return self._prepare_model(obj) @overload def create( # pylint: disable=arguments-differ disable=too-many-arguments @@ -209,4 +220,4 @@ def create( ) obj = resp.json() del obj["version"] - return self.prepare_model(obj) + return self._prepare_model(obj) diff --git a/replicate/version.py b/replicate/version.py index 2d7cd3ae..8579b6ae 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -18,6 +18,8 @@ class Version(BaseModel): A version of a model. """ + _collection: "VersionCollection" + id: str """The unique ID of the version.""" @@ -50,7 +52,7 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401 prediction = self._client.predictions.create(version=self, input=kwargs) # pylint: disable=no-member # Return an iterator of the output - schema = self.get_transformed_schema() + schema = make_schema_backwards_compatible(self.openapi_schema, self.cog_version) output = schema["components"]["schemas"]["Output"] if ( output.get("type") == "array" @@ -63,10 +65,14 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401 raise ModelError(prediction.error) return prediction.output - def get_transformed_schema(self) -> dict: - schema = self.openapi_schema - schema = make_schema_backwards_compatible(schema, self.cog_version) - return schema + def reload(self) -> None: + """ + Load this object from the server. + """ + + obj = self._collection.get(self.id) # pylint: disable=no-member + for name, value in obj.dict().items(): + setattr(self, name, value) class VersionCollection(Collection): @@ -80,7 +86,7 @@ def __init__(self, client: "Client", model: "Model") -> None: super().__init__(client=client) self._model = model - def get(self, id: str) -> Version: + def get(self, id: str) -> Version: # pylint: disable=invalid-name """ Get a specific model version. @@ -92,20 +98,7 @@ def get(self, id: str) -> Version: resp = self._client._request( "GET", f"/v1/models/{self._model.owner}/{self._model.name}/versions/{id}" ) - return self.prepare_model(resp.json()) - - def create( - self, - *args, - **kwargs, - ) -> Version: - """ - Create a model version. - - Raises: - NotImplementedError: This method is not implemented. - """ - raise NotImplementedError() + return self._prepare_model(resp.json()) def list(self) -> List[Version]: """ @@ -117,4 +110,4 @@ def list(self) -> List[Version]: resp = self._client._request( "GET", f"/v1/models/{self._model.owner}/{self._model.name}/versions" ) - return [self.prepare_model(obj) for obj in resp.json()["results"]] + return [self._prepare_model(obj) for obj in resp.json()["results"]] diff --git a/tests/test_run.py b/tests/test_run.py index 67c89af5..ecb712f2 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -10,6 +10,8 @@ @pytest.mark.vcr("run.yaml") @pytest.mark.asyncio async def test_run(mock_replicate_api_token): + replicate.default_client.poll_interval = 0.001 + version = "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" input = {