diff --git a/replicate/collection.py b/replicate/collection.py index 799b7b63..19e8e6c1 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -32,7 +32,9 @@ def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring pass @abc.abstractmethod - def create(self, **kwargs) -> Model: # pylint: disable=missing-function-docstring + def create( # pylint: disable=missing-function-docstring + self, *args, **kwargs + ) -> Model: pass def prepare_model(self, attrs: Union[Model, Dict]) -> Model: diff --git a/replicate/deployment.py b/replicate/deployment.py index 1a0766c7..cfdd0da1 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -1,10 +1,12 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload + +from typing_extensions import Unpack from replicate.base_model import BaseModel from replicate.collection import Collection from replicate.files import upload_file from replicate.json import encode_json -from replicate.prediction import Prediction +from replicate.prediction import Prediction, PredictionCollection if TYPE_CHECKING: from replicate.client import Client @@ -65,7 +67,11 @@ def get(self, name: str) -> Deployment: username, name = name.split("/") return self.prepare_model({"username": username, "name": name}) - def create(self, **kwargs) -> Deployment: + def create( + self, + *args, + **kwargs, + ) -> Deployment: """ Create a deployment. @@ -114,15 +120,34 @@ def get(self, id: str) -> Prediction: del obj["version"] return self.prepare_model(obj) - def create( # type: ignore + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments self, input: Dict[str, Any], + *, webhook: Optional[str] = None, webhook_completed: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, + stream: Optional[bool] = None, + ) -> Prediction: + ... + + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, *, + input: Dict[str, Any], + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, stream: Optional[bool] = None, - **kwargs, + ) -> Prediction: + ... + + def create( + self, + *args, + **kwargs: Unpack[PredictionCollection.CreateParams], # type: ignore[misc] ) -> Prediction: """ Create a new prediction with the deployment. @@ -138,18 +163,20 @@ def create( # type: ignore Prediction: The created prediction object. """ - input = encode_json(input, upload_file=upload_file) - body: Dict[str, Any] = { - "input": input, + input = args[0] if len(args) > 0 else kwargs.get("input") + if input is None: + raise ValueError( + "An input must be provided as a positional or keyword argument." + ) + + body = { + "input": encode_json(input, upload_file=upload_file), } - if webhook is not None: - body["webhook"] = webhook - if webhook_completed is not None: - body["webhook_completed"] = webhook_completed - if webhook_events_filter is not None: - body["webhook_events_filter"] = webhook_events_filter - if stream is True: - body["stream"] = True + + for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]: + value = kwargs.get(key) + if value is not None: + body[key] = value resp = self._client._request( "POST", diff --git a/replicate/model.py b/replicate/model.py index 3dcc2427..887e2bfc 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -139,7 +139,11 @@ def get(self, key: str) -> Model: resp = self._client._request("GET", f"/v1/models/{key}") return self.prepare_model(resp.json()) - def create(self, **kwargs) -> Model: + def create( + self, + *args, + **kwargs, + ) -> Model: """ Create a model. diff --git a/replicate/prediction.py b/replicate/prediction.py index f8afa5e8..05cdec2b 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,7 +1,9 @@ import re import time from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, TypedDict, Union, overload + +from typing_extensions import Unpack from replicate.base_model import BaseModel from replicate.collection import Collection @@ -137,6 +139,16 @@ class PredictionCollection(Collection): Namespace for operations related to predictions. """ + class CreateParams(TypedDict): + """Parameters for creating a prediction.""" + + version: Union[Version, str] + input: Dict[str, Any] + webhook: Optional[str] + webhook_completed: Optional[str] + webhook_events_filter: Optional[List[str]] + stream: Optional[bool] + model = Prediction def list(self) -> List[Prediction]: @@ -171,16 +183,36 @@ def get(self, id: str) -> Prediction: del obj["version"] return self.prepare_model(obj) - def create( # type: ignore + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments self, version: Union[Version, str], input: Dict[str, Any], + *, webhook: Optional[str] = None, webhook_completed: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, + stream: Optional[bool] = None, + ) -> Prediction: + ... + + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, *, + version: Union[Version, str], + input: Dict[str, Any], + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, stream: Optional[bool] = None, - **kwargs, + ) -> Prediction: + ... + + def create( + self, + *args, + **kwargs: Unpack[CreateParams], # type: ignore[misc] ) -> Prediction: """ Create a new prediction for the specified model version. @@ -197,19 +229,28 @@ def create( # type: ignore Prediction: The created prediction object. """ - input = encode_json(input, upload_file=upload_file) - body: Dict[str, Any] = { + # Support positional arguments for backwards compatibility + version = args[0] if args else kwargs.get("version") + if version is None: + raise ValueError( + "A version identifier must be provided as a positional or keyword argument." + ) + + input = args[1] if len(args) > 1 else kwargs.get("input") + if input is None: + raise ValueError( + "An input must be provided as a positional or keyword argument." + ) + + body = { "version": version if isinstance(version, str) else version.id, - "input": input, + "input": encode_json(input, upload_file=upload_file), } - if webhook is not None: - body["webhook"] = webhook - if webhook_completed is not None: - body["webhook_completed"] = webhook_completed - if webhook_events_filter is not None: - body["webhook_events_filter"] = webhook_events_filter - if stream is True: - body["stream"] = True + + for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]: + value = kwargs.get(key) + if value is not None: + body[key] = value resp = self._client._request( "POST", diff --git a/replicate/training.py b/replicate/training.py index 4499a79e..ee3fe7e4 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -1,5 +1,7 @@ import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TypedDict, Union + +from typing_extensions import NotRequired, Unpack, overload from replicate.base_model import BaseModel from replicate.collection import Collection @@ -68,6 +70,16 @@ class TrainingCollection(Collection): model = Training + class CreateParams(TypedDict): + """Parameters for creating a prediction.""" + + version: Union[Version, str] + destination: str + input: Dict[str, Any] + webhook: NotRequired[str] + webhook_completed: NotRequired[str] + webhook_events_filter: NotRequired[List[str]] + def list(self) -> List[Training]: """ List your trainings. @@ -103,14 +115,36 @@ def get(self, id: str) -> Training: del obj["version"] return self.prepare_model(obj) - def create( # type: ignore + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, + version: Union[Version, str], + input: Dict[str, Any], + destination: str, + *, + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, + ) -> Training: + ... + + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments self, - version: str, + *, + version: Union[Version, str], input: Dict[str, Any], destination: str, webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, - **kwargs, + ) -> Training: + ... + + def create( + self, + *args, + **kwargs: Unpack[CreateParams], # type: ignore[misc] ) -> Training: """ Create a new training using the specified model version as a base. @@ -120,24 +154,45 @@ def create( # type: ignore input: The input to the training. destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request. webhook: The URL to send a POST request to when the training is completed. Defaults to None. + webhook_completed: The URL to receive a POST request when the prediction is completed. webhook_events_filter: The events to send to the webhook. Defaults to None. Returns: The training object. """ - input = encode_json(input, upload_file=upload_file) + # Support positional arguments for backwards compatibility + version = args[0] if args else kwargs.get("version") + if version is None: + raise ValueError( + "A version identifier must be provided as a positional or keyword argument." + ) + + destination = args[1] if len(args) > 1 else kwargs.get("destination") + if destination is None: + raise ValueError( + "A destination must be provided as a positional or keyword argument." + ) + + input = args[2] if len(args) > 2 else kwargs.get("input") + if input is None: + raise ValueError( + "An input must be provided as a positional or keyword argument." + ) + body = { - "input": input, + "input": encode_json(input, upload_file=upload_file), "destination": destination, } - if webhook is not None: - body["webhook"] = webhook - if webhook_events_filter is not None: - body["webhook_events_filter"] = webhook_events_filter + + for key in ["webhook", "webhook_completed", "webhook_events_filter"]: + value = kwargs.get(key) + if value is not None: + body[key] = value # Split version in format "username/model_name:version_id" match = re.match( - r"^(?P[^/]+)/(?P[^:]+):(?P.+)$", version + r"^(?P[^/]+)/(?P[^:]+):(?P.+)$", + version.id if isinstance(version, Version) else version, ) if not match: raise ReplicateException( diff --git a/replicate/version.py b/replicate/version.py index c3be8b2e..2d7cd3ae 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -94,7 +94,11 @@ def get(self, id: str) -> Version: ) return self.prepare_model(resp.json()) - def create(self, **kwargs) -> Version: + def create( + self, + *args, + **kwargs, + ) -> Version: """ Create a model version. diff --git a/tests/test_training.py b/tests/test_training.py index da215749..0c4a4782 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio async def test_trainings_create(mock_replicate_api_token): training = replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -22,6 +22,22 @@ async def test_trainings_create(mock_replicate_api_token): assert training.status == "starting" +@pytest.mark.vcr("trainings-create.yaml") +@pytest.mark.asyncio +async def test_trainings_create_with_positional_argument(mock_replicate_api_token): + training = replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + { + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + "replicate/dreambooth-sdxl", + ) + + assert training.id is not None + assert training.status == "starting" + + @pytest.mark.vcr("trainings-create__invalid-destination.yaml") @pytest.mark.asyncio async def test_trainings_create_with_invalid_destination(mock_replicate_api_token): @@ -57,7 +73,7 @@ async def test_trainings_cancel(mock_replicate_api_token): destination = "replicate/dreambooth-sdxl" training = replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", destination=destination, input=input, )