diff --git a/replicate/prediction.py b/replicate/prediction.py index 298a4655..f40a587a 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,4 +1,6 @@ +import re import time +from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional from replicate.base_model import BaseModel @@ -56,6 +58,46 @@ class Prediction(BaseModel): - `cancel`: A URL to cancel the prediction. """ + @dataclass + class Progress: + percentage: float + """The percentage of the prediction that has completed.""" + + current: int + """The number of items that have been processed.""" + + total: int + """The total number of items to process.""" + + _pattern = re.compile( + r"^\s*(?P\d+)%\s*\|.+?\|\s*(?P\d+)\/(?P\d+)" + ) + + @classmethod + def parse(cls, logs: str) -> Optional["Prediction.Progress"]: + """Parse the progress from the logs of a prediction.""" + + lines = logs.split("\n") + for i in reversed(range(len(lines))): + line = lines[i].strip() + if cls._pattern.match(line): + matches = cls._pattern.findall(line) + if len(matches) == 1: + percentage, current, total = map(int, matches[0]) + return cls(percentage / 100.0, current, total) + + return None + + @property + def progress(self) -> Optional[Progress]: + """ + The progress of the prediction, if available. + """ + if self.logs is None or self.logs == "": + return None + + return Prediction.Progress.parse(self.logs) + def wait(self) -> None: """ Wait for prediction to finish. diff --git a/tests/test_prediction.py b/tests/test_prediction.py index ad6ccba6..4b330015 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,6 +1,8 @@ import responses from responses import matchers +from replicate.prediction import Prediction + from .factories import create_client, create_version @@ -214,3 +216,63 @@ def test_async_timings(): assert prediction.completed_at == "2022-04-26T20:02:27.648305Z" assert prediction.output == "hello world" assert prediction.metrics["predict_time"] == 1.2345 + + +def test_prediction_progress(): + client = create_client() + version = create_version(client) + prediction = Prediction( + id="ufawqhfynnddngldkgtslldrkq", version=version, status="starting" + ) + + lines = [ + "Using seed: 12345", + "0%| | 0/5 [00:00