From ed9b8fa3640407cdf45d395e90763a39581ea258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Sun, 31 Mar 2019 18:33:09 +0400 Subject: [PATCH] Add concurrency.iterator_to_async, tests and docs --- docs/responses.md | 25 +++++++++++++++++++++++++ starlette/concurrency.py | 22 ++++++++++++++++++++++ tests/test_responses.py | 19 +++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/docs/responses.md b/docs/responses.md index 595c9101c..d56020b99 100644 --- a/docs/responses.md +++ b/docs/responses.md @@ -180,6 +180,31 @@ class App: await response(receive, send) ``` +If you have a standard generator or iterator (instead of an async generator), you can wrap it with `starlette.concurrency.iterator_to_async` to convert it to an async generator. + +Then you can use it with a `StreamingResponse`. + +This is specially useful for synchronous file-like or streaming objects, like those provided by cloud storage providers. + +```python +from starlette.responses import StreamingResponse +from starlette.concurrency import iterator_to_async + +def get_stream(): + # this would return an iterator or file-like object, etc. + pass + +class App: + def __init__(self, scope): + assert scope['type'] == 'http' + self.scope = scope + + async def __call__(self, receive, send): + generator = iterator_to_async(get_stream()) + response = StreamingResponse(generator, media_type='application/octet-stream') + await response(receive, send) +``` + ### FileResponse Asynchronously streams a file as the response. diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 35b589956..968659046 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,6 +1,7 @@ import asyncio import functools import typing +from typing import Any, AsyncGenerator, Iterator try: import contextvars # Python 3.7+ only. @@ -22,3 +23,24 @@ async def run_in_threadpool( # loop.run_in_executor doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) return await loop.run_in_executor(None, func, *args) + + +class _StopSyncIteration(Exception): + pass + + +def _interceptable_next(iterator: Iterator) -> Any: + try: + result = next(iterator) + return result + except StopIteration: + raise _StopSyncIteration + + +async def iterator_to_async(iterator: Iterator) -> AsyncGenerator: + while True: + try: + result = await run_in_threadpool(_interceptable_next, iterator) + yield result + except _StopSyncIteration: + break diff --git a/tests/test_responses.py b/tests/test_responses.py index 300975afe..081b97dba 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -5,6 +5,7 @@ from starlette import status from starlette.background import BackgroundTask +from starlette.concurrency import iterator_to_async from starlette.requests import Request from starlette.responses import ( FileResponse, @@ -90,6 +91,24 @@ async def numbers_for_cleanup(start=1, stop=5): assert filled_by_bg_task == "6, 7, 8, 9" +def test_streaming_response_from_sync_stream(): + async def app(scope, receive, send): + def numbers(minimum, maximum): + for i in range(minimum, maximum + 1): + yield str(i) + if i != maximum: + yield ", " + + generator = numbers(1, 5) + aio_generator = iterator_to_async(generator) + response = StreamingResponse(aio_generator, media_type="text/plain") + await response(scope, receive, send) + + client = TestClient(app) + response = client.get("/") + assert response.text == "1, 2, 3, 4, 5" + + def test_response_headers(): async def app(scope, receive, send): headers = {"x-header-1": "123", "x-header-2": "456"}