diff --git a/replicate/client.py b/replicate/client.py index a08dacf0..91a2bf07 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -115,7 +115,13 @@ def trainings(self) -> TrainingCollection: def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: """ - Run a model in the format owner/name:version. + Run a model and wait for its output. + + Args: + model_version: The model version to run, in the format `owner/name:version` + kwargs: The input to the model, as a dictionary + Returns: + The output of the model """ # Split model_version into owner, name, version in format owner/name:version m = re.match(r"^(?P[^/]+/[^:]+):(?P.+)$", model_version) diff --git a/replicate/collection.py b/replicate/collection.py index 92e7a88a..32596f89 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -11,8 +11,7 @@ class Collection(abc.ABC, Generic[Model]): """ - A base class for representing all objects of a particular type on the - server. + A base class for representing objects of a particular type on the server. """ def __init__(self, client: "Client") -> None: diff --git a/replicate/files.py b/replicate/files.py index 55a6612c..27dbb6db 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -9,8 +9,16 @@ def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str: """ - Lifted straight from cog.files + Upload a file to the server. + + Args: + fh: A file handle to upload. + output_file_prefix: A string to prepend to the output file name. + Returns: + str: A URL to the uploaded file. """ + # Lifted straight from cog.files + fh.seek(0) if output_file_prefix is not None: diff --git a/replicate/json.py b/replicate/json.py index cd0b864e..8884cd06 100644 --- a/replicate/json.py +++ b/replicate/json.py @@ -15,8 +15,10 @@ def encode_json( obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401 ) -> Any: # noqa: ANN401 """ - Returns a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json. + Return a JSON-compatible version of the object. """ + # Effectively the same thing as cog.json.encode_json. + if isinstance(obj, dict): return {key: encode_json(value, upload_file) for key, value in obj.items()} if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): diff --git a/replicate/model.py b/replicate/model.py index d6b32fcd..4787e337 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -7,16 +7,35 @@ class Model(BaseModel): + """ + A machine learning model hosted on Replicate. + """ + username: str + """ + The name of the user or organization that owns the model. + """ + name: str + """ + The name of the model. + """ def predict(self, *args, **kwargs) -> None: + """ + DEPRECATED: Use `version.predict()` instead. + """ + raise ReplicateException( "The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme" ) @property def versions(self) -> VersionCollection: + """ + Get the versions of this model. + """ + return VersionCollection(client=self._client, model=self) @@ -27,6 +46,15 @@ def list(self) -> List[Model]: raise NotImplementedError() def get(self, name: str) -> Model: + """ + Get a model by name. + + Args: + name: The name of the model, in the format `owner/model-name`. + Returns: + The model. + """ + # TODO: fetch model from server # TODO: support permanent IDs username, name = name.split("/") diff --git a/replicate/prediction.py b/replicate/prediction.py index db197ae3..9f2fc8a7 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -10,20 +10,53 @@ class Prediction(BaseModel): + """ + A prediction made by a model hosted on Replicate. + """ + id: str - error: Optional[str] + """The unique ID of the prediction.""" + + version: Optional[Version] + """The version of the model used to create the prediction.""" + + status: str + """The status of the prediction.""" + input: Optional[Dict[str, Any]] - logs: Optional[str] + """The input to the prediction.""" + output: Optional[Any] - status: str - version: Optional[Version] - started_at: Optional[str] + """The output of the prediction.""" + + logs: Optional[str] + """The logs of the prediction.""" + + error: Optional[str] + """The error encountered during the prediction, if any.""" + created_at: Optional[str] + """When the prediction was created.""" + + started_at: Optional[str] + """When the prediction was started.""" + completed_at: Optional[str] + """When the prediction was completed, if finished.""" + urls: Optional[Dict[str, str]] + """ + URLs associated with the prediction. + + The following keys are available: + - `get`: A URL to fetch the prediction. + - `cancel`: A URL to cancel the prediction. + """ def wait(self) -> None: - """Wait for prediction to finish.""" + """ + Wait for prediction to finish. + """ while self.status not in ["succeeded", "failed", "canceled"]: time.sleep(self._client.poll_interval) self.reload() @@ -48,7 +81,9 @@ def output_iterator(self) -> Iterator[Any]: yield output def cancel(self) -> None: - """Cancel a currently running prediction""" + """ + Cancels a running prediction. + """ self._client._request("POST", f"/v1/predictions/{self.id}/cancel") @@ -56,6 +91,13 @@ class PredictionCollection(Collection): model = Prediction def list(self) -> List[Prediction]: + """ + List your predictions. + + Returns: + A list of prediction objects. + """ + resp = self._client._request("GET", "/v1/predictions") # TODO: paginate predictions = resp.json()["results"] @@ -65,6 +107,15 @@ def list(self) -> List[Prediction]: return [self.prepare_model(obj) for obj in predictions] def get(self, id: str) -> Prediction: + """ + Get a prediction by ID. + + Args: + id: The ID of the prediction. + Returns: + Prediction: The prediction object. + """ + resp = self._client._request("GET", f"/v1/predictions/{id}") obj = resp.json() # HACK: resolve this? make it lazy somehow? @@ -80,6 +131,21 @@ def create( # type: ignore webhook_events_filter: Optional[List[str]] = None, **kwargs, ) -> Prediction: + """ + Create a new prediction for the specified model version. + + Args: + version: The model version to use for the prediction. + input: The input data for the prediction. + webhook: The URL to receive a POST request with prediction updates. + webhook_completed: The URL to receive a POST request when the prediction is completed. + webhook_events_filter: List of events to trigger webhooks. + stream: Set to True to enable streaming of prediction output. + + Returns: + Prediction: The created prediction object. + """ + input = encode_json(input, upload_file=upload_file) body = { "version": version.id, diff --git a/replicate/training.py b/replicate/training.py index d7e97bc3..d93b56ab 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -10,17 +10,51 @@ class Training(BaseModel): - completed_at: Optional[str] - created_at: Optional[str] - destination: Optional[str] - error: Optional[str] + """ + A training made for a model hosted on Replicate. + """ + id: str + """The unique ID of the training.""" + + version: Optional[Version] + """The version of the model used to create the training.""" + + destination: Optional[str] + """The model destination of the training.""" + + status: str + """The status of the training.""" + input: Optional[Dict[str, Any]] - logs: Optional[str] + """The input to the training.""" + output: Optional[Any] + """The output of the training.""" + + logs: Optional[str] + """The logs of the training.""" + + error: Optional[str] + """The error encountered during the training, if any.""" + + created_at: Optional[str] + """When the training was created.""" + started_at: Optional[str] - status: str - version: Optional[Version] + """When the training was started.""" + + completed_at: Optional[str] + """When the training was completed, if finished.""" + + urls: Optional[Dict[str, str]] + """ + URLs associated with the training. + + The following keys are available: + - `get`: A URL to fetch the training. + - `cancel`: A URL to cancel the training. + """ def cancel(self) -> None: """Cancel a running training""" @@ -31,6 +65,13 @@ class TrainingCollection(Collection): model = Training def list(self) -> List[Training]: + """ + List your trainings. + + Returns: + List[Training]: A list of training objects. + """ + resp = self._client._request("GET", "/v1/trainings") # TODO: paginate trainings = resp.json()["results"] @@ -40,6 +81,15 @@ def list(self) -> List[Training]: return [self.prepare_model(obj) for obj in trainings] def get(self, id: str) -> Training: + """ + Get a training by ID. + + Args: + id: The ID of the training. + Returns: + Training: The training object. + """ + resp = self._client._request( "GET", f"/v1/trainings/{id}", @@ -58,6 +108,19 @@ def create( # type: ignore webhook_events_filter: Optional[List[str]] = None, **kwargs, ) -> Training: + """ + Create a new training using the specified model version as a base. + + Args: + version: The ID of the base model version that you're using to train a new model version. + input: The input to the training. + destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request. + webhook: The URL to send a POST request to when the training is completed. Defaults to None. + webhook_events_filter: The events to send to the webhook. Defaults to None. + Returns: + The training object. + """ + input = encode_json(input, upload_file=upload_file) body = { "input": input, diff --git a/replicate/version.py b/replicate/version.py index d4ed9108..71cee3ad 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -14,12 +14,32 @@ class Version(BaseModel): + """ + A version of a model. + """ + id: str + """The unique ID of the version.""" + created_at: datetime.datetime + """When the version was created.""" + cog_version: str + """The version of the Cog used to create the version.""" + openapi_schema: dict + """An OpenAPI description of the model inputs and outputs.""" def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: + """ + Create a prediction using this model version. + + Args: + kwargs: The input to the model. + Returns: + The output of the model. + """ + warnings.warn( "version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.", DeprecationWarning, @@ -57,7 +77,12 @@ def __init__(self, client: "Client", model: "Model") -> None: # doesn't exist yet def get(self, id: str) -> Version: """ - Get a specific version. + Get a specific model version. + + Args: + id: The version ID. + Returns: + The model version. """ resp = self._client._request( "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}" @@ -70,6 +95,9 @@ def create(self, **kwargs) -> Version: def list(self) -> List[Version]: """ Return a list of all versions for a model. + + Returns: + List[Version]: A list of version objects. """ resp = self._client._request( "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions"