diff --git a/pyproject.toml b/pyproject.toml index 11c97bf..fde2865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "replicate" -version = "0.34.2" +version = "1.0.0b3" description = "Python client for Replicate" readme = "README.md" license = { file = "LICENSE" } diff --git a/replicate/client.py b/replicate/client.py index 52d07f7..3e767d6 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -164,53 +164,59 @@ def run( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output. """ - return run(self, ref, input, use_file_output, **params) + return run(self, ref, input, use_file_output=use_file_output, **params) async def async_run( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ - return await async_run(self, ref, input, use_file_output, **params) + return await async_run( + self, ref, input, use_file_output=use_file_output, **params + ) def stream( self, ref: str, + *, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator["ServerSentEvent"]: """ Stream a model's output. """ - return stream(self, ref, input, use_file_output, **params) + return stream(self, ref, input, use_file_output=use_file_output, **params) async def async_stream( self, ref: str, input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator["ServerSentEvent"]: """ Stream a model's output asynchronously. """ - return async_stream(self, ref, input, use_file_output, **params) + return async_stream(self, ref, input, use_file_output=use_file_output, **params) # Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155 diff --git a/replicate/prediction.py b/replicate/prediction.py index 0e5342a..aa3e45c 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -395,11 +395,13 @@ class CreatePredictionParams(TypedDict): wait: NotRequired[Union[int, bool]] """ - Wait until the prediction is completed before returning. + Block until the prediction is completed before returning. - If `True`, wait a predetermined number of seconds until the prediction - is completed before returning. - If an `int`, wait for the specified number of seconds. + If `True`, keep the request open for up to 60 seconds, falling back to + polling until the prediction is completed. + If an `int`, same as True but hold the request for a specified number of + seconds (between 1 and 60). + If `False`, poll for the prediction status until completed. """ file_encoding_strategy: NotRequired[FileEncodingStrategy] diff --git a/replicate/run.py b/replicate/run.py index d159f11..3b6bddb 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -29,14 +29,18 @@ def run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output. """ - is_blocking = "wait" in params + if "wait" not in params: + params["wait"] = True + is_blocking = params["wait"] != False # noqa: E712 + version, owner, name, version_id = identifier._resolve(ref) if version_id is not None: @@ -74,13 +78,18 @@ async def async_run( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401 """ Run a model and wait for its output asynchronously. """ + if "wait" not in params: + params["wait"] = True + is_blocking = params["wait"] != False # noqa: E712 + version, owner, name, version_id = identifier._resolve(ref) if version or version_id: @@ -102,7 +111,8 @@ async def async_run( if version and (iterator := _make_async_output_iterator(version, prediction)): return iterator - await prediction.async_wait() + if not (is_blocking and prediction.status != "starting"): + await prediction.async_wait() if prediction.status == "failed": raise ModelError(prediction) diff --git a/replicate/stream.py b/replicate/stream.py index 4cf0d15..e837abd 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -71,11 +71,12 @@ def __init__( self, client: "Client", response: "httpx.Response", - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, ) -> None: self.client = client self.response = response - self.use_file_output = use_file_output or False + self.use_file_output = use_file_output or True content_type, _, _ = response.headers["content-type"].partition(";") if content_type != "text/event-stream": raise ValueError( @@ -193,7 +194,8 @@ def stream( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator[ServerSentEvent]: """ @@ -234,7 +236,8 @@ async def async_stream( client: "Client", ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, - use_file_output: Optional[bool] = None, + *, + use_file_output: Optional[bool] = True, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator[ServerSentEvent]: """ diff --git a/tests/test_run.py b/tests/test_run.py index 7d963a4..0f9aed2 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -123,7 +123,7 @@ def prediction_with_status(status: str) -> dict: router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -212,7 +212,7 @@ def prediction_with_status(status: str) -> dict: router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -454,7 +454,7 @@ def prediction_with_status( router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( @@ -541,7 +541,7 @@ def prediction_with_status( router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("processing"), + json=prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock(