From 15f9c5990f1e3d9f403dc05df87f254f8bfeee7a Mon Sep 17 00:00:00 2001 From: Mattt Date: Fri, 17 Nov 2023 07:28:26 -0500 Subject: [PATCH] Add `paginate` and `async_paginate` method (#197) This PR adds `paginate` and `async_paginate` methods that let you iterate over a paginated list of resources. **sync** ```python import replicate for page in replicate.paginate(replicate.collections.list): for collection in page: print(collection.name) ``` **async** ```python import replicate async for page in replicate.async_paginate(replicate.collections.async_list): for collection in page: print(collection.name) ``` --------- Signed-off-by: Mattt Zmuda --- README.md | 14 +++++++++++--- replicate/__init__.py | 5 +++++ replicate/collection.py | 4 ++-- replicate/model.py | 7 +++++-- replicate/pagination.py | 37 +++++++++++++++++++++++++++++++++++++ replicate/prediction.py | 4 ++-- replicate/training.py | 7 +++++-- tests/test_pagination.py | 29 +++++++++++++++++++++++++++++ 8 files changed, 96 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 2bf5146d..d3324fdd 100644 --- a/README.md +++ b/README.md @@ -223,21 +223,29 @@ You can the models you've created: replicate.models.list() ``` -Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method. Here's how you can get all the models you've created: +Lists of models are paginated. You can get the next page of models by passing the `next` property as an argument to the `list` method, or you can use the `paginate` method to fetch pages automatically. ```python +# Automatic pagination using `replicate.paginate` (recommended) models = [] -page = replicate.models.list() +for page in replicate.paginate(replicate.models.list): + models.extend(page.results) + if len(models) > 100: + break +# Manual pagination using `next` cursors +page = replicate.models.list() while page: models.extend(page.results) + if len(models) > 100: + break page = replicate.models.list(page.next) if page.next else None ``` You can also find collections of featured models on Replicate: ```python ->>> collections = replicate.collections.list() +>>> collections = [collection for page in replicate.paginate(replicate.collections.list) for collection in page] >>> collections[0].slug "vision-models" >>> collections[0].description diff --git a/replicate/__init__.py b/replicate/__init__.py index d78432d9..ea27f7db 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -1,10 +1,15 @@ from replicate.client import Client +from replicate.pagination import async_paginate as _async_paginate +from replicate.pagination import paginate as _paginate default_client = Client() run = default_client.run async_run = default_client.async_run +paginate = _paginate +async_paginate = _async_paginate + collections = default_client.collections hardware = default_client.hardware deployments = default_client.deployments diff --git a/replicate/collection.py b/replicate/collection.py index 56001d15..6c6ac912 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -55,7 +55,7 @@ class Collections(Namespace): def list( self, - cursor: Union[str, "ellipsis"] = ..., # noqa: F821 + cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 ) -> Page[Collection]: """ List collections of models. @@ -82,7 +82,7 @@ def list( async def async_list( self, - cursor: Union[str, "ellipsis"] = ..., # noqa: F821 + cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 ) -> Page[Collection]: """ List collections of models. diff --git a/replicate/model.py b/replicate/model.py index c771704c..43db9616 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -140,7 +140,7 @@ class Models(Namespace): model = Model - def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821 + def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Model]: # noqa: F821 """ List all public models. @@ -164,7 +164,10 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F8 return Page[Model](**obj) - async def async_list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Model]: # noqa: F821 + async def async_list( + self, + cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 + ) -> Page[Model]: """ List all public models. diff --git a/replicate/pagination.py b/replicate/pagination.py index b70b97b1..7472aa4d 100644 --- a/replicate/pagination.py +++ b/replicate/pagination.py @@ -1,9 +1,14 @@ from typing import ( TYPE_CHECKING, + AsyncGenerator, + Awaitable, + Callable, + Generator, Generic, List, Optional, TypeVar, + Union, ) try: @@ -41,3 +46,35 @@ def __getitem__(self, index: int) -> T: def __len__(self) -> int: return len(self.results) + + +def paginate( + list_method: Callable[[Union[str, "ellipsis", None]], Page[T]], # noqa: F821 +) -> Generator[Page[T], None, None]: + """ + Iterate over all items using the provided list method. + + Args: + list_method: A method that takes a cursor argument and returns a Page of items. + """ + cursor: Union[str, "ellipsis", None] = ... # noqa: F821 + while cursor is not None: + page = list_method(cursor) + yield page + cursor = page.next + + +async def async_paginate( + list_method: Callable[[Union[str, "ellipsis", None]], Awaitable[Page[T]]], # noqa: F821 +) -> AsyncGenerator[Page[T], None]: + """ + Asynchronously iterate over all items using the provided list method. + + Args: + list_method: An async method that takes a cursor argument and returns a Page of items. + """ + cursor: Union[str, "ellipsis", None] = ... # noqa: F821 + while cursor is not None: + page = await list_method(cursor) + yield page + cursor = page.next diff --git a/replicate/prediction.py b/replicate/prediction.py index a6239892..488e9a3b 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -172,7 +172,7 @@ class Predictions(Namespace): Namespace for operations related to predictions. """ - def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noqa: F821 + def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Prediction]: # noqa: F821 """ List your predictions. @@ -200,7 +200,7 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Prediction]: # noq async def async_list( self, - cursor: Union[str, "ellipsis"] = ..., # noqa: F821 + cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 ) -> Page[Prediction]: """ List your predictions. diff --git a/replicate/training.py b/replicate/training.py index 8a24180e..dff61a5b 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -101,7 +101,7 @@ class Trainings(Namespace): Namespace for operations related to trainings. """ - def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821 + def list(self, cursor: Union[str, "ellipsis", None] = ...) -> Page[Training]: # noqa: F821 """ List your trainings. @@ -127,7 +127,10 @@ def list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: return Page[Training](**obj) - async def async_list(self, cursor: Union[str, "ellipsis"] = ...) -> Page[Training]: # noqa: F821 + async def async_list( + self, + cursor: Union[str, "ellipsis", None] = ..., # noqa: F821 + ) -> Page[Training]: """ List your trainings. diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 86d86b6d..44b0c775 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -7,3 +7,32 @@ async def test_paginate_with_none_cursor(mock_replicate_api_token): with pytest.raises(ValueError): replicate.models.list(None) + + +@pytest.mark.vcr("collections-list.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_paginate(async_flag): + found = False + + if async_flag: + async for page in replicate.async_paginate(replicate.collections.async_list): + assert page.next is None + assert page.previous is None + + for collection in page: + if collection.slug == "text-to-image": + found = True + break + + else: + for page in replicate.paginate(replicate.collections.list): + assert page.next is None + assert page.previous is None + + for collection in page: + if collection.slug == "text-to-image": + found = True + break + + assert found