diff --git a/README.md b/README.md index ae288708..cea980ab 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,7 @@ background = Image.open("/tmp/out.png") ## List models -You can the models you've created: +You can list the models you've created: ```python replicate.models.list() diff --git a/replicate/model.py b/replicate/model.py index 2349fe5e..f408140f 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload from typing_extensions import NotRequired, TypedDict, Unpack, deprecated @@ -207,31 +207,77 @@ async def async_list( return Page[Model](**obj) - def get(self, key: str) -> Model: + @overload + def get(self, key: str) -> Model: ... + + @overload + def get(self, owner: str, name: str) -> Model: ... + + def get(self, *args, **kwargs) -> Model: + """ + Get a model by name. + """ + + url = _get_model_url(*args, **kwargs) + resp = self._client._request("GET", url) + + return _json_to_model(self._client, resp.json()) + + @overload + async def async_get(self, key: str) -> Model: ... + + @overload + async def async_get(self, owner: str, name: str) -> Model: ... + + async def async_get(self, *args, **kwargs) -> Model: """ Get a model by name. Args: - key: The qualified name of the model, in the format `owner/model-name`. + key: The qualified name of the model, in the format `owner/name`. Returns: The model. """ - resp = self._client._request("GET", f"/v1/models/{key}") + url = _get_model_url(*args, **kwargs) + resp = await self._client._async_request("GET", url) return _json_to_model(self._client, resp.json()) - async def async_get(self, key: str) -> Model: + @overload + def delete(self, key: str) -> Model: ... + + @overload + def delete(self, owner: str, name: str) -> Model: ... + + def delete(self, *args, **kwargs) -> Model: """ - Get a model by name. + Delete a model by name. + """ + + url = _delete_model_url(*args, **kwargs) + resp = self._client._request("DELETE", url) + + return _json_to_model(self._client, resp.json()) + + @overload + async def async_delete(self, key: str) -> Model: ... + + @overload + async def async_delete(self, owner: str, name: str) -> Model: ... + + async def async_delete(self, *args, **kwargs) -> Model: + """ + Delete a model by name. Args: - key: The qualified name of the model, in the format `owner/model-name`. + key: The qualified name of the model, in the format `owner/name`. Returns: The model. """ - resp = await self._client._async_request("GET", f"/v1/models/{key}") + url = _delete_model_url(*args, **kwargs) + resp = await self._client._async_request("DELETE", url) return _json_to_model(self._client, resp.json()) @@ -374,6 +420,41 @@ def _create_model_body( # pylint: disable=too-many-arguments return body +def _get_model_url(*args, **kwargs) -> str: + if len(args) > 0 and len(kwargs) > 0: + raise ValueError("Cannot mix positional and keyword arguments") + + owner = kwargs.get("owner", None) + name = kwargs.get("name", None) + key = kwargs.get("key", None) + + if key and (owner or name): + raise ValueError( + "Must specify exactly one of 'owner' and 'name' or single 'key' in the format 'owner/name'" + ) + + if args: + if len(args) == 1: + key = args[0] + elif len(args) == 2: + owner, name = args + else: + raise ValueError("Invalid number of arguments") + + if not key: + if not (owner and name): + raise ValueError( + "Both 'owner' and 'name' must be provided if 'key' is not specified." + ) + key = f"{owner}/{name}" + + return f"/v1/models/{key}" + + +def _delete_model_url(*args, **kwargs) -> str: + return _get_model_url(*args, **kwargs) + + def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model: model = Model(**json) model._client = client