diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000000..2f87d94ca1 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: encode diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index eed46850a9..751c5193be 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9"] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10.0-beta.3"] steps: - uses: "actions/checkout@v2" diff --git a/README.md b/README.md index 7e315627ed..28b6f49854 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ # Starlette Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit, -which is ideal for building high performance asyncio services. +which is ideal for building high performance async services. It is production-ready, and gives you the following: @@ -35,7 +35,8 @@ It is production-ready, and gives you the following: * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. -* Zero hard dependencies. +* Few hard dependencies. +* Compatible with `asyncio` and `trio` backends. ## Requirements @@ -83,10 +84,9 @@ For a more complete example, see [encode/starlette-example](https://github.com/e ## Dependencies -Starlette does not have any hard dependencies, but the following are optional: +Starlette only requires `anyio`, and the following are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. -* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -165,7 +165,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ -[aiofiles]: https://github.com/Tinche/aiofiles [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [graphene]: https://graphene-python.org/ diff --git a/docs/config.md b/docs/config.md index 7a93b22e9e..f7c2c7b7de 100644 --- a/docs/config.md +++ b/docs/config.md @@ -160,7 +160,7 @@ organisations = sqlalchemy.Table( ```python from starlette.applications import Starlette from starlette.middleware import Middleware -from starlette.middleware.session import SessionMiddleware +from starlette.middleware.sessions import SessionMiddleware from starlette.routing import Route from myproject import settings @@ -192,7 +192,7 @@ and drop it once the tests complete. We'd also like to ensure from starlette.config import environ from starlette.testclient import TestClient from sqlalchemy import create_engine -from sqlalchemy_utils import database_exists, create_database +from sqlalchemy_utils import create_database, database_exists, drop_database # This line would raise an error if we use it after 'settings' has been imported. environ['TESTING'] = 'TRUE' diff --git a/docs/database.md b/docs/database.md index ca1b85d6a2..aa6cb74edf 100644 --- a/docs/database.md +++ b/docs/database.md @@ -142,10 +142,10 @@ async def populate_note(request): await database.execute(query) raise RuntimeError() except: - transaction.rollback() + await transaction.rollback() raise else: - transaction.commit() + await transaction.commit() ``` ## Test isolation diff --git a/docs/events.md b/docs/events.md index 4f2bce558b..c7ed49e9dc 100644 --- a/docs/events.md +++ b/docs/events.md @@ -37,6 +37,31 @@ registered startup handlers have completed. The shutdown handlers will run once all connections have been closed, and any in-process background tasks have completed. +A single lifespan asynccontextmanager handler can be used instead of +separate startup and shutdown handlers: + +```python +import contextlib +import anyio +from starlette.applications import Starlette + + +@contextlib.asynccontextmanager +async def lifespan(app): + async with some_async_resource(): + yield + + +routes = [ + ... +] + +app = Starlette(routes=routes, lifespan=lifespan) +``` + +Consider using [`anyio.create_task_group()`](https://anyio.readthedocs.io/en/stable/tasks.html) +for managing asynchronious tasks. + ## Running event handlers in tests You might want to explicitly call into your event handlers in any test setup diff --git a/docs/img/graphiql.png b/docs/img/graphiql.png deleted file mode 100644 index 7851993f7f..0000000000 Binary files a/docs/img/graphiql.png and /dev/null differ diff --git a/docs/index.md b/docs/index.md index 4ae77f0e60..b9692a1fbc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,7 +32,7 @@ It is production-ready, and gives you the following: * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. -* Zero hard dependencies. +* Few hard dependencies. ## Requirements @@ -79,10 +79,9 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam ## Dependencies -Starlette does not have any hard dependencies, but the following are optional: +Starlette only requires `anyio`, and the following dependencies are optional: * [`requests`][requests] - Required if you want to use the `TestClient`. -* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`. * [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`. * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. @@ -161,7 +160,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...

Starlette is BSD licensed code. Designed & built in Brighton, England.

