Skip to content

Commit a0a06fe

Browse files
committed
Switch to using file outputs and blocking api by default
1 parent 08ee31a commit a0a06fe

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

replicate/prediction.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,13 @@ class CreatePredictionParams(TypedDict):
395395

396396
wait: NotRequired[Union[int, bool]]
397397
"""
398-
Wait until the prediction is completed before returning.
398+
Block until the prediction is completed before returning.
399399
400-
If `True`, wait a predetermined number of seconds until the prediction
401-
is completed before returning.
402-
If an `int`, wait for the specified number of seconds.
400+
If `True`, keep the request open for up to 60 seconds, falling back to
401+
polling until the prediction is completed.
402+
If an `int`, same as True but hold the request for a specified number of
403+
seconds (between 1 and 60).
404+
If `False`, poll for the prediction status until completed.
403405
"""
404406

405407
file_encoding_strategy: NotRequired[FileEncodingStrategy]

replicate/run.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,17 @@ def run(
2929
client: "Client",
3030
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
3131
input: Optional[Dict[str, Any]] = None,
32-
use_file_output: Optional[bool] = None,
32+
use_file_output: Optional[bool] = True,
3333
**params: Unpack["Predictions.CreatePredictionParams"],
3434
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
3535
"""
3636
Run a model and wait for its output.
3737
"""
3838

39-
is_blocking = "wait" in params
39+
if "wait" not in params:
40+
params["wait"] = True
41+
is_blocking = params["wait"] != False
42+
4043
version, owner, name, version_id = identifier._resolve(ref)
4144

4245
if version_id is not None:
@@ -74,7 +77,7 @@ async def async_run(
7477
client: "Client",
7578
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
7679
input: Optional[Dict[str, Any]] = None,
77-
use_file_output: Optional[bool] = None,
80+
use_file_output: Optional[bool] = True,
7881
**params: Unpack["Predictions.CreatePredictionParams"],
7982
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
8083
"""

tests/test_run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def prediction_with_status(status: str) -> dict:
123123
router.route(method="POST", path="/predictions").mock(
124124
return_value=httpx.Response(
125125
201,
126-
json=prediction_with_status("processing"),
126+
json=prediction_with_status("starting"),
127127
)
128128
)
129129
router.route(method="GET", path="/predictions/p1").mock(
@@ -212,7 +212,7 @@ def prediction_with_status(status: str) -> dict:
212212
router.route(method="POST", path="/predictions").mock(
213213
return_value=httpx.Response(
214214
201,
215-
json=prediction_with_status("processing"),
215+
json=prediction_with_status("starting"),
216216
)
217217
)
218218
router.route(method="GET", path="/predictions/p1").mock(
@@ -454,7 +454,7 @@ def prediction_with_status(
454454
router.route(method="POST", path="/predictions").mock(
455455
return_value=httpx.Response(
456456
201,
457-
json=prediction_with_status("processing"),
457+
json=prediction_with_status("starting"),
458458
)
459459
)
460460
router.route(method="GET", path="/predictions/p1").mock(
@@ -541,7 +541,7 @@ def prediction_with_status(
541541
router.route(method="POST", path="/predictions").mock(
542542
return_value=httpx.Response(
543543
201,
544-
json=prediction_with_status("processing"),
544+
json=prediction_with_status("starting"),
545545
)
546546
)
547547
router.route(method="GET", path="/predictions/p1").mock(

0 commit comments

Comments
 (0)