Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,21 +223,29 @@ You can the models you've created:
replicate.models.list()
```

Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method. Here's how you can get all the models you've created:
Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method, or you can use the `paginate` method to fetch pages automatically.

```python
# Automatic pagination using `replicate.paginate` (recommended)
models = []
page = replicate.models.list()
for page in replicate.paginate(replicate.models.list):
models.extend(page.results)
if len(models) > 100:
break

# Manual pagination using `next` cursors
page = replicate.models.list()
while page:
models.extend(page.results)
if len(models) > 100:
break
page = replicate.models.list(page.next) if page.next else None
```

You can also find collections of featured models on Replicate:

```python
>>> collections = replicate.collections.list()
>>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page]
>>> collections[0].slug
"vision-models"
>>> collections[0].description
Expand Down
5 changes: 5 additions & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from replicate.client import Client
from replicate.pagination import async_paginate as _async_paginate
from replicate.pagination import paginate as _paginate

default_client = Client()

run = default_client.run
async_run = default_client.async_run

paginate = _paginate
async_paginate = _async_paginate

collections = default_client.collections
hardware = default_client.hardware
deployments = default_client.deployments
Expand Down
4 changes: 2 additions & 2 deletions replicate/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Collections(Namespace):

def list(
self,
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Collection]:
"""
List collections of models.
Expand All @@ -82,7 +82,7 @@ def list(

async def async_list(
self,
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Collection]:
"""
List collections of models.
Expand Down
7 changes: 5 additions & 2 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Models(Namespace):

model = Model

def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821
"""
List all public models.

Expand All @@ -164,7 +164,10 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F8

return Page[Model](**obj)

async def async_list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821
async def async_list(
self,
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Model]:
"""
List all public models.

Expand Down
37 changes: 37 additions & 0 deletions replicate/pagination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import (
TYPE_CHECKING,
AsyncGenerator,
Awaitable,
Callable,
Generator,
Generic,
List,
Optional,
TypeVar,
Union,
)

try:
Expand Down Expand Up @@ -41,3 +46,35 @@ def __getitem__(self, index: int) -> T:

def __len__(self) -> int:
return len(self.results)


def paginate(
list_method: Callable[[Union[str, "ellipsis", None]], Page[T]], # noqa: F821
) -> Generator[Page[T], None, None]:
"""
Iterate over all items using the provided list method.

Args:
list_method: A method that takes a cursor argument and returns a Page of items.
"""
cursor: Union[str, "ellipsis", None] = ... # noqa: F821
while cursor is not None:
page = list_method(cursor)
yield page
cursor = page.next


async def async_paginate(
list_method: Callable[[Union[str, "ellipsis", None]], Awaitable[Page[T]]], # noqa: F821
) -> AsyncGenerator[Page[T], None]:
"""
Asynchronously iterate over all items using the provided list method.

Args:
list_method: An async method that takes a cursor argument and returns a Page of items.
"""
cursor: Union[str, "ellipsis", None] = ... # noqa: F821
while cursor is not None:
page = await list_method(cursor)
yield page
cursor = page.next
4 changes: 2 additions & 2 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class Predictions(Namespace):
Namespace for operations related to predictions.
"""

def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noqa: F821
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Prediction]: # noqa: F821
"""
List your predictions.

Expand Down Expand Up @@ -197,7 +197,7 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noq

async def async_list(
self,
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Prediction]:
"""
List your predictions.
Expand Down
7 changes: 5 additions & 2 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class Trainings(Namespace):
Namespace for operations related to trainings.
"""

def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Training]: # noqa: F821
"""
List your trainings.

Expand All @@ -124,7 +124,10 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa:

return Page[Training](**obj)

async def async_list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821
async def async_list(
self,
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Training]:
"""
List your trainings.

Expand Down
29 changes: 29 additions & 0 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,32 @@
async def test_paginate_with_none_cursor(mock_replicate_api_token):
with pytest.raises(ValueError):
replicate.models.list(None)


@pytest.mark.vcr("collections-list.yaml")
@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_paginate(async_flag):
found = False

if async_flag:
async for page in replicate.async_paginate(replicate.collections.async_list):
assert page.next is None
assert page.previous is None

for collection in page:
if collection.slug == "text-to-image":
found = True
break

else:
for page in replicate.paginate(replicate.collections.list):
assert page.next is None
assert page.previous is None

for collection in page:
if collection.slug == "text-to-image":
found = True
break

assert found