[requests]: http://docs.python-requests.org/en/master/ -[aiofiles]: https://github.com/Tinche/aiofiles [jinja2]: http://jinja.pocoo.org/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [graphene]: https://graphene-python.org/ diff --git a/docs/middleware.md b/docs/middleware.md index 7d1233d100..4c6fb8a9e0 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -183,6 +183,8 @@ To implement a middleware class using `BaseHTTPMiddleware`, you must override th `async def dispatch(request, call_next)` method. ```python +from starlette.middleware.base import BaseHTTPMiddleware + class CustomHeaderMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): response = await call_next(request) @@ -256,7 +258,7 @@ when proxy servers are being used, based on the `X-Forwarded-Proto` and `X-Forwa A middleware class to emit timing information (cpu and wall time) for each request which passes through it. Includes examples for how to emit these timings as statsd metrics. -#### [datasette-auth-github](https://github.com/simonw/datasette-auth-github) +#### [asgi-auth-github](https://github.com/simonw/asgi-auth-github) This middleware adds authentication to any ASGI application, requiring users to sign in using their GitHub account (via [OAuth](https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps/)). diff --git a/docs/overrides/partials/nav.html b/docs/overrides/partials/nav.html new file mode 100644 index 0000000000..d4684d0a68 --- /dev/null +++ b/docs/overrides/partials/nav.html @@ -0,0 +1,52 @@ + + {% set class = "md-nav md-nav--primary" %} + {% if "navigation.tabs" in features %} + {% set class = class ~ " md-nav--lifted" %} + {% endif %} + {% if "toc.integrate" in features %} + {% set class = class ~ " md-nav--integrated" %} + {% endif %} + + + diff --git a/docs/release-notes.md b/docs/release-notes.md index 08f1e2895b..7305046f9d 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,11 +1,72 @@ +## 0.16.0 + +July 19, 2021 + +### Added + * Added [Encode](https://github.com/sponsors/encode) funding option + [#1219](https://github.com/encode/starlette/pull/1219) + +### Fixed + * `starlette.websockets.WebSocket` instances are now hashable and compare by identity + [#1039](https://github.com/encode/starlette/pull/1039) + * A number of fixes related to running task groups in lifespan + [#1213](https://github.com/encode/starlette/pull/1213), + [#1227](https://github.com/encode/starlette/pull/1227) + +### Deprecated/removed + * The method `starlette.templates.Jinja2Templates.get_env` was removed + [#1218](https://github.com/encode/starlette/pull/1218) + * The ClassVar `starlette.testclient.TestClient.async_backend` was removed, + the backend is now configured using constructor kwargs + [#1211](https://github.com/encode/starlette/pull/1211) + * Passing an Async Generator Function or a Generator Function to `starlette.router.Router(lifespan_context=)` is deprecated. You should wrap your lifespan in `@contextlib.asynccontextmanager`. + [#1227](https://github.com/encode/starlette/pull/1227) + [#1110](https://github.com/encode/starlette/pull/1110) + ## 0.15.0 -Unreleased +June 23, 2021 + +This release includes major changes to the low-level asynchronous parts of Starlette. As a result, +**Starlette now depends on [AnyIO](https://anyio.readthedocs.io/en/stable/)** and some minor API +changes have occurred. Another significant change with this release is the +**deprecation of built-in GraphQL support**. -### Deprecated +### Added +* Starlette now supports [Trio](https://trio.readthedocs.io/en/stable/) as an async runtime via + AnyIO - [#1157](https://github.com/encode/starlette/pull/1157). +* `TestClient.websocket_connect()` now must be used as a context manager. +* Initial support for Python 3.10 - [#1201](https://github.com/encode/starlette/pull/1201). +* The compression level used in `GZipMiddleware` is now adjustable - + [#1128](https://github.com/encode/starlette/pull/1128). + +### Fixed +* Several fixes to `CORSMiddleware`. See [#1111](https://github.com/encode/starlette/pull/1111), + [#1112](https://github.com/encode/starlette/pull/1112), + [#1113](https://github.com/encode/starlette/pull/1113), + [#1199](https://github.com/encode/starlette/pull/1199). +* Improved exception messages in the case of duplicated path parameter names - + [#1177](https://github.com/encode/starlette/pull/1177). +* `RedirectResponse` now uses `quote` instead of `quote_plus` encoding for the `Location` header + to better match the behaviour in other frameworks such as Django - + [#1164](https://github.com/encode/starlette/pull/1164). +* Exception causes are now preserved in more cases - + [#1158](https://github.com/encode/starlette/pull/1158). +* Session cookies now use the ASGI root path in the case of mounted applications - + [#1147](https://github.com/encode/starlette/pull/1147). +* Fixed a cache invalidation bug when static files were deleted in certain circumstances - + [#1023](https://github.com/encode/starlette/pull/1023). +* Improved memory usage of `BaseHTTPMiddleware` when handling large responses - + [#1012](https://github.com/encode/starlette/issues/1012) fixed via #1157 + +### Deprecated/removed * Built-in GraphQL support via the `GraphQLApp` class has been deprecated and will be removed in a - future release. Please see [#619](https://github.com/encode/starlette/issues/619). + future release. Please see [#619](https://github.com/encode/starlette/issues/619). GraphQL is not + supported on Python 3.10. +* The `executor` parameter to `GraphQLApp` was removed. Use `executor_class` instead. +* The `workers` parameter to `WSGIMiddleware` was removed. This hasn't had any effect since + Starlette v0.6.3. ## 0.14.2 diff --git a/docs/responses.md b/docs/responses.md index c4cd84ed32..5284ac5044 100644 --- a/docs/responses.md +++ b/docs/responses.md @@ -182,9 +182,8 @@ async def app(scope, receive, send): await response(scope, receive, send) ``` -## Third party middleware +## Third party responses -### [SSEResponse(EventSourceResponse)](https://github.com/sysid/sse-starlette) +#### [EventSourceResponse](https://github.com/sysid/sse-starlette) -Server Sent Response implements the ServerSentEvent Protocol: https://www.w3.org/TR/2009/WD-eventsource-20090421. -It enables event streaming from the server to the client without the complexity of websockets. +A response class that implements [Server-Sent Events](https://html.spec.whatwg.org/multipage/server-sent-events.html). It enables event streaming from the server to the client without the complexity of websockets. diff --git a/docs/schemas.md b/docs/schemas.md index 2530ba8d1f..275e7b2968 100644 --- a/docs/schemas.md +++ b/docs/schemas.md @@ -51,7 +51,7 @@ routes = [ Route("/schema", endpoint=openapi_schema, include_in_schema=False) ] -app = Starlette() +app = Starlette(routes=routes) ``` We can now access an OpenAPI schema at the "/schema" endpoint. diff --git a/docs/sponsors/fastapi.png b/docs/sponsors/fastapi.png new file mode 100644 index 0000000000..a5b2af17eb Binary files /dev/null and b/docs/sponsors/fastapi.png differ diff --git a/docs/testclient.md b/docs/testclient.md index 61f7201c62..a1861efec7 100644 --- a/docs/testclient.md +++ b/docs/testclient.md @@ -31,6 +31,29 @@ application. Occasionally you might want to test the content of 500 error responses, rather than allowing client to raise the server exception. In this case you should use `client = TestClient(app, raise_server_exceptions=False)`. +### Selecting the Async backend + +`TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary). +These options are passed to `anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options) +for more information about the accepted backend options. +By default, `asyncio` is used with default options. + +To run `Trio`, pass `backend="trio"`. For example: + +```python +def test_app() + with TestClient(app, backend="trio") as client: + ... +``` + +To run `asyncio` with `uvloop`, pass `backend_options={"use_uvloop": True}`. For example: + +```python +def test_app() + with TestClient(app, backend_options={"use_uvloop": True}) as client: + ... +``` + ### Testing WebSocket sessions You can also test websocket sessions with the test client. @@ -72,6 +95,8 @@ always raised by the test client. May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection. +`websocket_connect()` must be used as a context manager (in a `with` block). + #### Sending data * `.send_text(data)` - Send the given text to the application. diff --git a/docs/third-party-packages.md b/docs/third-party-packages.md index 71902ce83f..a21f31ba2c 100644 --- a/docs/third-party-packages.md +++ b/docs/third-party-packages.md @@ -20,7 +20,7 @@ Simple APISpec integration for Starlette. Document your REST API built with Starlette by declaring OpenAPI (Swagger) schemas in YAML format in your endpoint's docstrings. -### SpecTree +### SpecTree GitHub @@ -43,7 +43,7 @@ Checkout nejma GitHub -Another solution for websocket broadcast. Send messages to channel groups from any part of your code. +Another solution for websocket broadcast. Send messages to channel groups from any part of your code. Checkout channel-box-chat, a simple chat application built using `channel-box` and `starlette`. ### Scout APM @@ -90,6 +90,21 @@ It relies solely on an auth provider to issue access and/or id tokens to clients Middleware for Starlette that allows you to store and access the context data of a request. Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id. + +### Starsessions + +GitHub + +An alternate session support implementation with customizable storage backends. + + +### Starlette Cramjam + +GitHub + +A Starlette middleware that allows **brotli**, **gzip** and **deflate** compression algorithm with a minimal requirements. + + ## Frameworks ### Responder @@ -116,3 +131,9 @@ Inspired by **APIStar**'s previous server system with type declarations for rout Formerly Starlette API. Flama aims to bring a layer on top of Starlette to provide an **easy to learn** and **fast to develop** approach for building **highly performant** GraphQL and REST APIs. In the same way of Starlette is, Flama is a perfect option for developing **asynchronous** and **production-ready** services. + +### Starlette-apps + +Roll your own framework with a simple app system, like [Django-GDAPS](https://gdaps.readthedocs.io/en/latest/) or [CakePHP](https://cakephp.org/). + +GitHub diff --git a/mkdocs.yml b/mkdocs.yml index 0f10c19ff6..2ab5faa680 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,8 +1,22 @@ site_name: Starlette site_description: The little ASGI library that shines. +site_url: https://www.starlette.io theme: name: 'material' + custom_dir: docs/overrides + palette: + - scheme: 'default' + media: '(prefers-color-scheme: light)' + toggle: + icon: 'material/lightbulb' + name: "Switch to dark mode" + - scheme: 'slate' + media: '(prefers-color-scheme: dark)' + primary: 'blue' + toggle: + icon: 'material/lightbulb-outline' + name: 'Switch to light mode' repo_name: encode/starlette repo_url: https://github.com/encode/starlette diff --git a/requirements.txt b/requirements.txt index 6ec5bf09ee..abc7a3b0a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ # Optionals -aiofiles -graphene +graphene; python_version<'3.10' itsdangerous jinja2 python-multipart @@ -10,17 +9,17 @@ requests # Testing autoflake black==20.8b1 +coverage>=5.3 databases[sqlite] flake8 isort==5.* mypy types-requests types-contextvars -types-aiofiles types-PyYAML +types-dataclasses pytest -pytest-cov -pytest-asyncio +trio # Documentation mkdocs diff --git a/scripts/test b/scripts/test index f9c9917233..720a66392d 100755 --- a/scripts/test +++ b/scripts/test @@ -11,7 +11,7 @@ if [ -z $GITHUB_ACTIONS ]; then scripts/check fi -${PREFIX}pytest $@ +${PREFIX}coverage run -m pytest $@ if [ -z $GITHUB_ACTIONS ]; then scripts/coverage diff --git a/setup.cfg b/setup.cfg index 196414c0a7..39f67c2150 100644 --- a/setup.cfg +++ b/setup.cfg @@ -17,9 +17,6 @@ combine_as_imports = True [tool:pytest] addopts = - --cov-report=term-missing:skip-covered - --cov=starlette - --cov=tests -rxXs --strict-config --strict-markers @@ -32,3 +29,8 @@ filterwarnings= ignore: Using or importing the ABCs from 'collections' instead of from 'collections\.abc' is deprecated.*:DeprecationWarning ignore: The 'context' alias has been deprecated. Please use 'context_value' instead\.:DeprecationWarning ignore: The 'variables' alias has been deprecated. Please use 'variable_values' instead\.:DeprecationWarning + # Workaround for Python 3.9.7 (see https://bugs.python.org/issue45097) + ignore:The loop argument is deprecated since Python 3\.8, and scheduled for removal in Python 3\.10\.:DeprecationWarning:asyncio + +[coverage:run] +source_pkgs = starlette, tests diff --git a/setup.py b/setup.py index c48356370c..31789fe09d 100644 --- a/setup.py +++ b/setup.py @@ -37,10 +37,14 @@ def get_long_description(): packages=find_packages(exclude=["tests*"]), package_data={"starlette": ["py.typed"]}, include_package_data=True, + install_requires=[ + "anyio>=3.0.0,<4", + "typing_extensions; python_version < '3.8'", + "contextlib2 >= 21.6.0; python_version < '3.7'", + ], extras_require={ "full": [ - "aiofiles", - "graphene", + "graphene; python_version<'3.10'", "itsdangerous", "jinja2", "python-multipart", @@ -60,6 +64,7 @@ def get_long_description(): "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], zip_safe=False, ) diff --git a/starlette/__init__.py b/starlette/__init__.py index 745162e731..5a313cc7ef 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1 +1 @@ -__version__ = "0.14.2" +__version__ = "0.16.0" diff --git a/starlette/applications.py b/starlette/applications.py index 34c3e38bd9..ea52ee70ee 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -46,7 +46,7 @@ def __init__( ] = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, - lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None, + lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. diff --git a/starlette/concurrency.py b/starlette/concurrency.py index c8c5d57acb..e89d1e0471 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -1,33 +1,32 @@ -import asyncio import functools -import sys import typing from typing import Any, AsyncGenerator, Iterator +import anyio + try: import contextvars # Python 3.7+ only or via contextvars backport. except ImportError: # pragma: no cover contextvars = None # type: ignore -if sys.version_info >= (3, 7): # pragma: no cover - from asyncio import create_task -else: # pragma: no cover - from asyncio import ensure_future as create_task T = typing.TypeVar("T") async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: - tasks = [create_task(handler(**kwargs)) for handler, kwargs in args] - (done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) - [task.cancel() for task in pending] - [task.result() for task in done] + async with anyio.create_task_group() as task_group: + + async def run(func: typing.Callable[[], typing.Coroutine]) -> None: + await func() + task_group.cancel_scope.cancel() + + for func, kwargs in args: + task_group.start_soon(run, functools.partial(func, **kwargs)) async def run_in_threadpool( func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any ) -> T: - loop = asyncio.get_event_loop() if contextvars is not None: # pragma: no cover # Ensure we run in the same context child = functools.partial(func, *args, **kwargs) @@ -35,9 +34,9 @@ async def run_in_threadpool( func = context.run args = (child,) elif kwargs: # pragma: no cover - # loop.run_in_executor doesn't accept 'kwargs', so bind them in here + # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) - return await loop.run_in_executor(None, func, *args) + return await anyio.to_thread.run_sync(func, *args) class _StopIteration(Exception): @@ -57,6 +56,6 @@ def _next(iterator: Iterator) -> Any: async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator: while True: try: - yield await run_in_threadpool(_next, iterator) + yield await anyio.to_thread.run_sync(_next, iterator) except _StopIteration: break diff --git a/starlette/config.py b/starlette/config.py index e9894e0773..7444ae06d2 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -46,6 +46,8 @@ def __len__(self) -> int: environ = Environ() +T = typing.TypeVar("T") + class Config: def __init__( @@ -58,6 +60,24 @@ def __init__( if env_file is not None and os.path.isfile(env_file): self.file_values = self._read_file(env_file) + @typing.overload + def __call__( + self, key: str, cast: typing.Type[T], default: T = ... + ) -> T: # pragma: no cover + ... + + @typing.overload + def __call__( + self, key: str, cast: typing.Type[str] = ..., default: str = ... + ) -> str: # pragma: no cover + ... + + @typing.overload + def __call__( + self, key: str, cast: typing.Type[str] = ..., default: T = ... + ) -> typing.Union[T, str]: # pragma: no cover + ... + def __call__( self, key: str, cast: typing.Callable = None, default: typing.Any = undefined ) -> typing.Any: diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 5149a6e2e5..17dc46eb6a 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -266,7 +266,7 @@ def __init__( self._dict = {k: v for k, v in _items} self._list = _items - def getlist(self, key: typing.Any) -> typing.List[str]: + def getlist(self, key: typing.Any) -> typing.List[typing.Any]: return [item_value for item_key, item_value in self._list if item_key == key] def keys(self) -> typing.KeysView: diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index b347a6a2dd..77ba669251 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,9 +1,10 @@ -import asyncio import typing +import anyio + from starlette.requests import Request from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ @@ -21,45 +22,39 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - request = Request(scope, receive=receive) - response = await self.dispatch_func(request, self.call_next) - await response(scope, receive, send) + async def call_next(request: Request) -> Response: + send_stream, recv_stream = anyio.create_memory_object_stream() - async def call_next(self, request: Request) -> Response: - loop = asyncio.get_event_loop() - queue: "asyncio.Queue[typing.Optional[Message]]" = asyncio.Queue() + async def coro() -> None: + async with send_stream: + await self.app(scope, request.receive, send_stream.send) - scope = request.scope - receive = request.receive - send = queue.put + task_group.start_soon(coro) - async def coro() -> None: try: - await self.app(scope, receive, send) - finally: - await queue.put(None) - - task = loop.create_task(coro()) - message = await queue.get() - if message is None: - task.result() - raise RuntimeError("No response returned.") - assert message["type"] == "http.response.start" - - async def body_stream() -> typing.AsyncGenerator[bytes, None]: - while True: - message = await queue.get() - if message is None: - break - assert message["type"] == "http.response.body" - yield message.get("body", b"") - task.result() - - response = StreamingResponse( - status_code=message["status"], content=body_stream() - ) - response.raw_headers = message["headers"] - return response + message = await recv_stream.receive() + except anyio.EndOfStream: + raise RuntimeError("No response returned.") + + assert message["type"] == "http.response.start" + + async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async with recv_stream: + async for message in recv_stream: + assert message["type"] == "http.response.body" + yield message.get("body", b"") + + response = StreamingResponse( + status_code=message["status"], content=body_stream() + ) + response.raw_headers = message["headers"] + return response + + async with anyio.create_task_group() as task_group: + request = Request(scope, receive=receive) + response = await self.dispatch_func(request, call_next) + await response(scope, receive, send) + task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 0b3f505e71..c850579c80 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -129,6 +129,7 @@ def preflight_response(self, request_headers: Headers) -> Response: for header in [h.lower() for h in requested_headers.split(",")]: if header.strip() not in self.allow_headers: failures.append("headers") + break # We don't strictly need to use 400 responses here, since its up to # the browser to enforce the CORS policy, but its more informative diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index a13ec5c0ed..ad7a6ee899 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -3,7 +3,7 @@ from base64 import b64decode, b64encode import itsdangerous -from itsdangerous.exc import BadTimeSignature, SignatureExpired +from itsdangerous.exc import BadSignature from starlette.datastructures import MutableHeaders, Secret from starlette.requests import HTTPConnection @@ -42,7 +42,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: data = self.signer.unsign(data, max_age=self.max_age) scope["session"] = json.loads(b64decode(data)) initial_session_was_empty = False - except (BadTimeSignature, SignatureExpired): + except BadSignature: scope["session"] = {} else: scope["session"] = {} diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 6b30610bc6..7e69e1a6b2 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -1,10 +1,11 @@ -import asyncio import io +import math import sys import typing -from starlette.concurrency import run_in_threadpool -from starlette.types import Message, Receive, Scope, Send +import anyio + +from starlette.types import Receive, Scope, Send def build_environ(scope: Scope, body: bytes) -> dict: @@ -54,7 +55,7 @@ def build_environ(scope: Scope, body: bytes) -> dict: class WSGIMiddleware: - def __init__(self, app: typing.Callable, workers: int = 10) -> None: + def __init__(self, app: typing.Callable) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -69,9 +70,9 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None: self.scope = scope self.status = None self.response_headers = None - self.send_event = asyncio.Event() - self.send_queue: typing.List[typing.Optional[Message]] = [] - self.loop = asyncio.get_event_loop() + self.stream_send, self.stream_receive = anyio.create_memory_object_stream( + math.inf + ) self.response_started = False self.exc_info: typing.Any = None @@ -83,31 +84,18 @@ async def __call__(self, receive: Receive, send: Send) -> None: body += message.get("body", b"") more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - sender = None - try: - sender = self.loop.create_task(self.sender(send)) - await run_in_threadpool(self.wsgi, environ, self.start_response) - self.send_queue.append(None) - self.send_event.set() - await asyncio.wait_for(sender, None) - if self.exc_info is not None: - raise self.exc_info[0].with_traceback( - self.exc_info[1], self.exc_info[2] - ) - finally: - if sender and not sender.done(): - sender.cancel() # pragma: no cover + + async with anyio.create_task_group() as task_group: + task_group.start_soon(self.sender, send) + async with self.stream_send: + await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) async def sender(self, send: Send) -> None: - while True: - if self.send_queue: - message = self.send_queue.pop(0) - if message is None: - return + async with self.stream_receive: + async for message in self.stream_receive: await send(message) - else: - await self.send_event.wait() - self.send_event.clear() def start_response( self, @@ -124,21 +112,22 @@ def start_response( (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) for name, value in response_headers ] - self.send_queue.append( + anyio.from_thread.run( + self.stream_send.send, { "type": "http.response.start", "status": status_code, "headers": headers, - } + }, ) - self.loop.call_soon_threadsafe(self.send_event.set) def wsgi(self, environ: dict, start_response: typing.Callable) -> None: for chunk in self.app(environ, start_response): - self.send_queue.append( - {"type": "http.response.body", "body": chunk, "more_body": True} + anyio.from_thread.run( + self.stream_send.send, + {"type": "http.response.body", "body": chunk, "more_body": True}, ) - self.loop.call_soon_threadsafe(self.send_event.set) - self.send_queue.append({"type": "http.response.body", "body": b""}) - self.loop.call_soon_threadsafe(self.send_event.set) + anyio.from_thread.run( + self.stream_send.send, {"type": "http.response.body", "body": b""} + ) diff --git a/starlette/requests.py b/starlette/requests.py index ab6f51424b..676f4e9aa3 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -1,9 +1,10 @@ -import asyncio import json import typing from collections.abc import Mapping from http import cookies as http_cookies +import anyio + from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State from starlette.formparsers import FormParser, MultiPartParser from starlette.types import Message, Receive, Scope, Send @@ -64,7 +65,7 @@ def __init__(self, scope: Scope, receive: Receive = None) -> None: assert scope["type"] in ("http", "websocket") self.scope = scope - def __getitem__(self, key: str) -> str: + def __getitem__(self, key: str) -> typing.Any: return self.scope[key] def __iter__(self) -> typing.Iterator[str]: @@ -73,6 +74,12 @@ def __iter__(self) -> typing.Iterator[str]: def __len__(self) -> int: return len(self.scope) + # Don't use the `abc.Mapping.__eq__` implementation. + # Connection instances should never be considered equal + # unless `self is other`. + __eq__ = object.__eq__ + __hash__ = object.__hash__ + @property def app(self) -> typing.Any: return self.scope["app"] @@ -251,10 +258,12 @@ async def close(self) -> None: async def is_disconnected(self) -> bool: if not self._is_disconnected: - try: - message = await asyncio.wait_for(self._receive(), timeout=0.0000001) - except asyncio.TimeoutError: - message = {} + message: Message = {} + + # If message isn't immediately available, move on + with anyio.CancelScope() as cs: + cs.cancel() + message = await self._receive() if message.get("type") == "http.disconnect": self._is_disconnected = True diff --git a/starlette/responses.py b/starlette/responses.py index 00f6be4dbc..d03df23294 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -6,24 +6,20 @@ import sys import typing from email.utils import formatdate +from functools import partial from mimetypes import guess_type as mimetypes_guess_type from urllib.parse import quote +import anyio + from starlette.background import BackgroundTask -from starlette.concurrency import iterate_in_threadpool, run_until_first_complete +from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, MutableHeaders from starlette.types import Receive, Scope, Send # Workaround for adding samesite support to pre 3.8 python http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore -try: - import aiofiles - from aiofiles.os import stat as aio_stat -except ImportError: # pragma: nocover - aiofiles = None # type: ignore - aio_stat = None # type: ignore - # Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on None: await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await run_until_first_complete( - (self.stream_response, {"send": send}), - (self.listen_for_disconnect, {"receive": receive}), - ) + async with anyio.create_task_group() as task_group: + + async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None: + await func() + task_group.cancel_scope.cancel() + + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) if self.background is not None: await self.background() @@ -244,7 +244,6 @@ def __init__( stat_result: os.stat_result = None, method: str = None, ) -> None: - assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse" self.path = path self.status_code = status_code self.filename = filename @@ -280,7 +279,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.stat_result is None: try: - stat_result = await aio_stat(self.path) + stat_result = await anyio.to_thread.run_sync(os.stat, self.path) self.set_stat_headers(stat_result) except FileNotFoundError: raise RuntimeError(f"File at path {self.path} does not exist.") @@ -298,10 +297,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if self.send_header_only: await send({"type": "http.response.body", "body": b"", "more_body": False}) else: - # Tentatively ignoring type checking failure to work around the wrong type - # definitions for aiofile that come with typeshed. See - # https://github.com/python/typeshed/pull/4650 - async with aiofiles.open(self.path, mode="rb") as file: # type: ignore + async with await anyio.open_file(self.path, mode="rb") as file: more_body = True while more_body: chunk = await file.read(self.chunk_size) diff --git a/starlette/routing.py b/starlette/routing.py index cef1ef4848..9a1a5e12df 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,9 +1,13 @@ import asyncio +import contextlib import functools import inspect import re +import sys import traceback +import types import typing +import warnings from enum import Enum from starlette.concurrency import run_in_threadpool @@ -15,6 +19,11 @@ from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager # pragma: no cover +else: + from contextlib2 import asynccontextmanager # pragma: no cover + class NoMatchFound(Exception): """ @@ -470,6 +479,51 @@ def __eq__(self, other: typing.Any) -> bool: ) +_T = typing.TypeVar("_T") + + +class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): + def __init__(self, cm: typing.ContextManager[_T]): + self._cm = cm + + async def __aenter__(self) -> _T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]], + exc_value: typing.Optional[BaseException], + traceback: typing.Optional[types.TracebackType], + ) -> typing.Optional[bool]: + return self._cm.__exit__(exc_type, exc_value, traceback) + + +def _wrap_gen_lifespan_context( + lifespan_context: typing.Callable[[typing.Any], typing.Generator] +) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: + cmgr = contextlib.contextmanager(lifespan_context) + + @functools.wraps(cmgr) + def wrapper(app: typing.Any) -> _AsyncLiftContextManager: + return _AsyncLiftContextManager(cmgr(app)) + + return wrapper + + +class _DefaultLifespan: + def __init__(self, router: "Router"): + self._router = router + + async def __aenter__(self) -> None: + await self._router.startup() + + async def __aexit__(self, *exc_info: object) -> None: + await self._router.shutdown() + + def __call__(self: _T, app: object) -> _T: + return self + + class Router: def __init__( self, @@ -478,7 +532,7 @@ def __init__( default: ASGIApp = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, - lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None, + lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -486,12 +540,31 @@ def __init__( self.on_startup = [] if on_startup is None else list(on_startup) self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) - async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator: - await self.startup() - yield - await self.shutdown() + if lifespan is None: + self.lifespan_context: typing.Callable[ + [typing.Any], typing.AsyncContextManager + ] = _DefaultLifespan(self) - self.lifespan_context = default_lifespan if lifespan is None else lifespan + elif inspect.isasyncgenfunction(lifespan): + warnings.warn( + "async generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = asynccontextmanager( + lifespan, # type: ignore[arg-type] + ) + elif inspect.isgeneratorfunction(lifespan): + warnings.warn( + "generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = _wrap_gen_lifespan_context( + lifespan, # type: ignore[arg-type] + ) + else: + self.lifespan_context = lifespan async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": @@ -541,25 +614,19 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: Handle ASGI lifespan messages, which allows us to manage application startup and shutdown events. """ - first = True + started = False app = scope.get("app") await receive() try: - if inspect.isasyncgenfunction(self.lifespan_context): - async for item in self.lifespan_context(app): - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() - else: - for item in self.lifespan_context(app): # type: ignore - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() + async with self.lifespan_context(app): + await send({"type": "lifespan.startup.complete"}) + started = True + await receive() except BaseException: - if first: - exc_text = traceback.format_exc() + exc_text = traceback.format_exc() + if started: + await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + else: await send({"type": "lifespan.startup.failed", "message": exc_text}) raise else: diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 15a67fe35d..33ea0b0337 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -4,7 +4,7 @@ import typing from email.utils import parsedate -from aiofiles.os import stat as aio_stat +import anyio from starlette.datastructures import URL, Headers from starlette.responses import ( @@ -154,7 +154,7 @@ async def lookup_path( # directory. continue try: - stat_result = await aio_stat(full_path) + stat_result = await anyio.to_thread.run_sync(os.stat, full_path) return full_path, stat_result except FileNotFoundError: pass @@ -187,7 +187,7 @@ async def check_config(self) -> None: return try: - stat_result = await aio_stat(self.directory) + stat_result = await anyio.to_thread.run_sync(os.stat, self.directory) except FileNotFoundError: raise RuntimeError( f"StaticFiles directory '{self.directory}' does not exist." diff --git a/starlette/templating.py b/starlette/templating.py index 64fdbd14f6..18d5eb40c0 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -1,4 +1,5 @@ import typing +from os import PathLike from starlette.background import BackgroundTask from starlette.responses import Response @@ -54,11 +55,13 @@ class Jinja2Templates: return templates.TemplateResponse("index.html", {"request": request}) """ - def __init__(self, directory: str) -> None: + def __init__(self, directory: typing.Union[str, PathLike]) -> None: assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" - self.env = self.get_env(directory) + self.env = self._create_env(directory) - def get_env(self, directory: str) -> "jinja2.Environment": + def _create_env( + self, directory: typing.Union[str, PathLike] + ) -> "jinja2.Environment": @pass_context def url_for(context: dict, name: str, **path_params: typing.Any) -> str: request = context["request"] diff --git a/starlette/testclient.py b/starlette/testclient.py index 77c038b17f..08d03fa5c4 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1,19 +1,35 @@ import asyncio +import contextlib import http import inspect import io import json +import math import queue -import threading +import sys import types import typing +from concurrent.futures import Future from urllib.parse import unquote, urljoin, urlsplit +import anyio.abc import requests +from anyio.streams.stapled import StapledObjectStream from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocketDisconnect +if sys.version_info >= (3, 8): # pragma: no cover + from typing import TypedDict +else: # pragma: no cover + from typing_extensions import TypedDict + + +_PortalFactoryType = typing.Callable[ + [], typing.ContextManager[anyio.abc.BlockingPortal] +] + + # Annotations for `Session.request()` Cookies = typing.Union[ typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar @@ -87,13 +103,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await instance(receive, send) +class _AsyncBackend(TypedDict): + backend: str + backend_options: typing.Dict[str, typing.Any] + + class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( - self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = "" + self, + app: ASGI3App, + portal_factory: _PortalFactoryType, + raise_server_exceptions: bool = True, + root_path: str = "", ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path + self.portal_factory = portal_factory def send( self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any @@ -142,7 +168,7 @@ def send( "server": [host, port], "subprotocols": subprotocols, } - session = WebSocketTestSession(self.app, scope) + session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) scope = { @@ -161,17 +187,17 @@ def send( request_complete = False response_started = False - response_complete = False + response_complete: anyio.Event raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()} template = None context = None async def receive() -> Message: - nonlocal request_complete, response_complete + nonlocal request_complete if request_complete: - while not response_complete: - await asyncio.sleep(0.0001) + if not response_complete.is_set(): + await response_complete.wait() return {"type": "http.disconnect"} body = request.body @@ -195,7 +221,7 @@ async def receive() -> Message: return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: - nonlocal raw_kwargs, response_started, response_complete, template, context + nonlocal raw_kwargs, response_started, template, context if message["type"] == "http.response.start": assert ( @@ -205,7 +231,8 @@ async def send(message: Message) -> None: raw_kwargs["status"] = message["status"] raw_kwargs["reason"] = _get_reason_phrase(message["status"]) raw_kwargs["headers"] = [ - (key.decode(), value.decode()) for key, value in message["headers"] + (key.decode(), value.decode()) + for key, value in message.get("headers", []) ] raw_kwargs["preload_content"] = False raw_kwargs["original_response"] = _MockOriginalResponse( @@ -217,7 +244,7 @@ async def send(message: Message) -> None: response_started ), 'Received "http.response.body" without "http.response.start".' assert ( - not response_complete + not response_complete.is_set() ), 'Received "http.response.body" after response completed.' body = message.get("body", b"") more_body = message.get("more_body", False) @@ -225,19 +252,15 @@ async def send(message: Message) -> None: raw_kwargs["body"].write(body) if not more_body: raw_kwargs["body"].seek(0) - response_complete = True + response_complete.set() elif message["type"] == "http.response.template": template = message["template"] context = message["context"] try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(self.app(scope, receive, send)) + with self.portal_factory() as portal: + response_complete = portal.call(anyio.Event) + portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: raise exc @@ -264,48 +287,60 @@ async def send(message: Message) -> None: class WebSocketTestSession: - def __init__(self, app: ASGI3App, scope: Scope) -> None: + def __init__( + self, + app: ASGI3App, + scope: Scope, + portal_factory: _PortalFactoryType, + ) -> None: self.app = app self.scope = scope self.accepted_subprotocol = None + self.portal_factory = portal_factory self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() - self._thread = threading.Thread(target=self._run) - self.send({"type": "websocket.connect"}) - self._thread.start() - message = self.receive() - self._raise_on_close(message) - self.accepted_subprotocol = message.get("subprotocol", None) def __enter__(self) -> "WebSocketTestSession": + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context(self.portal_factory()) + + try: + _: "Future[None]" = self.portal.start_task_soon(self._run) + self.send({"type": "websocket.connect"}) + message = self.receive() + self._raise_on_close(message) + except Exception: + self.exit_stack.close() + raise + self.accepted_subprotocol = message.get("subprotocol", None) return self def __exit__(self, *args: typing.Any) -> None: - self.close(1000) - self._thread.join() + try: + self.close(1000) + finally: + self.exit_stack.close() while not self._send_queue.empty(): message = self._send_queue.get() if isinstance(message, BaseException): raise message - def _run(self) -> None: + async def _run(self) -> None: """ The sub-thread in which the websocket session runs. """ - loop = asyncio.new_event_loop() scope = self.scope receive = self._asgi_receive send = self._asgi_send try: - loop.run_until_complete(self.app(scope, receive, send)) + await self.app(scope, receive, send) except BaseException as exc: self._send_queue.put(exc) - finally: - loop.close() + raise async def _asgi_receive(self) -> Message: while self._receive_queue.empty(): - await asyncio.sleep(0) + await anyio.sleep(0) return self._receive_queue.get() async def _asgi_send(self, message: Message) -> None: @@ -364,6 +399,8 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. + task: "Future[None]" + portal: typing.Optional[anyio.abc.BlockingPortal] = None def __init__( self, @@ -371,8 +408,13 @@ def __init__( base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", + backend: str = "asyncio", + backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> None: super().__init__() + self.async_backend = _AsyncBackend( + backend=backend, backend_options=backend_options or {} + ) if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app @@ -381,6 +423,7 @@ def __init__( asgi_app = _WrapASGI2(app) #  type: ignore adapter = _ASGIAdapter( asgi_app, + portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, ) @@ -392,6 +435,16 @@ def __init__( self.app = asgi_app self.base_url = base_url + @contextlib.contextmanager + def _portal_factory( + self, + ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: + if self.portal is not None: + yield self.portal + else: + with anyio.start_blocking_portal(**self.async_backend) as portal: + yield portal + def request( # type: ignore self, method: str, @@ -452,42 +505,72 @@ def websocket_connect( return session def __enter__(self) -> "TestClient": - loop = asyncio.get_event_loop() - self.send_queue: "asyncio.Queue[typing.Any]" = asyncio.Queue() - self.receive_queue: "asyncio.Queue[typing.Any]" = asyncio.Queue() - self.task = loop.create_task(self.lifespan()) - loop.run_until_complete(self.wait_startup()) + with contextlib.ExitStack() as stack: + self.portal = portal = stack.enter_context( + anyio.start_blocking_portal(**self.async_backend) + ) + + @stack.callback + def reset_portal() -> None: + self.portal = None + + self.stream_send = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.stream_receive = StapledObjectStream( + *anyio.create_memory_object_stream(math.inf) + ) + self.task = portal.start_task_soon(self.lifespan) + portal.call(self.wait_startup) + + @stack.callback + def wait_shutdown() -> None: + portal.call(self.wait_shutdown) + + self.exit_stack = stack.pop_all() + return self def __exit__(self, *args: typing.Any) -> None: - loop = asyncio.get_event_loop() - loop.run_until_complete(self.wait_shutdown()) + self.exit_stack.close() async def lifespan(self) -> None: scope = {"type": "lifespan"} try: - await self.app(scope, self.receive_queue.get, self.send_queue.put) + await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: - await self.send_queue.put(None) + await self.stream_send.send(None) async def wait_startup(self) -> None: - await self.receive_queue.put({"type": "lifespan.startup"}) - message = await self.send_queue.get() - if message is None: - self.task.result() + await self.stream_receive.send({"type": "lifespan.startup"}) + + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + message = await receive() assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": - message = await self.send_queue.get() - if message is None: - self.task.result() + await receive() async def wait_shutdown(self) -> None: - await self.receive_queue.put({"type": "lifespan.shutdown"}) - message = await self.send_queue.get() - if message is None: - self.task.result() - assert message["type"] == "lifespan.shutdown.complete" - await self.task + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + async with self.stream_send: + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..8b9872aebe --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,23 @@ +import functools +import sys + +import pytest + +from starlette.testclient import TestClient + + +@pytest.fixture +def no_trio_support(anyio_backend_name): + if anyio_backend_name == "trio": + pytest.skip("Trio not supported (yet!)") + + +@pytest.fixture +def test_client_factory(anyio_backend_name, anyio_backend_options): + # anyio_backend_name defined by: + # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on + return functools.partial( + TestClient, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 048dd9ffb9..8a8df4ea66 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -5,7 +5,6 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import PlainTextResponse from starlette.routing import Route -from starlette.testclient import TestClient class CustomMiddleware(BaseHTTPMiddleware): @@ -48,8 +47,8 @@ async def websocket_endpoint(session): await session.close() -def test_custom_middleware(): - client = TestClient(app) +def test_custom_middleware(test_client_factory): + client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -64,7 +63,7 @@ def test_custom_middleware(): assert text == "Hello, world!" -def test_middleware_decorator(): +def test_middleware_decorator(test_client_factory): app = Starlette() @app.route("/homepage") @@ -79,7 +78,7 @@ async def plaintext(request, call_next): response.headers["Custom"] = "Example" return response - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "OK" @@ -88,7 +87,7 @@ async def plaintext(request, call_next): assert response.headers["Custom"] == "Example" -def test_state_data_across_multiple_middlewares(): +def test_state_data_across_multiple_middlewares(test_client_factory): expected_value1 = "foo" expected_value2 = "bar" @@ -120,14 +119,14 @@ async def dispatch(self, request, call_next): def homepage(request): return PlainTextResponse("OK") - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "OK" assert response.headers["X-State-Foo"] == expected_value1 assert response.headers["X-State-Bar"] == expected_value2 -def test_app_middleware_argument(): +def test_app_middleware_argument(test_client_factory): def homepage(request): return PlainTextResponse("Homepage") @@ -135,7 +134,7 @@ def homepage(request): routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)] ) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -143,3 +142,18 @@ def homepage(request): def test_middleware_repr(): middleware = Middleware(CustomMiddleware) assert repr(middleware) == "Middleware(CustomMiddleware)" + + +def test_fully_evaluated_response(test_client_factory): + # Test for https://github.com/encode/starlette/issues/1022 + class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + await call_next(request) + return PlainTextResponse("Custom") + + app = Starlette() + app.add_middleware(CustomMiddleware) + + client = test_client_factory(app) + response = client.get("/does_not_exist") + assert response.text == "Custom" diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 7a250a2416..65252e5024 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.cors import CORSMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_cors_allow_all(): +def test_cors_allow_all(test_client_factory): app = Starlette() app.add_middleware( @@ -20,7 +19,7 @@ def test_cors_allow_all(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -61,7 +60,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_all_except_credentials(): +def test_cors_allow_all_except_credentials(test_client_factory): app = Starlette() app.add_middleware( @@ -76,7 +75,7 @@ def test_cors_allow_all_except_credentials(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -108,7 +107,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_specific_origin(): +def test_cors_allow_specific_origin(test_client_factory): app = Starlette() app.add_middleware( @@ -121,7 +120,7 @@ def test_cors_allow_specific_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -153,7 +152,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_disallowed_preflight(): +def test_cors_disallowed_preflight(test_client_factory): app = Starlette() app.add_middleware( @@ -166,7 +165,7 @@ def test_cors_disallowed_preflight(): def homepage(request): pass # pragma: no cover - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -179,8 +178,20 @@ def homepage(request): assert response.text == "Disallowed CORS origin, method, headers" assert "access-control-allow-origin" not in response.headers + # Bug specific test, https://github.com/encode/starlette/pull/1199 + # Test preflight response text with multiple disallowed headers + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "X-Nope-1, X-Nope-2", + } + response = client.options("/", headers=headers) + assert response.text == "Disallowed CORS headers" + -def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(): +def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed( + test_client_factory, +): app = Starlette() app.add_middleware( @@ -194,7 +205,7 @@ def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_all def homepage(request): return # pragma: no cover - client = TestClient(app) + client = test_client_factory(app) # Test pre-flight response headers = { @@ -211,7 +222,7 @@ def homepage(request): assert response.headers["vary"] == "Origin" -def test_cors_preflight_allow_all_methods(): +def test_cors_preflight_allow_all_methods(test_client_factory): app = Starlette() app.add_middleware( @@ -224,7 +235,7 @@ def test_cors_preflight_allow_all_methods(): def homepage(request): pass # pragma: no cover - client = TestClient(app) + client = test_client_factory(app) headers = { "Origin": "https://example.org", @@ -237,7 +248,7 @@ def homepage(request): assert method in response.headers["access-control-allow-methods"] -def test_cors_allow_all_methods(): +def test_cors_allow_all_methods(test_client_factory): app = Starlette() app.add_middleware( @@ -252,7 +263,7 @@ def test_cors_allow_all_methods(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) headers = {"Origin": "https://example.org"} @@ -261,7 +272,7 @@ def homepage(request): assert response.status_code == 200 -def test_cors_allow_origin_regex(): +def test_cors_allow_origin_regex(test_client_factory): app = Starlette() app.add_middleware( @@ -275,7 +286,7 @@ def test_cors_allow_origin_regex(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test standard response headers = {"Origin": "https://example.org"} @@ -329,7 +340,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_origin_regex_fullmatch(): +def test_cors_allow_origin_regex_fullmatch(test_client_factory): app = Starlette() app.add_middleware( @@ -342,7 +353,7 @@ def test_cors_allow_origin_regex_fullmatch(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test standard response headers = {"Origin": "https://subdomain.example.org"} @@ -363,7 +374,7 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_credentialed_requests_return_specific_origin(): +def test_cors_credentialed_requests_return_specific_origin(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["*"]) @@ -372,7 +383,7 @@ def test_cors_credentialed_requests_return_specific_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) # Test credentialed request headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} @@ -383,7 +394,7 @@ def homepage(request): assert "access-control-allow-credentials" not in response.headers -def test_cors_vary_header_defaults_to_origin(): +def test_cors_vary_header_defaults_to_origin(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) @@ -394,14 +405,14 @@ def test_cors_vary_header_defaults_to_origin(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers=headers) assert response.status_code == 200 assert response.headers["vary"] == "Origin" -def test_cors_vary_header_is_not_set_for_non_credentialed_request(): +def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["*"]) @@ -412,14 +423,14 @@ def homepage(request): "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding" -def test_cors_vary_header_is_properly_set_for_credentialed_request(): +def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["*"]) @@ -430,7 +441,7 @@ def homepage(request): "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) - client = TestClient(app) + client = test_client_factory(app) response = client.get( "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} @@ -439,7 +450,9 @@ def homepage(request): assert response.headers["vary"] == "Accept-Encoding, Origin" -def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(): +def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( + test_client_factory, +): app = Starlette() app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) @@ -450,14 +463,16 @@ def homepage(request): "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://example.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" -def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(): +def test_cors_allowed_origin_does_not_leak_between_credentialed_requests( + test_client_factory, +): app = Starlette() app.add_middleware( @@ -468,7 +483,7 @@ def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(): def homepage(request): return PlainTextResponse("Homepage", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" assert "access-control-allow-credentials" not in response.headers diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index c178ef9da2..2c926a9b2d 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -2,10 +2,9 @@ from starlette.middleware.errors import ServerErrorMiddleware from starlette.responses import JSONResponse, Response -from starlette.testclient import TestClient -def test_handler(): +def test_handler(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") @@ -13,49 +12,49 @@ def error_500(request, exc): return JSONResponse({"detail": "Server Error"}, status_code=500) app = ServerErrorMiddleware(app, handler=error_500) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} -def test_debug_text(): +def test_debug_text(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.headers["content-type"].startswith("text/plain") assert "RuntimeError: Something went wrong" in response.text -def test_debug_html(): +def test_debug_html(test_client_factory): async def app(scope, receive, send): raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/", headers={"Accept": "text/html, */*"}) assert response.status_code == 500 assert response.headers["content-type"].startswith("text/html") assert "RuntimeError" in response.text -def test_debug_after_response_sent(): +def test_debug_after_response_sent(test_client_factory): async def app(scope, receive, send): response = Response(b"", status_code=204) await response(scope, receive, send) raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): client.get("/") -def test_debug_not_http(): +def test_debug_not_http(test_client_factory): """ DebugMiddleware should just pass through any non-http messages as-is. """ @@ -66,5 +65,6 @@ async def app(scope, receive, send): app = ServerErrorMiddleware(app) with pytest.raises(RuntimeError): - client = TestClient(app) - client.websocket_connect("/") + client = test_client_factory(app) + with client.websocket_connect("/"): + pass # pragma: nocover diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index cd989b8c1f..b917ea4dbb 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.gzip import GZipMiddleware from starlette.responses import PlainTextResponse, StreamingResponse -from starlette.testclient import TestClient -def test_gzip_responses(): +def test_gzip_responses(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -13,7 +12,7 @@ def test_gzip_responses(): def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -21,7 +20,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) < 4000 -def test_gzip_not_in_accept_encoding(): +def test_gzip_not_in_accept_encoding(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -30,7 +29,7 @@ def test_gzip_not_in_accept_encoding(): def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "identity"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -38,7 +37,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 4000 -def test_gzip_ignored_for_small_responses(): +def test_gzip_ignored_for_small_responses(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -47,7 +46,7 @@ def test_gzip_ignored_for_small_responses(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "OK" @@ -55,7 +54,7 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 2 -def test_gzip_streaming_response(): +def test_gzip_streaming_response(test_client_factory): app = Starlette() app.add_middleware(GZipMiddleware) @@ -69,7 +68,7 @@ async def generator(bytes, count): streaming = generator(bytes=b"x" * 400, count=10) return StreamingResponse(streaming, status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py index 757770b853..8db9506342 100644 --- a/tests/middleware/test_https_redirect.py +++ b/tests/middleware/test_https_redirect.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_https_redirect_middleware(): +def test_https_redirect_middleware(test_client_factory): app = Starlette() app.add_middleware(HTTPSRedirectMiddleware) @@ -13,26 +12,26 @@ def test_https_redirect_middleware(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app, base_url="https://testserver") + client = test_client_factory(app, base_url="https://testserver") response = client.get("/") assert response.status_code == 200 - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:80") + client = test_client_factory(app, base_url="http://testserver:80") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:443") + client = test_client_factory(app, base_url="http://testserver:443") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver/" - client = TestClient(app, base_url="http://testserver:123") + client = test_client_factory(app, base_url="http://testserver:123") response = client.get("/", allow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://testserver:123/" diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 68cf36df99..42f4447e5c 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -3,7 +3,6 @@ from starlette.applications import Starlette from starlette.middleware.sessions import SessionMiddleware from starlette.responses import JSONResponse -from starlette.testclient import TestClient def view_session(request): @@ -29,10 +28,10 @@ def create_app(): return app -def test_session(): +def test_session(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example") - client = TestClient(app) + client = test_client_factory(app) response = client.get("/view_session") assert response.json() == {"session": {}} @@ -56,10 +55,10 @@ def test_session(): assert response.json() == {"session": {}} -def test_session_expires(): +def test_session_expires(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", max_age=-1) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/update_session", json={"some": "data"}) assert response.json() == {"session": {"some": "data"}} @@ -72,11 +71,11 @@ def test_session_expires(): assert response.json() == {"session": {}} -def test_secure_session(): +def test_secure_session(test_client_factory): app = create_app() app.add_middleware(SessionMiddleware, secret_key="example", https_only=True) - secure_client = TestClient(app, base_url="https://testserver") - unsecure_client = TestClient(app, base_url="http://testserver") + secure_client = test_client_factory(app, base_url="https://testserver") + unsecure_client = test_client_factory(app, base_url="http://testserver") response = unsecure_client.get("/view_session") assert response.json() == {"session": {}} @@ -103,13 +102,26 @@ def test_secure_session(): assert response.json() == {"session": {}} -def test_session_cookie_subpath(): +def test_session_cookie_subpath(test_client_factory): app = create_app() second_app = create_app() second_app.add_middleware(SessionMiddleware, secret_key="example") app.mount("/second_app", second_app) - client = TestClient(app, base_url="http://testserver/second_app") + client = test_client_factory(app, base_url="http://testserver/second_app") response = client.post("second_app/update_session", json={"some": "data"}) cookie = response.headers["set-cookie"] cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0] assert cookie_path == "/second_app" + + +def test_invalid_session_cookie(test_client_factory): + app = create_app() + app.add_middleware(SessionMiddleware, secret_key="example") + client = test_client_factory(app) + + response = client.post("/update_session", json={"some": "data"}) + assert response.json() == {"session": {"some": "data"}} + + # we expect it to not raise an exception if we provide a bogus session cookie + response = client.get("/view_session", cookies={"session": "invalid"}) + assert response.json() == {"session": {}} diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py index 934f2477bf..de9c79e66a 100644 --- a/tests/middleware/test_trusted_host.py +++ b/tests/middleware/test_trusted_host.py @@ -1,10 +1,9 @@ from starlette.applications import Starlette from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import PlainTextResponse -from starlette.testclient import TestClient -def test_trusted_host_middleware(): +def test_trusted_host_middleware(test_client_factory): app = Starlette() app.add_middleware( @@ -15,15 +14,15 @@ def test_trusted_host_middleware(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 - client = TestClient(app, base_url="http://subdomain.testserver") + client = test_client_factory(app, base_url="http://subdomain.testserver") response = client.get("/") assert response.status_code == 200 - client = TestClient(app, base_url="http://invalidhost") + client = test_client_factory(app, base_url="http://invalidhost") response = client.get("/") assert response.status_code == 400 @@ -34,7 +33,7 @@ def test_default_allowed_hosts(): assert middleware.allowed_hosts == ["*"] -def test_www_redirect(): +def test_www_redirect(test_client_factory): app = Starlette() app.add_middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"]) @@ -43,7 +42,7 @@ def test_www_redirect(): def homepage(request): return PlainTextResponse("OK", status_code=200) - client = TestClient(app, base_url="https://example.com") + client = test_client_factory(app, base_url="https://example.com") response = client.get("/") assert response.status_code == 200 assert response.url == "https://www.example.com/" diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 615805a94d..bcb4cd6ff2 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -3,7 +3,6 @@ import pytest from starlette.middleware.wsgi import WSGIMiddleware, build_environ -from starlette.testclient import TestClient def hello_world(environ, start_response): @@ -46,41 +45,41 @@ def return_exc_info(environ, start_response): return [output] -def test_wsgi_get(): +def test_wsgi_get(test_client_factory): app = WSGIMiddleware(hello_world) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello World!\n" -def test_wsgi_post(): +def test_wsgi_post(test_client_factory): app = WSGIMiddleware(echo_body) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"example": 123}) assert response.status_code == 200 assert response.text == '{"example": 123}' -def test_wsgi_exception(): +def test_wsgi_exception(test_client_factory): # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): client.get("/") -def test_wsgi_exc_info(): +def test_wsgi_exc_info(test_client_factory): # Note that we're testing the WSGI app directly here. # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(return_exc_info) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): response = client.get("/") app = WSGIMiddleware(return_exc_info) - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert response.text == "Internal Server Error" diff --git a/tests/test_applications.py b/tests/test_applications.py index ad8504cbd5..f5f4c7fbea 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,4 +1,7 @@ import os +import sys + +import pytest from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint @@ -7,7 +10,11 @@ from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient + +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager # pragma: no cover +else: + from contextlib2 import asynccontextmanager # pragma: no cover app = Starlette() @@ -86,14 +93,17 @@ async def websocket_endpoint(session): await session.close() -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client def test_url_path_for(): assert app.url_path_for("func_homepage") == "/func" -def test_func_route(): +def test_func_route(client): response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" @@ -103,51 +113,51 @@ def test_func_route(): assert response.text == "" -def test_async_route(): +def test_async_route(client): response = client.get("/async") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_class_route(): +def test_class_route(client): response = client.get("/class") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_mounted_route(): +def test_mounted_route(client): response = client.get("/users/") assert response.status_code == 200 assert response.text == "Hello, everyone!" -def test_mounted_route_path_params(): +def test_mounted_route_path_params(client): response = client.get("/users/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" -def test_subdomain_route(): - client = TestClient(app, base_url="https://foo.example.org/") +def test_subdomain_route(test_client_factory): + client = test_client_factory(app, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 assert response.text == "Subdomain: foo" -def test_websocket_route(): +def test_websocket_route(client): with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_400(): +def test_400(client): response = client.get("/404") assert response.status_code == 404 assert response.json() == {"detail": "Not Found"} -def test_405(): +def test_405(client): response = client.post("/func") assert response.status_code == 405 assert response.json() == {"detail": "Custom message"} @@ -157,15 +167,15 @@ def test_405(): assert response.json() == {"detail": "Custom message"} -def test_500(): - client = TestClient(app, raise_server_exceptions=False) +def test_500(test_client_factory): + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/500") assert response.status_code == 500 assert response.json() == {"detail": "Server Error"} -def test_middleware(): - client = TestClient(app, base_url="http://incorrecthost") +def test_middleware(test_client_factory): + client = test_client_factory(app, base_url="http://incorrecthost") response = client.get("/func") assert response.status_code == 400 assert response.text == "Invalid host header" @@ -194,7 +204,7 @@ def test_routes(): ] -def test_app_mount(tmpdir): +def test_app_mount(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") @@ -202,7 +212,7 @@ def test_app_mount(tmpdir): app = Starlette() app.mount("/static", StaticFiles(directory=tmpdir)) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/static/example.txt") assert response.status_code == 200 @@ -213,7 +223,7 @@ def test_app_mount(tmpdir): assert response.text == "Method Not Allowed" -def test_app_debug(): +def test_app_debug(test_client_factory): app = Starlette() app.debug = True @@ -221,27 +231,27 @@ def test_app_debug(): async def homepage(request): raise RuntimeError() - client = TestClient(app, raise_server_exceptions=False) + client = test_client_factory(app, raise_server_exceptions=False) response = client.get("/") assert response.status_code == 500 assert "RuntimeError" in response.text assert app.debug -def test_app_add_route(): +def test_app_add_route(test_client_factory): app = Starlette() async def homepage(request): return PlainTextResponse("Hello, World!") app.add_route("/", homepage) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" -def test_app_add_websocket_route(): +def test_app_add_websocket_route(test_client_factory): app = Starlette() async def websocket_endpoint(session): @@ -250,14 +260,14 @@ async def websocket_endpoint(session): await session.close() app.add_websocket_route("/ws", websocket_endpoint) - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" -def test_app_add_event_handler(): +def test_app_add_event_handler(test_client_factory): startup_complete = False cleanup_complete = False app = Starlette() @@ -275,14 +285,46 @@ def run_cleanup(): assert not startup_complete assert not cleanup_complete - with TestClient(app): + with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete assert cleanup_complete -def test_app_async_lifespan(): +def test_app_async_cm_lifespan(test_client_factory): + startup_complete = False + cleanup_complete = False + + @asynccontextmanager + async def lifespan(app): + nonlocal startup_complete, cleanup_complete + startup_complete = True + yield + cleanup_complete = True + + app = Starlette(lifespan=lifespan) + + assert not startup_complete + assert not cleanup_complete + with test_client_factory(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete + + +deprecated_lifespan = pytest.mark.filterwarnings( + r"ignore" + r":(async )?generator function lifespans are deprecated, use an " + r"@contextlib\.asynccontextmanager function instead" + r":DeprecationWarning" + r":starlette.routing" +) + + +@deprecated_lifespan +def test_app_async_gen_lifespan(test_client_factory): startup_complete = False cleanup_complete = False @@ -296,14 +338,15 @@ async def lifespan(app): assert not startup_complete assert not cleanup_complete - with TestClient(app): + with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete assert cleanup_complete -def test_app_sync_lifespan(): +@deprecated_lifespan +def test_app_sync_gen_lifespan(test_client_factory): startup_complete = False cleanup_complete = False @@ -317,7 +360,7 @@ def lifespan(app): assert not startup_complete assert not cleanup_complete - with TestClient(app): + with test_client_factory(app): assert startup_complete assert not cleanup_complete assert startup_complete diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 3373f67c50..43c7ab96dc 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -15,7 +15,6 @@ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.testclient import TestClient from starlette.websockets import WebSocketDisconnect @@ -195,8 +194,8 @@ def foo(): pass # pragma: nocover -def test_user_interface(): - with TestClient(app) as client: +def test_user_interface(test_client_factory): + with test_client_factory(app) as client: response = client.get("/") assert response.status_code == 200 assert response.json() == {"authenticated": False, "user": ""} @@ -206,8 +205,8 @@ def test_user_interface(): assert response.json() == {"authenticated": True, "user": "tomchristie"} -def test_authentication_required(): - with TestClient(app) as client: +def test_authentication_required(test_client_factory): + with test_client_factory(app) as client: response = client.get("/dashboard") assert response.status_code == 403 @@ -258,13 +257,17 @@ def test_authentication_required(): assert response.text == "Invalid basic auth credentials" -def test_websocket_authentication_required(): - with TestClient(app) as client: +def test_websocket_authentication_required(test_client_factory): + with test_client_factory(app) as client: with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws") + with client.websocket_connect("/ws"): + pass # pragma: nocover with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws", headers={"Authorization": "basic foobar"}) + with client.websocket_connect( + "/ws", headers={"Authorization": "basic foobar"} + ): + pass # pragma: nocover with client.websocket_connect( "/ws", auth=("tomchristie", "example") @@ -273,12 +276,14 @@ def test_websocket_authentication_required(): assert data == {"authenticated": True, "user": "tomchristie"} with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/ws/decorated") + with client.websocket_connect("/ws/decorated"): + pass # pragma: nocover with pytest.raises(WebSocketDisconnect): - client.websocket_connect( + with client.websocket_connect( "/ws/decorated", headers={"Authorization": "basic foobar"} - ) + ): + pass # pragma: nocover with client.websocket_connect( "/ws/decorated", auth=("tomchristie", "example") @@ -291,8 +296,8 @@ def test_websocket_authentication_required(): } -def test_authentication_redirect(): - with TestClient(app) as client: +def test_authentication_redirect(test_client_factory): + with test_client_factory(app) as client: response = client.get("/admin") assert response.status_code == 200 assert response.url == "http://testserver/" @@ -331,8 +336,8 @@ def control_panel(request): ) -def test_custom_on_error(): - with TestClient(other_app) as client: +def test_custom_on_error(test_client_factory): + with test_client_factory(other_app) as client: response = client.get("/control-panel", auth=("tomchristie", "example")) assert response.status_code == 200 assert response.json() == {"authenticated": True, "user": "tomchristie"} diff --git a/tests/test_background.py b/tests/test_background.py index d9d7ddd872..e299ec3628 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,9 +1,8 @@ from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response -from starlette.testclient import TestClient -def test_async_task(): +def test_async_task(test_client_factory): TASK_COMPLETE = False async def async_task(): @@ -16,13 +15,13 @@ async def app(scope, receive, send): response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE -def test_sync_task(): +def test_sync_task(test_client_factory): TASK_COMPLETE = False def sync_task(): @@ -35,13 +34,13 @@ async def app(scope, receive, send): response = Response("task initiated", media_type="text/plain", background=task) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "task initiated" assert TASK_COMPLETE -def test_multiple_tasks(): +def test_multiple_tasks(test_client_factory): TASK_COUNTER = 0 def increment(amount): @@ -58,7 +57,7 @@ async def app(scope, receive, send): ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "tasks initiated" assert TASK_COUNTER == 1 + 2 + 3 diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py new file mode 100644 index 0000000000..cc5eba974f --- /dev/null +++ b/tests/test_concurrency.py @@ -0,0 +1,22 @@ +import anyio +import pytest + +from starlette.concurrency import run_until_first_complete + + +@pytest.mark.anyio +async def test_run_until_first_complete(): + task1_finished = anyio.Event() + task2_finished = anyio.Event() + + async def task1(): + task1_finished.set() + + async def task2(): + await task1_finished.wait() + await anyio.sleep(0) # pragma: nocover + task2_finished.set() # pragma: nocover + + await run_until_first_complete((task1, {}), (task2, {})) + assert task1_finished.is_set() + assert not task2_finished.is_set() diff --git a/tests/test_database.py b/tests/test_database.py index 258a71ec51..c0a4745d11 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,7 +4,6 @@ from starlette.applications import Starlette from starlette.responses import JSONResponse -from starlette.testclient import TestClient DATABASE_URL = "sqlite:///test.db" @@ -19,6 +18,9 @@ ) +pytestmark = pytest.mark.usefixtures("no_trio_support") + + @pytest.fixture(autouse=True, scope="module") def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) @@ -87,8 +89,8 @@ async def read_note_text(request): return JSONResponse(result[0]) -def test_database(): - with TestClient(app) as client: +def test_database(test_client_factory): + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "buy the milk", "completed": True} ) @@ -122,10 +124,8 @@ def test_database(): assert response.json() == "buy the milk" -def test_database_execute_many(): - with TestClient(app) as client: - response = client.get("/notes") - +def test_database_execute_many(test_client_factory): + with test_client_factory(app) as client: data = [ {"text": "buy the milk", "completed": True}, {"text": "walk the dog", "completed": False}, @@ -141,11 +141,11 @@ def test_database_execute_many(): ] -def test_database_isolated_during_test_cases(): +def test_database_isolated_during_test_cases(test_client_factory): """ Using `TestClient` as a context manager """ - with TestClient(app) as client: + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "just one note", "completed": True} ) @@ -155,7 +155,7 @@ def test_database_isolated_during_test_cases(): assert response.status_code == 200 assert response.json() == [{"text": "just one note", "completed": True}] - with TestClient(app) as client: + with test_client_factory(app) as client: response = client.post( "/notes", json={"text": "just one note", "completed": True} ) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index b0e6baf985..bb71ba870c 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -217,7 +217,7 @@ class BigUploadFile(UploadFile): spool_max_size = 1024 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_upload_file(): big_file = BigUploadFile("big-file") await big_file.write(b"big-data" * 512) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index e491c085f3..e57d47486a 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -3,7 +3,6 @@ from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint from starlette.responses import PlainTextResponse from starlette.routing import Route, Router -from starlette.testclient import TestClient class Homepage(HTTPEndpoint): @@ -18,46 +17,50 @@ async def get(self, request): routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)] ) -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client -def test_http_endpoint_route(): + +def test_http_endpoint_route(client): response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_http_endpoint_route_path_params(): +def test_http_endpoint_route_path_params(client): response = client.get("/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" -def test_http_endpoint_route_method(): +def test_http_endpoint_route_method(client): response = client.post("/") assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_websocket_endpoint_on_connect(): +def test_websocket_endpoint_on_connect(test_client_factory): class WebSocketApp(WebSocketEndpoint): async def on_connect(self, websocket): assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" -def test_websocket_endpoint_on_receive_bytes(): +def test_websocket_endpoint_on_receive_bytes(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "bytes" async def on_receive(self, websocket, data): await websocket.send_bytes(b"Message bytes was: " + data) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_bytes(b"Hello, world!") _bytes = websocket.receive_bytes() @@ -68,14 +71,14 @@ async def on_receive(self, websocket, data): websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json(): +def test_websocket_endpoint_on_receive_json(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket, data): await websocket.send_json({"message": data}) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() @@ -86,28 +89,28 @@ async def on_receive(self, websocket, data): websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json_binary(): +def test_websocket_endpoint_on_receive_json_binary(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "json" async def on_receive(self, websocket, data): await websocket.send_json({"message": data}, mode="binary") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_json({"hello": "world"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"message": {"hello": "world"}} -def test_websocket_endpoint_on_receive_text(): +def test_websocket_endpoint_on_receive_text(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = "text" async def on_receive(self, websocket, data): await websocket.send_text(f"Message text was: {data}") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() @@ -118,26 +121,26 @@ async def on_receive(self, websocket, data): websocket.send_bytes(b"Hello world") -def test_websocket_endpoint_on_default(): +def test_websocket_endpoint_on_default(test_client_factory): class WebSocketApp(WebSocketEndpoint): encoding = None async def on_receive(self, websocket, data): await websocket.send_text(f"Message text was: {data}") - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.send_text("Hello, world!") _text = websocket.receive_text() assert _text == "Message text was: Hello, world!" -def test_websocket_endpoint_on_disconnect(): +def test_websocket_endpoint_on_disconnect(test_client_factory): class WebSocketApp(WebSocketEndpoint): async def on_disconnect(self, websocket, close_code): assert close_code == 1001 await websocket.close(code=close_code) - client = TestClient(WebSocketApp) + client = test_client_factory(WebSocketApp) with client.websocket_connect("/ws") as websocket: websocket.close(code=1001) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 841c9a5cf6..5fba9981b1 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,7 +3,6 @@ from starlette.exceptions import ExceptionMiddleware, HTTPException from starlette.responses import PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute -from starlette.testclient import TestClient def raise_runtime_error(request): @@ -37,27 +36,33 @@ async def __call__(self, scope, receive, send): app = ExceptionMiddleware(router) -client = TestClient(app) -def test_not_acceptable(): +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client + + +def test_not_acceptable(client): response = client.get("/not_acceptable") assert response.status_code == 406 assert response.text == "Not Acceptable" -def test_not_modified(): +def test_not_modified(client): response = client.get("/not_modified") assert response.status_code == 304 assert response.text == "" -def test_websockets_should_raise(): +def test_websockets_should_raise(client): with pytest.raises(RuntimeError): - client.websocket_connect("/runtime_error") + with client.websocket_connect("/runtime_error"): + pass # pragma: nocover -def test_handled_exc_after_response(): +def test_handled_exc_after_response(test_client_factory, client): # A 406 HttpException is raised *after* the response has already been sent. # The exception middleware should raise a RuntimeError. with pytest.raises(RuntimeError): @@ -65,17 +70,17 @@ def test_handled_exc_after_response(): # If `raise_server_exceptions=False` then the test client will still allow # us to see the response as it will have been seen by the client. - allow_200_client = TestClient(app, raise_server_exceptions=False) + allow_200_client = test_client_factory(app, raise_server_exceptions=False) response = allow_200_client.get("/handled_exc_after_response") assert response.status_code == 200 assert response.text == "OK" -def test_force_500_response(): +def test_force_500_response(test_client_factory): def app(scope): raise RuntimeError() - force_500_client = TestClient(app, raise_server_exceptions=False) + force_500_client = test_client_factory(app, raise_server_exceptions=False) response = force_500_client.get("/") assert response.status_code == 500 assert response.text == "" diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 73a720fd13..8a1174e1d2 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -3,7 +3,6 @@ from starlette.formparsers import UploadFile, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse -from starlette.testclient import TestClient class ForceMultipartDict(dict): @@ -70,18 +69,18 @@ async def app_read_body(scope, receive, send): await response(scope, receive, send) -def test_multipart_request_data(tmpdir): - client = TestClient(app) +def test_multipart_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data"} -def test_multipart_request_files(tmpdir): +def test_multipart_request_files(tmpdir, test_client_factory): path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": f}) assert response.json() == { @@ -93,12 +92,12 @@ def test_multipart_request_files(tmpdir): } -def test_multipart_request_files_with_content_type(tmpdir): +def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) assert response.json() == { @@ -110,7 +109,7 @@ def test_multipart_request_files_with_content_type(tmpdir): } -def test_multipart_request_multiple_files(tmpdir): +def test_multipart_request_multiple_files(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -119,7 +118,7 @@ def test_multipart_request_multiple_files(tmpdir): with open(path2, "wb") as file: file.write(b"") - client = TestClient(app) + client = test_client_factory(app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")} @@ -138,7 +137,7 @@ def test_multipart_request_multiple_files(tmpdir): } -def test_multi_items(tmpdir): +def test_multi_items(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") @@ -147,7 +146,7 @@ def test_multi_items(tmpdir): with open(path2, "wb") as file: file.write(b"") - client = TestClient(multi_items_app) + client = test_client_factory(multi_items_app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", @@ -171,8 +170,8 @@ def test_multi_items(tmpdir): } -def test_multipart_request_mixed_files_and_data(tmpdir): - client = TestClient(app) +def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -208,8 +207,8 @@ def test_multipart_request_mixed_files_and_data(tmpdir): } -def test_multipart_request_with_charset_for_filename(tmpdir): - client = TestClient(app) +def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -236,8 +235,8 @@ def test_multipart_request_with_charset_for_filename(tmpdir): } -def test_multipart_request_without_charset_for_filename(tmpdir): - client = TestClient(app) +def test_multipart_request_without_charset_for_filename(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -263,8 +262,8 @@ def test_multipart_request_without_charset_for_filename(tmpdir): } -def test_multipart_request_with_encoded_value(tmpdir): - client = TestClient(app) +def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post( "/", data=( @@ -284,38 +283,38 @@ def test_multipart_request_with_encoded_value(tmpdir): assert response.json() == {"value": "Transférer"} -def test_urlencoded_request_data(tmpdir): - client = TestClient(app) +def test_urlencoded_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "data"}) assert response.json() == {"some": "data"} -def test_no_request_data(tmpdir): - client = TestClient(app) +def test_no_request_data(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/") assert response.json() == {} -def test_urlencoded_percent_encoding(tmpdir): - client = TestClient(app) +def test_urlencoded_percent_encoding(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"some": "da ta"}) assert response.json() == {"some": "da ta"} -def test_urlencoded_percent_encoding_keys(tmpdir): - client = TestClient(app) +def test_urlencoded_percent_encoding_keys(tmpdir, test_client_factory): + client = test_client_factory(app) response = client.post("/", data={"so me": "data"}) assert response.json() == {"so me": "data"} -def test_urlencoded_multi_field_app_reads_body(tmpdir): - client = TestClient(app_read_body) +def test_urlencoded_multi_field_app_reads_body(tmpdir, test_client_factory): + client = test_client_factory(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}) assert response.json() == {"some": "data", "second": "key pair"} -def test_multipart_multi_field_app_reads_body(tmpdir): - client = TestClient(app_read_body) +def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory): + client = test_client_factory(app_read_body) response = client.post( "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART ) diff --git a/tests/test_requests.py b/tests/test_requests.py index a83a2c480b..d7c69fbeb2 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,20 +1,18 @@ -import asyncio - +import anyio import pytest from starlette.requests import ClientDisconnect, Request, State from starlette.responses import JSONResponse, Response -from starlette.testclient import TestClient -def test_request_url(): +def test_request_url(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) data = {"method": request.method, "url": str(request.url)} response = JSONResponse(data) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/123?a=abc") assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"} @@ -22,26 +20,26 @@ async def app(scope, receive, send): assert response.json() == {"method": "GET", "url": "https://example.org:123/"} -def test_request_query_params(): +def test_request_query_params(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) params = dict(request.query_params) response = JSONResponse({"params": params}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/?a=123&b=456") assert response.json() == {"params": {"a": "123", "b": "456"}} -def test_request_headers(): +def test_request_headers(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) headers = dict(request.headers) response = JSONResponse({"headers": headers}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"host": "example.org"}) assert response.json() == { "headers": { @@ -54,7 +52,7 @@ async def app(scope, receive, send): } -def test_request_client(): +def test_request_client(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) response = JSONResponse( @@ -62,19 +60,19 @@ async def app(scope, receive, send): ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"host": "testclient", "port": 50000} -def test_request_body(): +def test_request_body(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"body": ""} @@ -86,7 +84,7 @@ async def app(scope, receive, send): assert response.json() == {"body": "abc"} -def test_request_stream(): +def test_request_stream(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = b"" @@ -95,7 +93,7 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"body": ""} @@ -107,20 +105,20 @@ async def app(scope, receive, send): assert response.json() == {"body": "abc"} -def test_request_form_urlencoded(): +def test_request_form_urlencoded(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) form = await request.form() response = JSONResponse({"form": dict(form)}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data={"abc": "123 @"}) assert response.json() == {"form": {"abc": "123 @"}} -def test_request_body_then_stream(): +def test_request_body_then_stream(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() @@ -130,13 +128,13 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data="abc") assert response.json() == {"body": "abc", "stream": "abc"} -def test_request_stream_then_body(): +def test_request_stream_then_body(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) chunks = b"" @@ -149,20 +147,20 @@ async def app(scope, receive, send): response = JSONResponse({"body": body.decode(), "stream": chunks.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", data="abc") assert response.json() == {"body": "", "stream": "abc"} -def test_request_json(): +def test_request_json(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) data = await request.json() response = JSONResponse({"json": data}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": {"a": "123"}} @@ -178,7 +176,7 @@ def test_request_scope_interface(): assert len(request) == 3 -def test_request_without_setting_receive(): +def test_request_without_setting_receive(test_client_factory): """ If Request is instantiated without the receive channel, then .body() is not available. @@ -193,12 +191,12 @@ async def app(scope, receive, send): response = JSONResponse({"json": data}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/", json={"a": "123"}) assert response.json() == {"json": "Receive channel not available"} -def test_request_disconnect(): +def test_request_disconnect(anyio_backend_name, anyio_backend_options): """ If a client disconnect occurs while reading request body then ClientDisconnect should be raised. @@ -212,12 +210,18 @@ async def receiver(): return {"type": "http.disconnect"} scope = {"type": "http", "method": "POST", "path": "/"} - loop = asyncio.get_event_loop() with pytest.raises(ClientDisconnect): - loop.run_until_complete(app(scope, receiver, None)) + anyio.run( + app, + scope, + receiver, + None, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) -def test_request_is_disconnected(): +def test_request_is_disconnected(test_client_factory): """ If a client disconnect occurs while reading request body then ClientDisconnect should be raised. @@ -234,7 +238,7 @@ async def app(scope, receive, send): await response(scope, receive, send) disconnected_after_response = await request.is_disconnected() - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"disconnected": False} assert disconnected_after_response @@ -254,19 +258,19 @@ def test_request_state_object(): s.new -def test_request_state(): +def test_request_state(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) request.state.example = 123 response = JSONResponse({"state.example": request.state.example}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/123?a=abc") assert response.json() == {"state.example": 123} -def test_request_cookies(): +def test_request_cookies(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) mycookie = request.cookies.get("mycookie") @@ -278,14 +282,14 @@ async def app(scope, receive, send): await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" response = client.get("/") assert response.text == "Hello, cookies!" -def test_cookie_lenient_parsing(): +def test_cookie_lenient_parsing(test_client_factory): """ The following test is based on a cookie set by Okta, a well-known authorization service. It turns out that it's common practice to set cookies that would be @@ -312,7 +316,7 @@ async def app(scope, receive, send): response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"cookie": tough_cookie}) result = response.json() assert len(result["cookies"]) == 4 @@ -341,13 +345,13 @@ async def app(scope, receive, send): ("a=b; h=i; a=c", {"a": "c", "h": "i"}), ], ) -def test_cookies_edge_cases(set_cookie, expected): +def test_cookies_edge_cases(set_cookie, expected, test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"cookie": set_cookie}) result = response.json() assert result["cookies"] == expected @@ -376,7 +380,7 @@ async def app(scope, receive, send): # (" = b ; ; = ; c = ; ", {"": "b", "c": ""}), ], ) -def test_cookies_invalid(set_cookie, expected): +def test_cookies_invalid(set_cookie, expected, test_client_factory): """ Cookie strings that are against the RFC6265 spec but which browsers will send if set via document.cookie. @@ -387,20 +391,20 @@ async def app(scope, receive, send): response = JSONResponse({"cookies": request.cookies}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/", headers={"cookie": set_cookie}) result = response.json() assert result["cookies"] == expected -def test_chunked_encoding(): +def test_chunked_encoding(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) body = await request.body() response = JSONResponse({"body": body.decode()}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) def post_body(): yield b"foo" @@ -410,7 +414,7 @@ def post_body(): assert response.json() == {"body": "foobar"} -def test_request_send_push_promise(): +def test_request_send_push_promise(test_client_factory): async def app(scope, receive, send): # the server is push-enabled scope["extensions"]["http.response.push"] = {} @@ -421,12 +425,12 @@ async def app(scope, receive, send): response = JSONResponse({"json": "OK"}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "OK"} -def test_request_send_push_promise_without_push_extension(): +def test_request_send_push_promise_without_push_extension(test_client_factory): """ If server does not support the `http.response.push` extension, .send_push_promise() does nothing. @@ -439,12 +443,12 @@ async def app(scope, receive, send): response = JSONResponse({"json": "OK"}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "OK"} -def test_request_send_push_promise_without_setting_send(): +def test_request_send_push_promise_without_setting_send(test_client_factory): """ If Request is instantiated without the send channel, then .send_push_promise() is not available. @@ -463,6 +467,6 @@ async def app(scope, receive, send): response = JSONResponse({"json": data}) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() == {"json": "Send channel not available"} diff --git a/tests/test_responses.py b/tests/test_responses.py index fd2ba0e424..baba549baf 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,6 +1,6 @@ -import asyncio import os +import anyio import pytest from starlette import status @@ -13,40 +13,39 @@ Response, StreamingResponse, ) -from starlette.testclient import TestClient -def test_text_response(): +def test_text_response(test_client_factory): async def app(scope, receive, send): response = Response("hello, world", media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "hello, world" -def test_bytes_response(): +def test_bytes_response(test_client_factory): async def app(scope, receive, send): response = Response(b"xxxxx", media_type="image/png") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.content == b"xxxxx" -def test_json_none_response(): +def test_json_none_response(test_client_factory): async def app(scope, receive, send): response = JSONResponse(None) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.json() is None -def test_redirect_response(): +def test_redirect_response(test_client_factory): async def app(scope, receive, send): if scope["path"] == "/": response = Response("hello, world", media_type="text/plain") @@ -54,13 +53,13 @@ async def app(scope, receive, send): response = RedirectResponse("/") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/" -def test_quoting_redirect_response(): +def test_quoting_redirect_response(test_client_factory): async def app(scope, receive, send): if scope["path"] == "/I ♥ Starlette/": response = Response("hello, world", media_type="text/plain") @@ -68,13 +67,13 @@ async def app(scope, receive, send): response = RedirectResponse("/I ♥ Starlette/") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/" -def test_streaming_response(): +def test_streaming_response(test_client_factory): filled_by_bg_task = "" async def app(scope, receive, send): @@ -83,7 +82,7 @@ async def numbers(minimum, maximum): yield str(i) if i != maximum: yield ", " - await asyncio.sleep(0) + await anyio.sleep(0) async def numbers_for_cleanup(start=1, stop=5): nonlocal filled_by_bg_task @@ -98,13 +97,13 @@ async def numbers_for_cleanup(start=1, stop=5): await response(scope, receive, send) assert filled_by_bg_task == "" - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" assert filled_by_bg_task == "6, 7, 8, 9" -def test_streaming_response_custom_iterator(): +def test_streaming_response_custom_iterator(test_client_factory): async def app(scope, receive, send): class CustomAsyncIterator: def __init__(self): @@ -122,12 +121,12 @@ async def __anext__(self): response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "12345" -def test_streaming_response_custom_iterable(): +def test_streaming_response_custom_iterable(test_client_factory): async def app(scope, receive, send): class CustomAsyncIterable: async def __aiter__(self): @@ -137,12 +136,12 @@ async def __aiter__(self): response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "12345" -def test_sync_streaming_response(): +def test_sync_streaming_response(test_client_factory): async def app(scope, receive, send): def numbers(minimum, maximum): for i in range(minimum, maximum + 1): @@ -154,37 +153,37 @@ def numbers(minimum, maximum): response = StreamingResponse(generator, media_type="text/plain") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" -def test_response_headers(): +def test_response_headers(test_client_factory): async def app(scope, receive, send): headers = {"x-header-1": "123", "x-header-2": "456"} response = Response("hello, world", media_type="text/plain", headers=headers) response.headers["x-header-2"] = "789" await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.headers["x-header-1"] == "123" assert response.headers["x-header-2"] == "789" -def test_response_phrase(): +def test_response_phrase(test_client_factory): app = Response(status_code=204) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.reason == "No Content" app = Response(b"", status_code=123) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.reason == "" -def test_file_response(tmpdir): +def test_file_response(tmpdir, test_client_factory): path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -197,7 +196,7 @@ async def numbers(minimum, maximum): yield str(i) if i != maximum: yield ", " - await asyncio.sleep(0) + await anyio.sleep(0) async def numbers_for_cleanup(start=1, stop=5): nonlocal filled_by_bg_task @@ -213,7 +212,7 @@ async def app(scope, receive, send): await response(scope, receive, send) assert filled_by_bg_task == "" - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") expected_disposition = 'attachment; filename="example.png"' assert response.status_code == status.HTTP_200_OK @@ -226,31 +225,31 @@ async def app(scope, receive, send): assert filled_by_bg_task == "6, 7, 8, 9" -def test_file_response_with_directory_raises_error(tmpdir): +def test_file_response_with_directory_raises_error(tmpdir, test_client_factory): app = FileResponse(path=tmpdir, filename="example.png") - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "is not a file" in str(exc_info.value) -def test_file_response_with_missing_file_raises_error(tmpdir): +def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factory): path = os.path.join(tmpdir, "404.txt") app = FileResponse(path=path, filename="404.txt") - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "does not exist" in str(exc_info.value) -def test_file_response_with_chinese_filename(tmpdir): +def test_file_response_with_chinese_filename(tmpdir, test_client_factory): content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = os.path.join(tmpdir, filename) with open(path, "wb") as f: f.write(content) app = FileResponse(path=path, filename=filename) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt" assert response.status_code == status.HTTP_200_OK @@ -258,7 +257,7 @@ def test_file_response_with_chinese_filename(tmpdir): assert response.headers["content-disposition"] == expected_disposition -def test_set_cookie(): +def test_set_cookie(test_client_factory): async def app(scope, receive, send): response = Response("Hello, world!", media_type="text/plain") response.set_cookie( @@ -274,12 +273,12 @@ async def app(scope, receive, send): ) await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_delete_cookie(): +def test_delete_cookie(test_client_factory): async def app(scope, receive, send): request = Request(scope, receive) response = Response("Hello, world!", media_type="text/plain") @@ -289,24 +288,24 @@ async def app(scope, receive, send): response.set_cookie("mycookie", "myvalue") await response(scope, receive, send) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.cookies["mycookie"] response = client.get("/") assert not response.cookies.get("mycookie") -def test_populate_headers(): +def test_populate_headers(test_client_factory): app = Response(content="hi", headers={}, media_type="text/html") - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "hi" assert response.headers["content-length"] == "2" assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_head_method(): +def test_head_method(test_client_factory): app = Response("hello, world", media_type="text/plain") - client = TestClient(app) + client = test_client_factory(app) response = client.head("/") assert response.text == "" diff --git a/tests/test_routing.py b/tests/test_routing.py index fff3332dbd..9e734b9cc9 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -6,7 +6,6 @@ from starlette.applications import Starlette from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -105,10 +104,19 @@ async def websocket_params(session): await session.close() -client = TestClient(app) +@pytest.fixture +def client(test_client_factory): + with test_client_factory(app) as client: + yield client -def test_router(): +@pytest.mark.filterwarnings( + r"ignore" + r":Trying to detect encoding from a tiny portion of \(5\) byte\(s\)\." + r":UserWarning" + r":charset_normalizer.api" +) +def test_router(client): response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world" @@ -147,7 +155,7 @@ def test_router(): assert response.text == "xxxxx" -def test_route_converters(): +def test_route_converters(client): # Test integer conversion response = client.get("/int/5") assert response.status_code == 200 @@ -232,19 +240,19 @@ def test_url_for(): ) -def test_router_add_route(): +def test_router_add_route(client): response = client.get("/func") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_router_duplicate_path(): +def test_router_duplicate_path(client): response = client.post("/func") assert response.status_code == 200 assert response.text == "Hello, POST!" -def test_router_add_websocket_route(): +def test_router_add_websocket_route(client): with client.websocket_connect("/ws") as session: text = session.receive_text() assert text == "Hello, world!" @@ -275,8 +283,8 @@ async def __call__(self, scope, receive, send): ) -def test_protocol_switch(): - client = TestClient(mixed_protocol_app) +def test_protocol_switch(test_client_factory): + client = test_client_factory(mixed_protocol_app) response = client.get("/") assert response.status_code == 200 @@ -286,15 +294,16 @@ def test_protocol_switch(): assert session.receive_json() == {"URL": "ws://testserver/"} with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/404") + with client.websocket_connect("/404"): + pass # pragma: nocover ok = PlainTextResponse("OK") -def test_mount_urls(): +def test_mount_urls(test_client_factory): mounted = Router([Mount("/users", ok, name="users")]) - client = TestClient(mounted) + client = test_client_factory(mounted) assert client.get("/users").status_code == 200 assert client.get("/users").url == "http://testserver/users/" assert client.get("/users/").status_code == 200 @@ -317,9 +326,9 @@ def test_reverse_mount_urls(): ) -def test_mount_at_root(): +def test_mount_at_root(test_client_factory): mounted = Router([Mount("/", ok, name="users")]) - client = TestClient(mounted) + client = test_client_factory(mounted) assert client.get("/").status_code == 200 @@ -347,8 +356,8 @@ def users_api(request): ) -def test_host_routing(): - client = TestClient(mixed_hosts_app, base_url="https://api.example.org/") +def test_host_routing(test_client_factory): + client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/") response = client.get("/users") assert response.status_code == 200 @@ -357,7 +366,7 @@ def test_host_routing(): response = client.get("/") assert response.status_code == 404 - client = TestClient(mixed_hosts_app, base_url="https://www.example.org/") + client = test_client_factory(mixed_hosts_app, base_url="https://www.example.org/") response = client.get("/users") assert response.status_code == 200 @@ -392,8 +401,8 @@ async def subdomain_app(scope, receive, send): ) -def test_subdomain_routing(): - client = TestClient(subdomain_app, base_url="https://foo.example.org/") +def test_subdomain_routing(test_client_factory): + client = test_client_factory(subdomain_app, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 @@ -428,9 +437,11 @@ async def echo_urls(request): ] -def test_url_for_with_root_path(): +def test_url_for_with_root_path(test_client_factory): app = Starlette(routes=echo_url_routes) - client = TestClient(app, base_url="https://www.example.org/", root_path="/sub_path") + client = test_client_factory( + app, base_url="https://www.example.org/", root_path="/sub_path" + ) response = client.get("/") assert response.json() == { "index": "https://www.example.org/sub_path/", @@ -458,17 +469,17 @@ def test_url_for_with_double_mount(): assert url == "/mount/static/123" -def test_standalone_route_matches(): +def test_standalone_route_matches(test_client_factory): app = Route("/", PlainTextResponse("Hello, World!")) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, World!" -def test_standalone_route_does_not_match(): +def test_standalone_route_does_not_match(test_client_factory): app = Route("/", PlainTextResponse("Hello, World!")) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/invalid") assert response.status_code == 404 assert response.text == "Not Found" @@ -480,22 +491,23 @@ async def ws_helloworld(websocket): await websocket.close() -def test_standalone_ws_route_matches(): +def test_standalone_ws_route_matches(test_client_factory): app = WebSocketRoute("/", ws_helloworld) - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: text = websocket.receive_text() assert text == "Hello, world!" -def test_standalone_ws_route_does_not_match(): +def test_standalone_ws_route_does_not_match(test_client_factory): app = WebSocketRoute("/", ws_helloworld) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(WebSocketDisconnect): - client.websocket_connect("/invalid") + with client.websocket_connect("/invalid"): + pass # pragma: nocover -def test_lifespan_async(): +def test_lifespan_async(test_client_factory): startup_complete = False shutdown_complete = False @@ -518,7 +530,7 @@ async def run_shutdown(): assert not startup_complete assert not shutdown_complete - with TestClient(app) as client: + with test_client_factory(app) as client: assert startup_complete assert not shutdown_complete client.get("/") @@ -526,7 +538,7 @@ async def run_shutdown(): assert shutdown_complete -def test_lifespan_sync(): +def test_lifespan_sync(test_client_factory): startup_complete = False shutdown_complete = False @@ -549,7 +561,7 @@ def run_shutdown(): assert not startup_complete assert not shutdown_complete - with TestClient(app) as client: + with test_client_factory(app) as client: assert startup_complete assert not shutdown_complete client.get("/") @@ -557,7 +569,7 @@ def run_shutdown(): assert shutdown_complete -def test_raise_on_startup(): +def test_raise_on_startup(test_client_factory): def run_startup(): raise RuntimeError() @@ -574,19 +586,19 @@ async def _send(message): startup_failed = False with pytest.raises(RuntimeError): - with TestClient(app): + with test_client_factory(app): pass # pragma: nocover assert startup_failed -def test_raise_on_shutdown(): +def test_raise_on_shutdown(test_client_factory): def run_shutdown(): raise RuntimeError() app = Router(on_shutdown=[run_shutdown]) with pytest.raises(RuntimeError): - with TestClient(app): + with test_client_factory(app): pass # pragma: nocover @@ -613,8 +625,8 @@ async def _partial_async_endpoint(arg, request): ) -def test_partial_async_endpoint(): - test_client = TestClient(partial_async_app) +def test_partial_async_endpoint(test_client_factory): + test_client = test_client_factory(partial_async_app) response = test_client.get("/") assert response.status_code == 200 assert response.json() == {"arg": "foo"} diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 0ae43238fe..28fe777f09 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -1,7 +1,6 @@ from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.schemas import SchemaGenerator -from starlette.testclient import TestClient schemas = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} @@ -213,8 +212,8 @@ def test_schema_generation(): """ -def test_schema_endpoint(): - client = TestClient(app) +def test_schema_endpoint(test_client_factory): + client = test_client_factory(app) response = client.get("/schema") assert response.headers["Content-Type"] == "application/vnd.oai.openapi" assert response.text.strip() == EXPECTED_SCHEMA.strip() diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 6b325071fa..d5ec1afc5e 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -1,43 +1,42 @@ -import asyncio import os import pathlib import time +import anyio import pytest from starlette.applications import Starlette from starlette.requests import Request from starlette.routing import Mount from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient -def test_staticfiles(tmpdir): +def test_staticfiles(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" -def test_staticfiles_with_pathlib(tmpdir): +def test_staticfiles_with_pathlib(tmpdir, test_client_factory): base_dir = pathlib.Path(tmpdir) path = base_dir / "example.txt" with open(path, "w") as file: file.write("") app = StaticFiles(directory=base_dir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "" -def test_staticfiles_head_with_middleware(tmpdir): +def test_staticfiles_head_with_middleware(tmpdir, test_client_factory): """ see https://github.com/encode/starlette/pull/935 """ @@ -53,51 +52,51 @@ async def does_nothing_middleware(request: Request, call_next): response = await call_next(request) return response - client = TestClient(app) + client = test_client_factory(app) response = client.head("/static/example.txt") assert response.status_code == 200 assert response.headers.get("content-length") == "100" -def test_staticfiles_with_package(): +def test_staticfiles_with_package(test_client_factory): app = StaticFiles(packages=["tests"]) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/example.txt") assert response.status_code == 200 assert response.text == "123\n" -def test_staticfiles_post(tmpdir): +def test_staticfiles_post(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.post("/example.txt") assert response.status_code == 405 assert response.text == "Method Not Allowed" -def test_staticfiles_with_directory_returns_404(tmpdir): +def test_staticfiles_with_directory_returns_404(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.status_code == 404 assert response.text == "Not Found" -def test_staticfiles_with_missing_file_returns_404(tmpdir): +def test_staticfiles_with_missing_file_returns_404(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/404.txt") assert response.status_code == 404 assert response.text == "Not Found" @@ -110,30 +109,32 @@ def test_staticfiles_instantiated_with_missing_directory(tmpdir): assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_missing_directory(tmpdir): +def test_staticfiles_configured_with_missing_directory(tmpdir, test_client_factory): path = os.path.join(tmpdir, "no_such_directory") app = StaticFiles(directory=path, check_dir=False) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") assert "does not exist" in str(exc_info.value) -def test_staticfiles_configured_with_file_instead_of_directory(tmpdir): +def test_staticfiles_configured_with_file_instead_of_directory( + tmpdir, test_client_factory +): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=path, check_dir=False) - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/example.txt") assert "is not a directory" in str(exc_info.value) -def test_staticfiles_config_check_occurs_only_once(tmpdir): +def test_staticfiles_config_check_occurs_only_once(tmpdir, test_client_factory): app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) assert not app.config_checked client.get("/") assert app.config_checked @@ -153,32 +154,31 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir): # We can't test this with 'requests', so we test the app directly here. path = app.get_path({"path": "/../example.txt"}) scope = {"method": "GET"} - loop = asyncio.get_event_loop() - response = loop.run_until_complete(app.get_response(path, scope)) + response = anyio.run(app.get_response, path, scope) assert response.status_code == 404 assert response.body == b"Not Found" -def test_staticfiles_never_read_file_for_head_method(tmpdir): +def test_staticfiles_never_read_file_for_head_method(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) response = client.head("/example.txt") assert response.status_code == 200 assert response.content == b"" assert response.headers["content-length"] == "14" -def test_staticfiles_304_with_etag_match(tmpdir): +def test_staticfiles_304_with_etag_match(tmpdir, test_client_factory): path = os.path.join(tmpdir, "example.txt") with open(path, "w") as file: file.write("") app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 last_etag = first_resp.headers["etag"] @@ -187,7 +187,9 @@ def test_staticfiles_304_with_etag_match(tmpdir): assert second_resp.content == b"" -def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): +def test_staticfiles_304_with_last_modified_compare_last_req( + tmpdir, test_client_factory +): path = os.path.join(tmpdir, "example.txt") file_last_modified_time = time.mktime( time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S") @@ -197,7 +199,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): os.utime(path, (file_last_modified_time, file_last_modified_time)) app = StaticFiles(directory=tmpdir) - client = TestClient(app) + client = test_client_factory(app) # last modified less than last request, 304 response = client.get( "/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"} @@ -212,7 +214,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir): assert response.content == b"" -def test_staticfiles_html(tmpdir): +def test_staticfiles_html(tmpdir, test_client_factory): path = os.path.join(tmpdir, "404.html") with open(path, "w") as file: file.write("

