Skip to content

Commit

Permalink
Add paginate and async_paginate method (#197)
Browse files Browse the repository at this point in the history
This PR adds `paginate` and `async_paginate` methods that let you
iterate over a paginated list of resources.

**sync**

```python
import replicate

for page in replicate.paginate(replicate.collections.list):
  for collection in page:
      print(collection.name)
```

**async**

```python
import replicate

async for page in replicate.async_paginate(replicate.collections.async_list):
  for collection in page:
      print(collection.name)
```

---------

Signed-off-by: Mattt Zmuda <[email protected]>
  • Loading branch information
mattt authored Nov 17, 2023
1 parent 45e3020 commit 15f9c59
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 11 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,21 +223,29 @@ You can the models you've created:
replicate.models.list()
```

Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method. Here's how you can get all the models you've created:
Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method, or you can use the `paginate` method to fetch pages automatically.

```python
# Automatic pagination using `replicate.paginate` (recommended)
models = []
page = replicate.models.list()
for page in replicate.paginate(replicate.models.list):
models.extend(page.results)
if len(models) > 100:
break

# Manual pagination using `next` cursors
page = replicate.models.list()
while page:
models.extend(page.results)
if len(models) > 100:
break
page = replicate.models.list(page.next) if page.next else None
```

You can also find collections of featured models on Replicate:

```python
>>> collections = replicate.collections.list()
>>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page]
>>> collections[0].slug
"vision-models"
>>> collections[0].description
Expand Down
5 changes: 5 additions & 0 deletions replicate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from replicate.client import Client
from replicate.pagination import async_paginate as _async_paginate
from replicate.pagination import paginate as _paginate

default_client = Client()

run = default_client.run
async_run = default_client.async_run

paginate = _paginate
async_paginate = _async_paginate

collections = default_client.collections
hardware = default_client.hardware
deployments = default_client.deployments
Expand Down
4 changes: 2 additions & 2 deletions replicate/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Collections(Namespace):

def list(
self,
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Collection]:
"""
List collections of models.
Expand All @@ -82,7 +82,7 @@ def list(

async def async_list(
self,
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Collection]:
"""
List collections of models.
Expand Down
7 changes: 5 additions & 2 deletions replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class Models(Namespace):

model = Model

def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821
"""
List all public models.
Expand All @@ -164,7 +164,10 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F8

return Page[Model](**obj)

async def async_list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821
async def async_list(
self,
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Model]:
"""
List all public models.
Expand Down
37 changes: 37 additions & 0 deletions replicate/pagination.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import (
TYPE_CHECKING,
AsyncGenerator,
Awaitable,
Callable,
Generator,
Generic,
List,
Optional,
TypeVar,
Union,
)

try:
Expand Down Expand Up @@ -41,3 +46,35 @@ def __getitem__(self, index: int) -> T:

def __len__(self) -> int:
return len(self.results)


def paginate(
list_method: Callable[[Union[str, "ellipsis", None]], Page[T]], # noqa: F821
) -> Generator[Page[T], None, None]:
"""
Iterate over all items using the provided list method.
Args:
list_method: A method that takes a cursor argument and returns a Page of items.
"""
cursor: Union[str, "ellipsis", None] = ... # noqa: F821
while cursor is not None:
page = list_method(cursor)
yield page
cursor = page.next


async def async_paginate(
list_method: Callable[[Union[str, "ellipsis", None]], Awaitable[Page[T]]], # noqa: F821
) -> AsyncGenerator[Page[T], None]:
"""
Asynchronously iterate over all items using the provided list method.
Args:
list_method: An async method that takes a cursor argument and returns a Page of items.
"""
cursor: Union[str, "ellipsis", None] = ... # noqa: F821
while cursor is not None:
page = await list_method(cursor)
yield page
cursor = page.next
4 changes: 2 additions & 2 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class Predictions(Namespace):
Namespace for operations related to predictions.
"""

def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noqa: F821
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Prediction]: # noqa: F821
"""
List your predictions.
Expand Down Expand Up @@ -200,7 +200,7 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noq

async def async_list(
self,
cursor: Union[str, "ellipsis"] = ..., # noqa: F821
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Prediction]:
"""
List your predictions.
Expand Down
7 changes: 5 additions & 2 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Trainings(Namespace):
Namespace for operations related to trainings.
"""

def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821
def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Training]: # noqa: F821
"""
List your trainings.
Expand All @@ -127,7 +127,10 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa:

return Page[Training](**obj)

async def async_list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821
async def async_list(
self,
cursor: Union[str, "ellipsis", None] = ..., # noqa: F821
) -> Page[Training]:
"""
List your trainings.
Expand Down
29 changes: 29 additions & 0 deletions tests/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,32 @@
async def test_paginate_with_none_cursor(mock_replicate_api_token):
with pytest.raises(ValueError):
replicate.models.list(None)


@pytest.mark.vcr("collections-list.yaml")
@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_paginate(async_flag):
found = False

if async_flag:
async for page in replicate.async_paginate(replicate.collections.async_list):
assert page.next is None
assert page.previous is None

for collection in page:
if collection.slug == "text-to-image":
found = True
break

else:
for page in replicate.paginate(replicate.collections.list):
assert page.next is None
assert page.previous is None

for collection in page:
if collection.slug == "text-to-image":
found = True
break

assert found

0 comments on commit 15f9c59

Please sign in to comment.