Skip to content

Commit b71341b

Browse files
committed
fix linter issues
1 parent e227b9a commit b71341b

File tree

5 files changed

+25
-14
lines changed

5 files changed

+25
-14
lines changed

src/replicate/_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import httpx
1010

11+
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
12+
1113
from . import _exceptions
1214
from ._qs import Querystring
1315
from .types import PredictionOutput, PredictionCreateParams
@@ -135,7 +137,7 @@ def run(
135137
ref: str,
136138
*,
137139
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
138-
**params: Unpack[PredictionCreateParams],
140+
**params: Unpack[PredictionCreateParamsWithoutVersion],
139141
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
140142
"""Run a model and wait for its output."""
141143
from .lib._predictions import run

src/replicate/lib/_predictions.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import TYPE_CHECKING, Dict, Union, Iterable
44
from typing_extensions import Unpack
55

6+
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
7+
68
from ..types import PredictionOutput, PredictionCreateParams
79
from .._types import NOT_GIVEN, NotGiven
810
from .._utils import is_given
@@ -21,7 +23,7 @@ def run(
2123
*,
2224
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
2325
# use_file_output: Optional[bool] = True,
24-
**params: Unpack[PredictionCreateParams],
26+
**params: Unpack[PredictionCreateParamsWithoutVersion],
2527
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
2628
from ._files import transform_output
2729

@@ -40,7 +42,8 @@ def run(
4042
params.setdefault("prefer", f"wait={wait}")
4143

4244
# TODO: support more ref types
43-
prediction = client.predictions.create(version=ref, **params)
45+
params_with_version: PredictionCreateParams = {**params, "version": ref}
46+
prediction = client.predictions.create(**params_with_version)
4447

4548
# Currently the "Prefer: wait" interface will return a prediction with a status
4649
# of "processing" rather than a terminal state because it returns before the
@@ -91,7 +94,8 @@ async def async_run(
9194
params.setdefault("prefer", f"wait={wait}")
9295

9396
# TODO: support more ref types
94-
prediction = await client.predictions.create(version=ref, **params)
97+
params_with_version: PredictionCreateParams = {**params, "version": ref}
98+
prediction = await client.predictions.create(**params_with_version)
9599

96100
# Currently the "Prefer: wait" interface will return a prediction with a status
97101
# of "processing" rather than a terminal state because it returns before the

src/replicate/resources/predictions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def with_streaming_response(self) -> PredictionsResourceWithStreamingResponse:
5454

5555
def wait(self, prediction_id: str) -> Prediction:
5656
"""Wait for prediction to finish."""
57-
prediction = self.retrieve(prediction_id)
57+
prediction = self.get(prediction_id)
5858
while prediction.status not in PREDICTION_TERMINAL_STATES:
5959
self._sleep(self._client.poll_interval)
60-
prediction = self.retrieve(prediction.id)
60+
prediction = self.get(prediction.id)
6161
return prediction
6262

6363
def create(
@@ -473,10 +473,10 @@ def with_streaming_response(self) -> AsyncPredictionsResourceWithStreamingRespon
473473

474474
async def wait(self, prediction_id: str) -> Prediction:
475475
"""Wait for prediction to finish."""
476-
prediction = await self.retrieve(prediction_id)
476+
prediction = await self.get(prediction_id)
477477
while prediction.status not in PREDICTION_TERMINAL_STATES:
478478
await self._sleep(self._client.poll_interval)
479-
prediction = await self.retrieve(prediction.id)
479+
prediction = await self.get(prediction.id)
480480
return prediction
481481

482482
async def create(

src/replicate/types/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams
1616
from .deployment_list_response import DeploymentListResponse as DeploymentListResponse
1717
from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams
18-
from .prediction_create_params import PredictionCreateParams as PredictionCreateParams
18+
from .prediction_create_params import (
19+
PredictionCreateParams as PredictionCreateParams,
20+
PredictionCreateParamsWithoutVersion as PredictionCreateParamsWithoutVersion,
21+
)
1922
from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse
2023
from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse
2124
from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse

src/replicate/types/prediction_create_params.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
from .._utils import PropertyInfo
99

10-
__all__ = ["PredictionCreateParams"]
10+
__all__ = ["PredictionCreateParams", "PredictionCreateParamsWithoutVersion"]
1111

1212

13-
class PredictionCreateParams(TypedDict, total=False):
13+
class PredictionCreateParamsWithoutVersion(TypedDict, total=False):
1414
input: Required[object]
1515
"""The model's input as a JSON object.
1616
@@ -36,9 +36,6 @@ class PredictionCreateParams(TypedDict, total=False):
3636
- you don't need to use the file again (Replicate will not store it)
3737
"""
3838

39-
version: Required[str]
40-
"""The ID of the model version that you want to run."""
41-
4239
stream: bool
4340
"""**This field is deprecated.**
4441
@@ -94,3 +91,8 @@ class PredictionCreateParams(TypedDict, total=False):
9491
"""
9592

9693
prefer: Annotated[str, PropertyInfo(alias="Prefer")]
94+
95+
96+
class PredictionCreateParams(PredictionCreateParamsWithoutVersion):
97+
version: Required[str]
98+
"""The ID of the model version that you want to run."""

0 commit comments

Comments
 (0)