diff --git a/replicate/prediction.py b/replicate/prediction.py index dd7a593c..db197ae3 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -20,6 +20,7 @@ class Prediction(BaseModel): started_at: Optional[str] created_at: Optional[str] completed_at: Optional[str] + urls: Optional[Dict[str, str]] def wait(self) -> None: """Wait for prediction to finish.""" diff --git a/tests/test_prediction.py b/tests/test_prediction.py index a0d08ae9..3a336a86 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -156,6 +156,7 @@ def test_async_timings(): assert prediction.created_at == "2022-04-26T20:00:40.658234Z" assert prediction.completed_at is None assert prediction.output is None + assert prediction.urls["get"] == "https://api.replicate.com/v1/predictions/p1" prediction.wait() assert prediction.created_at == "2022-04-26T20:00:40.658234Z" assert prediction.completed_at == "2022-04-26T20:02:27.648305Z"