diff --git a/replicate/exceptions.py b/replicate/exceptions.py index 6302d10f..497f522d 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -10,6 +10,12 @@ class ReplicateException(Exception): class ModelError(ReplicateException): """An error from user's code in a model.""" + prediction_id: str + + def __init__(self, error: Optional[str], prediction_id: str) -> None: + self.prediction_id = prediction_id + super().__init__(error) + class ReplicateError(ReplicateException): """ diff --git a/replicate/prediction.py b/replicate/prediction.py index 871566d7..b5590682 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -249,7 +249,7 @@ def output_iterator(self) -> Iterator[Any]: self.reload() if self.status == "failed": - raise ModelError(self.error) + raise ModelError(self.error, self.id) output = self.output or [] new_output = output[len(previous_output) :] @@ -272,7 +272,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: await self.async_reload() if self.status == "failed": - raise ModelError(self.error) + raise ModelError(self.error, self.id) output = self.output or [] new_output = output[len(previous_output) :] diff --git a/replicate/run.py b/replicate/run.py index 975cc4dc..7fcbbf06 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -58,7 +58,7 @@ def run( prediction.wait() if prediction.status == "failed": - raise ModelError(prediction.error) + raise ModelError(prediction.error, prediction.id) return prediction.output @@ -97,7 +97,7 @@ async def async_run( await prediction.async_wait() if prediction.status == "failed": - raise ModelError(prediction.error) + raise ModelError(prediction.error, prediction.id) return prediction.output