diff --git a/replicate/base_model.py b/replicate/base_model.py deleted file mode 100644 index 9fac634b..00000000 --- a/replicate/base_model.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from replicate.client import Client - from replicate.collection import Collection - -try: - from pydantic import v1 as pydantic # type: ignore -except ImportError: - import pydantic # type: ignore - - -class BaseModel(pydantic.BaseModel): - """ - A base class for representing a single object on the server. - """ - - id: str - - _client: "Client" = pydantic.PrivateAttr() - _collection: "Collection" = pydantic.PrivateAttr() diff --git a/replicate/client.py b/replicate/client.py index 0f2277c2..0a327fef 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -15,13 +15,13 @@ import httpx from replicate.__about__ import __version__ -from replicate.deployment import DeploymentCollection +from replicate.deployment import Deployments from replicate.exceptions import ModelError, ReplicateError -from replicate.hardware import HardwareCollection -from replicate.model import ModelCollection -from replicate.prediction import PredictionCollection +from replicate.hardware import Hardwares +from replicate.model import Models +from replicate.prediction import Predictions from replicate.schema import make_schema_backwards_compatible -from replicate.training import TrainingCollection +from replicate.training import Trainings from replicate.version import Version @@ -85,39 +85,39 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response: return resp @property - def deployments(self) -> DeploymentCollection: + def deployments(self) -> Deployments: """ Namespace for operations related to deployments. """ - return DeploymentCollection(client=self) + return Deployments(client=self) @property - def hardware(self) -> HardwareCollection: + def hardware(self) -> Hardwares: """ Namespace for operations related to hardware. """ - return HardwareCollection(client=self) + return Hardwares(client=self) @property - def models(self) -> ModelCollection: + def models(self) -> Models: """ Namespace for operations related to models. """ - return ModelCollection(client=self) + return Models(client=self) @property - def predictions(self) -> PredictionCollection: + def predictions(self) -> Predictions: """ Namespace for operations related to predictions. """ - return PredictionCollection(client=self) + return Predictions(client=self) @property - def trainings(self) -> TrainingCollection: + def trainings(self) -> Trainings: """ Namespace for operations related to trainings. """ - return TrainingCollection(client=self) + return Trainings(client=self) def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ diff --git a/replicate/deployment.py b/replicate/deployment.py index df9bb3c6..62a651a4 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -1,21 +1,20 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -from replicate.base_model import BaseModel -from replicate.collection import Collection from replicate.files import upload_file from replicate.json import encode_json from replicate.prediction import Prediction +from replicate.resource import Namespace, Resource if TYPE_CHECKING: from replicate.client import Client -class Deployment(BaseModel): +class Deployment(Resource): """ A deployment of a model hosted on Replicate. """ - _collection: "DeploymentCollection" + _namespace: "Deployments" username: str """ @@ -28,15 +27,15 @@ class Deployment(BaseModel): """ @property - def predictions(self) -> "DeploymentPredictionCollection": + def predictions(self) -> "DeploymentPredictions": """ Get the predictions for this deployment. """ - return DeploymentPredictionCollection(client=self._client, deployment=self) + return DeploymentPredictions(client=self._client, deployment=self) -class DeploymentCollection(Collection): +class Deployments(Namespace): """ Namespace for operations related to deployments. """ @@ -59,14 +58,14 @@ def get(self, name: str) -> Deployment: return self._prepare_model({"username": username, "name": name}) def _prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment: - if isinstance(attrs, BaseModel): + if isinstance(attrs, Resource): attrs.id = f"{attrs.username}/{attrs.name}" elif isinstance(attrs, dict): attrs["id"] = f"{attrs['username']}/{attrs['name']}" return super()._prepare_model(attrs) -class DeploymentPredictionCollection(Collection): +class DeploymentPredictions(Namespace): """ Namespace for operations related to predictions in a deployment. """ diff --git a/replicate/hardware.py b/replicate/hardware.py index 5feee199..1807a2a7 100644 --- a/replicate/hardware.py +++ b/replicate/hardware.py @@ -1,10 +1,9 @@ from typing import Dict, List, Union -from replicate.base_model import BaseModel -from replicate.collection import Collection +from replicate.resource import Namespace, Resource -class Hardware(BaseModel): +class Hardware(Resource): """ Hardware for running a model on Replicate. """ @@ -20,7 +19,7 @@ class Hardware(BaseModel): """ -class HardwareCollection(Collection): +class Hardwares(Namespace): """ Namespace for operations related to hardware. """ @@ -40,7 +39,7 @@ def list(self) -> List[Hardware]: return [self._prepare_model(obj) for obj in hardware] def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware: - if isinstance(attrs, BaseModel): + if isinstance(attrs, Resource): attrs.id = attrs.sku elif isinstance(attrs, dict): attrs["id"] = attrs["sku"] diff --git a/replicate/model.py b/replicate/model.py index 485e2c83..31825ddd 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -2,19 +2,18 @@ from typing_extensions import deprecated -from replicate.base_model import BaseModel -from replicate.collection import Collection from replicate.exceptions import ReplicateException from replicate.prediction import Prediction -from replicate.version import Version, VersionCollection +from replicate.resource import Namespace, Resource +from replicate.version import Version, Versions -class Model(BaseModel): +class Model(Resource): """ A machine learning model hosted on Replicate. """ - _collection: "ModelCollection" + _namespace: "Models" url: str """ @@ -100,24 +99,24 @@ def predict(self, *args, **kwargs) -> None: ) @property - def versions(self) -> VersionCollection: + def versions(self) -> Versions: """ Get the versions of this model. """ - return VersionCollection(client=self._client, model=self) + return Versions(client=self._client, model=self) def reload(self) -> None: """ Load this object from the server. """ - obj = self._collection.get(f"{self.owner}/{self.name}") # pylint: disable=no-member + obj = self._namespace.get(f"{self.owner}/{self.name}") # pylint: disable=no-member for name, value in obj.dict().items(): setattr(self, name, value) -class ModelCollection(Collection): +class Models(Namespace): """ Namespace for operations related to models. """ @@ -208,7 +207,7 @@ def create( # pylint: disable=arguments-differ disable=too-many-arguments return self._prepare_model(resp.json()) def _prepare_model(self, attrs: Union[Model, Dict]) -> Model: - if isinstance(attrs, BaseModel): + if isinstance(attrs, Resource): attrs.id = f"{attrs.owner}/{attrs.name}" elif isinstance(attrs, dict): attrs["id"] = f"{attrs['owner']}/{attrs['name']}" diff --git a/replicate/prediction.py b/replicate/prediction.py index 9fc69385..d6202158 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -3,20 +3,19 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Union -from replicate.base_model import BaseModel -from replicate.collection import Collection from replicate.exceptions import ModelError from replicate.files import upload_file from replicate.json import encode_json +from replicate.resource import Namespace, Resource from replicate.version import Version -class Prediction(BaseModel): +class Prediction(Resource): """ A prediction made by a model hosted on Replicate. """ - _collection: "PredictionCollection" + _namespace: "Predictions" id: str """The unique ID of the prediction.""" @@ -146,12 +145,12 @@ def reload(self) -> None: Load this prediction from the server. """ - obj = self._collection.get(self.id) # pylint: disable=no-member + obj = self._namespace.get(self.id) # pylint: disable=no-member for name, value in obj.dict().items(): setattr(self, name, value) -class PredictionCollection(Collection): +class Predictions(Namespace): """ Namespace for operations related to predictions. """ diff --git a/replicate/collection.py b/replicate/resource.py similarity index 64% rename from replicate/collection.py rename to replicate/resource.py index b9b8c40a..9a3473fc 100644 --- a/replicate/collection.py +++ b/replicate/resource.py @@ -1,16 +1,32 @@ import abc from typing import TYPE_CHECKING, Dict, Generic, TypeVar, Union, cast +from replicate.exceptions import ReplicateException + +try: + from pydantic import v1 as pydantic # type: ignore +except ImportError: + import pydantic # type: ignore + if TYPE_CHECKING: from replicate.client import Client -from replicate.base_model import BaseModel -from replicate.exceptions import ReplicateException -Model = TypeVar("Model", bound=BaseModel) +class Resource(pydantic.BaseModel): + """ + A base class for representing a single object on the server. + """ + + id: str + + _client: "Client" = pydantic.PrivateAttr() + _namespace: "Namespace" = pydantic.PrivateAttr() + + +Model = TypeVar("Model", bound=Resource) -class Collection(abc.ABC, Generic[Model]): +class Namespace(abc.ABC, Generic[Model]): """ A base class for representing objects of a particular type on the server. """ @@ -25,15 +41,15 @@ def _prepare_model(self, attrs: Union[Model, Dict]) -> Model: """ Create a model from a set of attributes. """ - if isinstance(attrs, BaseModel): + if isinstance(attrs, Resource): attrs._client = self._client - attrs._collection = self + attrs._namespace = self return cast(Model, attrs) if isinstance(attrs, dict) and self.model is not None and callable(self.model): model = self.model(**attrs) model._client = self._client - model._collection = self + model._namespace = self return model name = self.model.__name__ if hasattr(self.model, "__name__") else "model" diff --git a/replicate/training.py b/replicate/training.py index 6b847fab..18223bf4 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -3,20 +3,19 @@ from typing_extensions import NotRequired, Unpack, overload -from replicate.base_model import BaseModel -from replicate.collection import Collection from replicate.exceptions import ReplicateException from replicate.files import upload_file from replicate.json import encode_json +from replicate.resource import Namespace, Resource from replicate.version import Version -class Training(BaseModel): +class Training(Resource): """ A training made for a model hosted on Replicate. """ - _collection: "TrainingCollection" + _namespace: "Trainings" id: str """The unique ID of the training.""" @@ -69,12 +68,12 @@ def reload(self) -> None: Load the training from the server. """ - obj = self._collection.get(self.id) # pylint: disable=no-member + obj = self._namespace.get(self.id) # pylint: disable=no-member for name, value in obj.dict().items(): setattr(self, name, value) -class TrainingCollection(Collection): +class Trainings(Namespace): """ Namespace for operations related to trainings. """ diff --git a/replicate/version.py b/replicate/version.py index 8579b6ae..0496a6a4 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -7,18 +7,17 @@ from replicate.model import Model -from replicate.base_model import BaseModel -from replicate.collection import Collection from replicate.exceptions import ModelError +from replicate.resource import Namespace, Resource from replicate.schema import make_schema_backwards_compatible -class Version(BaseModel): +class Version(Resource): """ A version of a model. """ - _collection: "VersionCollection" + _namespace: "Versions" id: str """The unique ID of the version.""" @@ -70,12 +69,12 @@ def reload(self) -> None: Load this object from the server. """ - obj = self._collection.get(self.id) # pylint: disable=no-member + obj = self._namespace.get(self.id) # pylint: disable=no-member for name, value in obj.dict().items(): setattr(self, name, value) -class VersionCollection(Collection): +class Versions(Namespace): """ Namespace for operations related to model versions. """