Custom not found page

") @@ -223,7 +225,7 @@ def test_staticfiles_html(tmpdir): file.write("

Hello

") app = StaticFiles(directory=tmpdir, html=True) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/dir/") assert response.url == "http://testserver/dir/" @@ -245,7 +247,9 @@ def test_staticfiles_html(tmpdir): assert response.text == "

Custom not found page

" -def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(tmpdir): +def test_staticfiles_cache_invalidation_for_deleted_file_html_mode( + tmpdir, test_client_factory +): path_404 = os.path.join(tmpdir, "404.html") with open(path_404, "w") as file: file.write("

404 file

") @@ -260,7 +264,7 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(tmpdir): os.utime(path_some, (common_modified_time, common_modified_time)) app = StaticFiles(directory=tmpdir, html=True) - client = TestClient(app) + client = test_client_factory(app) resp_exists = client.get("/some.html") assert resp_exists.status_code == 200 diff --git a/tests/test_templates.py b/tests/test_templates.py index a0ab3e1b0b..073482d65a 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -4,10 +4,9 @@ from starlette.applications import Starlette from starlette.templating import Jinja2Templates -from starlette.testclient import TestClient -def test_templates(tmpdir): +def test_templates(tmpdir, test_client_factory): path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: file.write("Hello, world") @@ -19,7 +18,7 @@ def test_templates(tmpdir): async def homepage(request): return templates.TemplateResponse("index.html", {"request": request}) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world" assert response.template.name == "index.html" diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 00f4e0125b..8c06667896 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -1,13 +1,24 @@ import asyncio +import itertools +import sys +import anyio import pytest +import sniffio +import trio.lowlevel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.responses import JSONResponse -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect +if sys.version_info >= (3, 7): # pragma: no cover + from asyncio import current_task as asyncio_current_task + from contextlib import asynccontextmanager +else: # pragma: no cover + asyncio_current_task = asyncio.Task.current_task + from contextlib2 import asynccontextmanager + mock_service = Starlette() @@ -16,14 +27,19 @@ def mock_service_endpoint(request): return JSONResponse({"mock": "example"}) -app = Starlette() - +def current_task(): + # anyio's TaskInfo comparisons are invalid after their associated native + # task object is GC'd https://github.com/agronholm/anyio/issues/324 + asynclib_name = sniffio.current_async_library() + if asynclib_name == "trio": + return trio.lowlevel.current_task() -@app.route("/") -def homepage(request): - client = TestClient(mock_service) - response = client.get("/") - return JSONResponse(response.json()) + if asynclib_name == "asyncio": + task = asyncio_current_task() + if task is None: + raise RuntimeError("must be called from a running task") # pragma: no cover + return task + raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover startup_error_app = Starlette() @@ -34,30 +50,110 @@ def startup(): raise RuntimeError() -def test_use_testclient_in_endpoint(): +def test_use_testclient_in_endpoint(test_client_factory): """ We should be able to use the test client within applications. This is useful if we need to mock out other services, during tests or in development. """ - client = TestClient(app) + + app = Starlette() + + @app.route("/") + def homepage(request): + client = test_client_factory(mock_service) + response = client.get("/") + return JSONResponse(response.json()) + + client = test_client_factory(app) response = client.get("/") assert response.json() == {"mock": "example"} -def test_use_testclient_as_contextmanager(): - with TestClient(app): - pass +def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name): + """ + This test asserts a number of properties that are important for an + app level task_group + """ + counter = itertools.count() + identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar") + + def get_identity(): + try: + return identity_runvar.get() + except LookupError: + token = next(counter) + identity_runvar.set(token) + return token + + startup_task = object() + startup_loop = None + shutdown_task = object() + shutdown_loop = None + + @asynccontextmanager + async def lifespan_context(app): + nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop + + startup_task = current_task() + startup_loop = get_identity() + async with anyio.create_task_group() as app.task_group: + yield + shutdown_task = current_task() + shutdown_loop = get_identity() + + app = Starlette(lifespan=lifespan_context) + + @app.route("/loop_id") + async def loop_id(request): + return JSONResponse(get_identity()) + + client = test_client_factory(app) + + with client: + # within a TestClient context every async request runs in the same thread + assert client.get("/loop_id").json() == 0 + assert client.get("/loop_id").json() == 0 + + # that thread is also the same as the lifespan thread + assert startup_loop == 0 + assert shutdown_loop == 0 + + # lifespan events run in the same task, this is important because a task + # group must be entered and exited in the same task. + assert startup_task is shutdown_task + + # outside the TestClient context, new requests continue to spawn in new + # eventloops in new threads + assert client.get("/loop_id").json() == 1 + assert client.get("/loop_id").json() == 2 + + first_task = startup_task + + with client: + # the TestClient context can be re-used, starting a new lifespan task + # in a new thread + assert client.get("/loop_id").json() == 3 + assert client.get("/loop_id").json() == 3 + + assert startup_loop == 3 + assert shutdown_loop == 3 + + # lifespan events still run in the same task, with the context but... + assert startup_task is shutdown_task + + # ... the second TestClient context creates a new lifespan task. + assert first_task is not startup_task -def test_error_on_startup(): +def test_error_on_startup(test_client_factory): with pytest.raises(RuntimeError): - with TestClient(startup_error_app): + with test_client_factory(startup_error_app): pass # pragma: no cover -def test_exception_in_middleware(): +def test_exception_in_middleware(test_client_factory): class MiddlewareException(Exception): pass @@ -71,11 +167,11 @@ async def __call__(self, scope, receive, send): broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)]) with pytest.raises(MiddlewareException): - with TestClient(broken_middleware): + with test_client_factory(broken_middleware): pass # pragma: no cover -def test_testclient_asgi2(): +def test_testclient_asgi2(test_client_factory): def app(scope): async def inner(receive, send): await send( @@ -89,12 +185,12 @@ async def inner(receive, send): return inner - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_testclient_asgi3(): +def test_testclient_asgi3(test_client_factory): async def app(scope, receive, send): await send( { @@ -105,12 +201,12 @@ async def app(scope, receive, send): ) await send({"type": "http.response.body", "body": b"Hello, world!"}) - client = TestClient(app) + client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" -def test_websocket_blocking_receive(): +def test_websocket_blocking_receive(test_client_factory): def app(scope): async def respond(websocket): await websocket.send_json({"message": "test"}) @@ -118,17 +214,18 @@ async def respond(websocket): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() - asyncio.ensure_future(respond(websocket)) - try: - # this will block as the client does not send us data - # it should not prevent `respond` from executing though - await websocket.receive_json() - except WebSocketDisconnect: - pass + async with anyio.create_task_group() as task_group: + task_group.start_soon(respond, websocket) + try: + # this will block as the client does not send us data + # it should not prevent `respond` from executing though + await websocket.receive_json() + except WebSocketDisconnect: + pass return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} diff --git a/tests/test_websockets.py b/tests/test_websockets.py index ffb1a44a8c..e02d433d57 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,14 +1,11 @@ -import asyncio - +import anyio import pytest from starlette import status -from starlette.concurrency import run_until_first_complete -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect -def test_websocket_url(): +def test_websocket_url(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -18,13 +15,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/123?a=abc"} -def test_websocket_binary_json(): +def test_websocket_binary_json(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -35,14 +32,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: websocket.send_json({"test": "data"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"test": "data"} -def test_websocket_query_params(): +def test_websocket_query_params(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -53,13 +50,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/?a=abc&b=456") as websocket: data = websocket.receive_json() assert data == {"params": {"a": "abc", "b": "456"}} -def test_websocket_headers(): +def test_websocket_headers(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -70,7 +67,7 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: expected_headers = { "accept": "*/*", @@ -85,7 +82,7 @@ async def asgi(receive, send): assert data == {"headers": expected_headers} -def test_websocket_port(): +def test_websocket_port(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -95,13 +92,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"port": 123} -def test_websocket_send_and_receive_text(): +def test_websocket_send_and_receive_text(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -112,14 +109,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" -def test_websocket_send_and_receive_bytes(): +def test_websocket_send_and_receive_bytes(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -130,14 +127,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" -def test_websocket_send_and_receive_json(): +def test_websocket_send_and_receive_json(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -148,14 +145,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} -def test_websocket_iter_text(): +def test_websocket_iter_text(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -165,14 +162,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" -def test_websocket_iter_bytes(): +def test_websocket_iter_bytes(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -182,14 +179,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" -def test_websocket_iter_json(): +def test_websocket_iter_json(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -199,44 +196,45 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} -def test_websocket_concurrency_pattern(): +def test_websocket_concurrency_pattern(test_client_factory): def app(scope): - async def reader(websocket, queue): - async for data in websocket.iter_json(): - await queue.put(data) + stream_send, stream_receive = anyio.create_memory_object_stream() + + async def reader(websocket): + async with stream_send: + async for data in websocket.iter_json(): + await stream_send.send(data) - async def writer(websocket, queue): - while True: - message = await queue.get() - await websocket.send_json(message) + async def writer(websocket): + async with stream_receive: + async for message in stream_receive: + await websocket.send_json(message) async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) - queue = asyncio.Queue() await websocket.accept() - await run_until_first_complete( - (reader, {"websocket": websocket, "queue": queue}), - (writer, {"websocket": websocket, "queue": queue}), - ) + async with anyio.create_task_group() as task_group: + task_group.start_soon(reader, websocket) + await writer(websocket) await websocket.close() return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"hello": "world"} -def test_client_close(): +def test_client_close(test_client_factory): close_code = None def app(scope): @@ -251,13 +249,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.close(code=status.WS_1001_GOING_AWAY) assert close_code == status.WS_1001_GOING_AWAY -def test_application_close(): +def test_application_close(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -266,14 +264,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY -def test_rejected_connection(): +def test_rejected_connection(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -281,13 +279,14 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(WebSocketDisconnect) as exc: - client.websocket_connect("/") + with client.websocket_connect("/"): + pass # pragma: nocover assert exc.value.code == status.WS_1001_GOING_AWAY -def test_subprotocol(): +def test_subprotocol(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -297,24 +296,25 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" -def test_websocket_exception(): +def test_websocket_exception(test_client_factory): def app(scope): async def asgi(receive, send): assert False return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(AssertionError): - client.websocket_connect("/123?a=abc") + with client.websocket_connect("/123?a=abc"): + pass # pragma: nocover -def test_duplicate_close(): +def test_duplicate_close(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -324,13 +324,13 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): - pass + pass # pragma: nocover -def test_duplicate_disconnect(): +def test_duplicate_disconnect(test_client_factory): def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) @@ -341,7 +341,7 @@ async def asgi(receive, send): return asgi - client = TestClient(app) + client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.close() @@ -367,3 +367,13 @@ async def mock_send(message): assert websocket["type"] == "websocket" assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []} assert len(websocket) == 3 + + # check __eq__ and __hash__ + assert websocket != WebSocket( + {"type": "websocket", "path": "/abc/", "headers": []}, + receive=mock_receive, + send=mock_send, + ) + assert websocket == websocket + assert websocket in {websocket} + assert {websocket} == {websocket}