Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ replicate.predictions.list()
# [<Prediction: 8b0ba5ab4d85>, <Prediction: 494900564e8c>]
```

Lists of predictions are paginated. You can get the next page of predictions by passing the `next` property as an argument to the `list` method:

```python
page1 = replicate.predictions.list()

if page1.next:
page2 = replicate.predictions.list(page1.next)
```

## Load output files

Output files are returned as HTTPS URLs. You can load an output file as a buffer:
Expand Down
4 changes: 2 additions & 2 deletions replicate/hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class Hardwares(Namespace):

def list(self) -> List[Hardware]:
"""
List all public models.
List all hardware available for you to run models on Replicate.

Returns:
A list of models.
List[Hardware]: A list of hardware.
"""

resp = self._client._request("GET", "/v1/hardware")
Expand Down
20 changes: 13 additions & 7 deletions replicate/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Dict, List, Optional, Union
from typing import Dict, Optional, Union

from typing_extensions import deprecated

from replicate.exceptions import ReplicateException
from replicate.pagination import Page
from replicate.prediction import Prediction
from replicate.resource import Namespace, Resource
from replicate.version import Version, Versions
Expand Down Expand Up @@ -123,18 +124,23 @@ class Models(Namespace):

model = Model

def list(self) -> List[Model]:
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821
"""
List all public models.

Parameters:
cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`.
Returns:
A list of models.
Page[Model]: A page of of models.
Raises:
ValueError: If `cursor` is `None`.
"""

resp = self._client._request("GET", "/v1/models")
# TODO: paginate
models = resp.json()["results"]
return [self._prepare_model(obj) for obj in models]
if cursor is None:
raise ValueError("cursor cannot be None")

resp = self._client._request("GET", "/v1/models" if cursor is ... else cursor)
return Page[Model](self._client, self, **resp.json())

def get(self, key: str) -> Model:
"""
Expand Down
66 changes: 66 additions & 0 deletions replicate/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import (
TYPE_CHECKING,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
)

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore

from replicate.resource import Namespace, Resource

T = TypeVar("T", bound=Resource)

if TYPE_CHECKING:
from .client import Client


class Page(pydantic.BaseModel, Generic[T]):
"""
A page of results from the API.
"""

_client: "Client" = pydantic.PrivateAttr()
_namespace: Namespace = pydantic.PrivateAttr()

previous: Optional[str] = None
"""A pointer to the previous page of results"""

next: Optional[str] = None
"""A pointer to the next page of results"""

results: List[T]
"""The results on this page"""

def __init__(
self,
client: "Client",
namespace: Namespace[T],
*,
results: Optional[List[Union[T, Dict]]] = None,
**kwargs,
) -> None:
self._client = client
self._namespace = namespace

super().__init__(
results=[self._namespace._prepare_model(r) for r in results]
if results
else None,
**kwargs,
)

def __iter__(self): # noqa: ANN204
return iter(self.results)

def __getitem__(self, index: int) -> T:
return self.results[index]

def __len__(self) -> int:
return len(self.results)
23 changes: 14 additions & 9 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from replicate.exceptions import ModelError
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
from replicate.version import Version

Expand Down Expand Up @@ -157,21 +158,25 @@ class Predictions(Namespace):

model = Prediction

def list(self) -> List[Prediction]:
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noqa: F821
"""
List your predictions.

Parameters:
cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`.
Returns:
A list of prediction objects.
Page[Prediction]: A page of of predictions.
Raises:
ValueError: If `cursor` is `None`.
"""

resp = self._client._request("GET", "/v1/predictions")
# TODO: paginate
predictions = resp.json()["results"]
for prediction in predictions:
# HACK: resolve this? make it lazy somehow?
del prediction["version"]
return [self._prepare_model(obj) for obj in predictions]
if cursor is None:
raise ValueError("cursor cannot be None")

resp = self._client._request(
"GET", "/v1/predictions" if cursor is ... else cursor
)
return Page[Prediction](self._client, self, **resp.json())

def get(self, id: str) -> Prediction: # pylint: disable=invalid-name
"""
Expand Down
23 changes: 14 additions & 9 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from replicate.exceptions import ReplicateException
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.pagination import Page
from replicate.resource import Namespace, Resource
from replicate.version import Version

Expand Down Expand Up @@ -90,21 +91,25 @@ class CreateParams(TypedDict):
webhook_completed: NotRequired[str]
webhook_events_filter: NotRequired[List[str]]

def list(self) -> List[Training]:
def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821
"""
List your trainings.

Parameters:
cursor: The cursor to use for pagination. Use the value of `Page.next` or `Page.previous`.
Returns:
List[Training]: A list of training objects.
Page[Training]: A page of trainings.
Raises:
ValueError: If `cursor` is `None`.
"""

resp = self._client._request("GET", "/v1/trainings")
# TODO: paginate
trainings = resp.json()["results"]
for training in trainings:
# HACK: resolve this? make it lazy somehow?
del training["version"]
return [self._prepare_model(obj) for obj in trainings]
if cursor is None:
raise ValueError("cursor cannot be None")

resp = self._client._request(
"GET", "/v1/trainings" if cursor is ... else cursor
)
return Page[Training](self._client, self, **resp.json())

def get(self, id: str) -> Training: # pylint: disable=invalid-name
"""
Expand Down
Loading