diff --git a/starlette/concurrency.py b/starlette/concurrency.py index c44ee840f..d19020183 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -55,10 +55,11 @@ def _next(iterator: typing.Iterator[T]) -> T: async def iterate_in_threadpool( - iterator: typing.Iterator[T], + iterator: typing.Iterable[T], ) -> typing.AsyncIterator[T]: + as_iterator = iter(iterator) while True: try: - yield await anyio.to_thread.run_sync(_next, iterator) + yield await anyio.to_thread.run_sync(_next, as_iterator) except _StopIteration: break diff --git a/starlette/responses.py b/starlette/responses.py index b33aa8618..c99c64f58 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -204,7 +204,7 @@ def __init__( Content = typing.Union[str, bytes] -SyncContentStream = typing.Iterator[Content] +SyncContentStream = typing.Iterable[Content] AsyncContentStream = typing.AsyncIterable[Content] ContentStream = typing.Union[AsyncContentStream, SyncContentStream] diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 22b9da0e8..61fe5ff7b 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -4,7 +4,7 @@ import pytest from starlette.applications import Starlette -from starlette.concurrency import run_until_first_complete +from starlette.concurrency import iterate_in_threadpool, run_until_first_complete from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route @@ -40,3 +40,12 @@ def endpoint(request: Request) -> Response: resp = client.get("/") assert resp.content == b"data" + + +@pytest.mark.anyio +async def test_iterate_in_threadpool() -> None: + class CustomIterable: + def __iter__(self): + yield from range(3) + + assert [v async for v in iterate_in_threadpool(CustomIterable())] == [0, 1, 2]