Skip to content

Commit

Permalink
Use Iterable instead Iterator on iterate_in_threadpool (#2362)
Browse files Browse the repository at this point in the history
* Fixed AsyncContentStream to be AsyncIterator

* Updating isinstance check too

* Standardizing on Iterable/AsyncIterable

* Moved iterate_in_threadpool to make an iter internally

* Added test of iterate_in_threadpool accepting an Iterable

* Renamed arg to iterator, and fixed type hint in return to be AsyncIterator

---------

Co-authored-by: Marcelo Trylesinski <[email protected]>
  • Loading branch information
jamesbraza and Kludex authored Dec 20, 2023
1 parent 866a15f commit 966f0fc
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
5 changes: 3 additions & 2 deletions starlette/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ def _next(iterator: typing.Iterator[T]) -> T:


async def iterate_in_threadpool(
iterator: typing.Iterator[T],
iterator: typing.Iterable[T],
) -> typing.AsyncIterator[T]:
as_iterator = iter(iterator)
while True:
try:
yield await anyio.to_thread.run_sync(_next, iterator)
yield await anyio.to_thread.run_sync(_next, as_iterator)
except _StopIteration:
break
2 changes: 1 addition & 1 deletion starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(


Content = typing.Union[str, bytes]
SyncContentStream = typing.Iterator[Content]
SyncContentStream = typing.Iterable[Content]
AsyncContentStream = typing.AsyncIterable[Content]
ContentStream = typing.Union[AsyncContentStream, SyncContentStream]

Expand Down
11 changes: 10 additions & 1 deletion tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from starlette.applications import Starlette
from starlette.concurrency import run_until_first_complete
from starlette.concurrency import iterate_in_threadpool, run_until_first_complete
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
Expand Down Expand Up @@ -40,3 +40,12 @@ def endpoint(request: Request) -> Response:

resp = client.get("/")
assert resp.content == b"data"


@pytest.mark.anyio
async def test_iterate_in_threadpool() -> None:
class CustomIterable:
def __iter__(self):
yield from range(3)

assert [v async for v in iterate_in_threadpool(CustomIterable())] == [0, 1, 2]

0 comments on commit 966f0fc

Please sign in to comment.