diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 0000000000..2f87d94ca1
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1 @@
+github: encode
diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml
index eed46850a9..751c5193be 100644
--- a/.github/workflows/test-suite.yml
+++ b/.github/workflows/test-suite.yml
@@ -14,7 +14,7 @@ jobs:
strategy:
matrix:
- python-version: ["3.6", "3.7", "3.8", "3.9"]
+ python-version: ["3.6", "3.7", "3.8", "3.9", "3.10.0-beta.3"]
steps:
- uses: "actions/checkout@v2"
diff --git a/README.md b/README.md
index 7e315627ed..28b6f49854 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,7 @@
# Starlette
Starlette is a lightweight [ASGI](https://asgi.readthedocs.io/en/latest/) framework/toolkit,
-which is ideal for building high performance asyncio services.
+which is ideal for building high performance async services.
It is production-ready, and gives you the following:
@@ -35,7 +35,8 @@ It is production-ready, and gives you the following:
* Session and Cookie support.
* 100% test coverage.
* 100% type annotated codebase.
-* Zero hard dependencies.
+* Few hard dependencies.
+* Compatible with `asyncio` and `trio` backends.
## Requirements
@@ -83,10 +84,9 @@ For a more complete example, see [encode/starlette-example](https://github.com/e
## Dependencies
-Starlette does not have any hard dependencies, but the following are optional:
+Starlette only requires `anyio`, and the following are optional:
* [`requests`][requests] - Required if you want to use the `TestClient`.
-* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
@@ -165,7 +165,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...
Starlette is BSD licensed code. Designed & built in Brighton, England.
[requests]: http://docs.python-requests.org/en/master/
-[aiofiles]: https://github.com/Tinche/aiofiles
[jinja2]: http://jinja.pocoo.org/
[python-multipart]: https://andrew-d.github.io/python-multipart/
[graphene]: https://graphene-python.org/
diff --git a/docs/config.md b/docs/config.md
index 7a93b22e9e..f7c2c7b7de 100644
--- a/docs/config.md
+++ b/docs/config.md
@@ -160,7 +160,7 @@ organisations = sqlalchemy.Table(
```python
from starlette.applications import Starlette
from starlette.middleware import Middleware
-from starlette.middleware.session import SessionMiddleware
+from starlette.middleware.sessions import SessionMiddleware
from starlette.routing import Route
from myproject import settings
@@ -192,7 +192,7 @@ and drop it once the tests complete. We'd also like to ensure
from starlette.config import environ
from starlette.testclient import TestClient
from sqlalchemy import create_engine
-from sqlalchemy_utils import database_exists, create_database
+from sqlalchemy_utils import create_database, database_exists, drop_database
# This line would raise an error if we use it after 'settings' has been imported.
environ['TESTING'] = 'TRUE'
diff --git a/docs/database.md b/docs/database.md
index ca1b85d6a2..aa6cb74edf 100644
--- a/docs/database.md
+++ b/docs/database.md
@@ -142,10 +142,10 @@ async def populate_note(request):
await database.execute(query)
raise RuntimeError()
except:
- transaction.rollback()
+ await transaction.rollback()
raise
else:
- transaction.commit()
+ await transaction.commit()
```
## Test isolation
diff --git a/docs/events.md b/docs/events.md
index 4f2bce558b..c7ed49e9dc 100644
--- a/docs/events.md
+++ b/docs/events.md
@@ -37,6 +37,31 @@ registered startup handlers have completed.
The shutdown handlers will run once all connections have been closed, and
any in-process background tasks have completed.
+A single lifespan asynccontextmanager handler can be used instead of
+separate startup and shutdown handlers:
+
+```python
+import contextlib
+import anyio
+from starlette.applications import Starlette
+
+
+@contextlib.asynccontextmanager
+async def lifespan(app):
+ async with some_async_resource():
+ yield
+
+
+routes = [
+ ...
+]
+
+app = Starlette(routes=routes, lifespan=lifespan)
+```
+
+Consider using [`anyio.create_task_group()`](https://anyio.readthedocs.io/en/stable/tasks.html)
+for managing asynchronious tasks.
+
## Running event handlers in tests
You might want to explicitly call into your event handlers in any test setup
diff --git a/docs/img/graphiql.png b/docs/img/graphiql.png
deleted file mode 100644
index 7851993f7f..0000000000
Binary files a/docs/img/graphiql.png and /dev/null differ
diff --git a/docs/index.md b/docs/index.md
index 4ae77f0e60..b9692a1fbc 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -32,7 +32,7 @@ It is production-ready, and gives you the following:
* Session and Cookie support.
* 100% test coverage.
* 100% type annotated codebase.
-* Zero hard dependencies.
+* Few hard dependencies.
## Requirements
@@ -79,10 +79,9 @@ For a more complete example, [see here](https://github.com/encode/starlette-exam
## Dependencies
-Starlette does not have any hard dependencies, but the following are optional:
+Starlette only requires `anyio`, and the following dependencies are optional:
* [`requests`][requests] - Required if you want to use the `TestClient`.
-* [`aiofiles`][aiofiles] - Required if you want to use `FileResponse` or `StaticFiles`.
* [`jinja2`][jinja2] - Required if you want to use `Jinja2Templates`.
* [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`.
* [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support.
@@ -161,7 +160,6 @@ gunicorn -k uvicorn.workers.UvicornH11Worker ...
Starlette is BSD licensed code. Designed & built in Brighton, England.
[requests]: http://docs.python-requests.org/en/master/
-[aiofiles]: https://github.com/Tinche/aiofiles
[jinja2]: http://jinja.pocoo.org/
[python-multipart]: https://andrew-d.github.io/python-multipart/
[graphene]: https://graphene-python.org/
diff --git a/docs/middleware.md b/docs/middleware.md
index 7d1233d100..4c6fb8a9e0 100644
--- a/docs/middleware.md
+++ b/docs/middleware.md
@@ -183,6 +183,8 @@ To implement a middleware class using `BaseHTTPMiddleware`, you must override th
`async def dispatch(request, call_next)` method.
```python
+from starlette.middleware.base import BaseHTTPMiddleware
+
class CustomHeaderMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
response = await call_next(request)
@@ -256,7 +258,7 @@ when proxy servers are being used, based on the `X-Forwarded-Proto` and `X-Forwa
A middleware class to emit timing information (cpu and wall time) for each request which
passes through it. Includes examples for how to emit these timings as statsd metrics.
-#### [datasette-auth-github](https://github.com/simonw/datasette-auth-github)
+#### [asgi-auth-github](https://github.com/simonw/asgi-auth-github)
This middleware adds authentication to any ASGI application, requiring users to sign in
using their GitHub account (via [OAuth](https://developer.github.com/apps/building-oauth-apps/authorizing-oauth-apps/)).
diff --git a/docs/overrides/partials/nav.html b/docs/overrides/partials/nav.html
new file mode 100644
index 0000000000..d4684d0a68
--- /dev/null
+++ b/docs/overrides/partials/nav.html
@@ -0,0 +1,52 @@
+
+ {% set class = "md-nav md-nav--primary" %}
+ {% if "navigation.tabs" in features %}
+ {% set class = class ~ " md-nav--lifted" %}
+ {% endif %}
+ {% if "toc.integrate" in features %}
+ {% set class = class ~ " md-nav--integrated" %}
+ {% endif %}
+
+
+
+
+
+
+
+ {% include "partials/logo.html" %}
+
+ {{ config.site_name }}
+
+
+
+ {% if config.repo_url %}
+
+ {% include "partials/source.html" %}
+
+ {% endif %}
+
+
+
+ {% for nav_item in nav %}
+ {% set path = "__nav_" ~ loop.index %}
+ {% set level = 1 %}
+ {% include "partials/nav-item.html" %}
+ {% endfor %}
+
+
+
+
+
+
+
+
diff --git a/docs/release-notes.md b/docs/release-notes.md
index 08f1e2895b..7305046f9d 100644
--- a/docs/release-notes.md
+++ b/docs/release-notes.md
@@ -1,11 +1,72 @@
+## 0.16.0
+
+July 19, 2021
+
+### Added
+ * Added [Encode](https://github.com/sponsors/encode) funding option
+ [#1219](https://github.com/encode/starlette/pull/1219)
+
+### Fixed
+ * `starlette.websockets.WebSocket` instances are now hashable and compare by identity
+ [#1039](https://github.com/encode/starlette/pull/1039)
+ * A number of fixes related to running task groups in lifespan
+ [#1213](https://github.com/encode/starlette/pull/1213),
+ [#1227](https://github.com/encode/starlette/pull/1227)
+
+### Deprecated/removed
+ * The method `starlette.templates.Jinja2Templates.get_env` was removed
+ [#1218](https://github.com/encode/starlette/pull/1218)
+ * The ClassVar `starlette.testclient.TestClient.async_backend` was removed,
+ the backend is now configured using constructor kwargs
+ [#1211](https://github.com/encode/starlette/pull/1211)
+ * Passing an Async Generator Function or a Generator Function to `starlette.router.Router(lifespan_context=)` is deprecated. You should wrap your lifespan in `@contextlib.asynccontextmanager`.
+ [#1227](https://github.com/encode/starlette/pull/1227)
+ [#1110](https://github.com/encode/starlette/pull/1110)
+
## 0.15.0
-Unreleased
+June 23, 2021
+
+This release includes major changes to the low-level asynchronous parts of Starlette. As a result,
+**Starlette now depends on [AnyIO](https://anyio.readthedocs.io/en/stable/)** and some minor API
+changes have occurred. Another significant change with this release is the
+**deprecation of built-in GraphQL support**.
-### Deprecated
+### Added
+* Starlette now supports [Trio](https://trio.readthedocs.io/en/stable/) as an async runtime via
+ AnyIO - [#1157](https://github.com/encode/starlette/pull/1157).
+* `TestClient.websocket_connect()` now must be used as a context manager.
+* Initial support for Python 3.10 - [#1201](https://github.com/encode/starlette/pull/1201).
+* The compression level used in `GZipMiddleware` is now adjustable -
+ [#1128](https://github.com/encode/starlette/pull/1128).
+
+### Fixed
+* Several fixes to `CORSMiddleware`. See [#1111](https://github.com/encode/starlette/pull/1111),
+ [#1112](https://github.com/encode/starlette/pull/1112),
+ [#1113](https://github.com/encode/starlette/pull/1113),
+ [#1199](https://github.com/encode/starlette/pull/1199).
+* Improved exception messages in the case of duplicated path parameter names -
+ [#1177](https://github.com/encode/starlette/pull/1177).
+* `RedirectResponse` now uses `quote` instead of `quote_plus` encoding for the `Location` header
+ to better match the behaviour in other frameworks such as Django -
+ [#1164](https://github.com/encode/starlette/pull/1164).
+* Exception causes are now preserved in more cases -
+ [#1158](https://github.com/encode/starlette/pull/1158).
+* Session cookies now use the ASGI root path in the case of mounted applications -
+ [#1147](https://github.com/encode/starlette/pull/1147).
+* Fixed a cache invalidation bug when static files were deleted in certain circumstances -
+ [#1023](https://github.com/encode/starlette/pull/1023).
+* Improved memory usage of `BaseHTTPMiddleware` when handling large responses -
+ [#1012](https://github.com/encode/starlette/issues/1012) fixed via #1157
+
+### Deprecated/removed
* Built-in GraphQL support via the `GraphQLApp` class has been deprecated and will be removed in a
- future release. Please see [#619](https://github.com/encode/starlette/issues/619).
+ future release. Please see [#619](https://github.com/encode/starlette/issues/619). GraphQL is not
+ supported on Python 3.10.
+* The `executor` parameter to `GraphQLApp` was removed. Use `executor_class` instead.
+* The `workers` parameter to `WSGIMiddleware` was removed. This hasn't had any effect since
+ Starlette v0.6.3.
## 0.14.2
diff --git a/docs/responses.md b/docs/responses.md
index c4cd84ed32..5284ac5044 100644
--- a/docs/responses.md
+++ b/docs/responses.md
@@ -182,9 +182,8 @@ async def app(scope, receive, send):
await response(scope, receive, send)
```
-## Third party middleware
+## Third party responses
-### [SSEResponse(EventSourceResponse)](https://github.com/sysid/sse-starlette)
+#### [EventSourceResponse](https://github.com/sysid/sse-starlette)
-Server Sent Response implements the ServerSentEvent Protocol: https://www.w3.org/TR/2009/WD-eventsource-20090421.
-It enables event streaming from the server to the client without the complexity of websockets.
+A response class that implements [Server-Sent Events](https://html.spec.whatwg.org/multipage/server-sent-events.html). It enables event streaming from the server to the client without the complexity of websockets.
diff --git a/docs/schemas.md b/docs/schemas.md
index 2530ba8d1f..275e7b2968 100644
--- a/docs/schemas.md
+++ b/docs/schemas.md
@@ -51,7 +51,7 @@ routes = [
Route("/schema", endpoint=openapi_schema, include_in_schema=False)
]
-app = Starlette()
+app = Starlette(routes=routes)
```
We can now access an OpenAPI schema at the "/schema" endpoint.
diff --git a/docs/sponsors/fastapi.png b/docs/sponsors/fastapi.png
new file mode 100644
index 0000000000..a5b2af17eb
Binary files /dev/null and b/docs/sponsors/fastapi.png differ
diff --git a/docs/testclient.md b/docs/testclient.md
index 61f7201c62..a1861efec7 100644
--- a/docs/testclient.md
+++ b/docs/testclient.md
@@ -31,6 +31,29 @@ application. Occasionally you might want to test the content of 500 error
responses, rather than allowing client to raise the server exception. In this
case you should use `client = TestClient(app, raise_server_exceptions=False)`.
+### Selecting the Async backend
+
+`TestClient` takes arguments `backend` (a string) and `backend_options` (a dictionary).
+These options are passed to `anyio.start_blocking_portal()`. See the [anyio documentation](https://anyio.readthedocs.io/en/stable/basics.html#backend-options)
+for more information about the accepted backend options.
+By default, `asyncio` is used with default options.
+
+To run `Trio`, pass `backend="trio"`. For example:
+
+```python
+def test_app()
+ with TestClient(app, backend="trio") as client:
+ ...
+```
+
+To run `asyncio` with `uvloop`, pass `backend_options={"use_uvloop": True}`. For example:
+
+```python
+def test_app()
+ with TestClient(app, backend_options={"use_uvloop": True}) as client:
+ ...
+```
+
### Testing WebSocket sessions
You can also test websocket sessions with the test client.
@@ -72,6 +95,8 @@ always raised by the test client.
May raise `starlette.websockets.WebSocketDisconnect` if the application does not accept the websocket connection.
+`websocket_connect()` must be used as a context manager (in a `with` block).
+
#### Sending data
* `.send_text(data)` - Send the given text to the application.
diff --git a/docs/third-party-packages.md b/docs/third-party-packages.md
index 71902ce83f..a21f31ba2c 100644
--- a/docs/third-party-packages.md
+++ b/docs/third-party-packages.md
@@ -20,7 +20,7 @@ Simple APISpec integration for Starlette.
Document your REST API built with Starlette by declaring OpenAPI (Swagger)
schemas in YAML format in your endpoint's docstrings.
-### SpecTree
+### SpecTree
GitHub
@@ -43,7 +43,7 @@ Checkout nejma
GitHub
-Another solution for websocket broadcast. Send messages to channel groups from any part of your code.
+Another solution for websocket broadcast. Send messages to channel groups from any part of your code.
Checkout channel-box-chat , a simple chat application built using `channel-box` and `starlette`.
### Scout APM
@@ -90,6 +90,21 @@ It relies solely on an auth provider to issue access and/or id tokens to clients
Middleware for Starlette that allows you to store and access the context data of a request.
Can be used with logging so logs automatically use request headers such as x-request-id or x-correlation-id.
+
+### Starsessions
+
+GitHub
+
+An alternate session support implementation with customizable storage backends.
+
+
+### Starlette Cramjam
+
+GitHub
+
+A Starlette middleware that allows **brotli**, **gzip** and **deflate** compression algorithm with a minimal requirements.
+
+
## Frameworks
### Responder
@@ -116,3 +131,9 @@ Inspired by **APIStar**'s previous server system with type declarations for rout
Formerly Starlette API.
Flama aims to bring a layer on top of Starlette to provide an **easy to learn** and **fast to develop** approach for building **highly performant** GraphQL and REST APIs. In the same way of Starlette is, Flama is a perfect option for developing **asynchronous** and **production-ready** services.
+
+### Starlette-apps
+
+Roll your own framework with a simple app system, like [Django-GDAPS](https://gdaps.readthedocs.io/en/latest/) or [CakePHP](https://cakephp.org/).
+
+GitHub
diff --git a/mkdocs.yml b/mkdocs.yml
index 0f10c19ff6..2ab5faa680 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -1,8 +1,22 @@
site_name: Starlette
site_description: The little ASGI library that shines.
+site_url: https://www.starlette.io
theme:
name: 'material'
+ custom_dir: docs/overrides
+ palette:
+ - scheme: 'default'
+ media: '(prefers-color-scheme: light)'
+ toggle:
+ icon: 'material/lightbulb'
+ name: "Switch to dark mode"
+ - scheme: 'slate'
+ media: '(prefers-color-scheme: dark)'
+ primary: 'blue'
+ toggle:
+ icon: 'material/lightbulb-outline'
+ name: 'Switch to light mode'
repo_name: encode/starlette
repo_url: https://github.com/encode/starlette
diff --git a/requirements.txt b/requirements.txt
index 6ec5bf09ee..abc7a3b0a5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,5 @@
# Optionals
-aiofiles
-graphene
+graphene; python_version<'3.10'
itsdangerous
jinja2
python-multipart
@@ -10,17 +9,17 @@ requests
# Testing
autoflake
black==20.8b1
+coverage>=5.3
databases[sqlite]
flake8
isort==5.*
mypy
types-requests
types-contextvars
-types-aiofiles
types-PyYAML
+types-dataclasses
pytest
-pytest-cov
-pytest-asyncio
+trio
# Documentation
mkdocs
diff --git a/scripts/test b/scripts/test
index f9c9917233..720a66392d 100755
--- a/scripts/test
+++ b/scripts/test
@@ -11,7 +11,7 @@ if [ -z $GITHUB_ACTIONS ]; then
scripts/check
fi
-${PREFIX}pytest $@
+${PREFIX}coverage run -m pytest $@
if [ -z $GITHUB_ACTIONS ]; then
scripts/coverage
diff --git a/setup.cfg b/setup.cfg
index 196414c0a7..39f67c2150 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -17,9 +17,6 @@ combine_as_imports = True
[tool:pytest]
addopts =
- --cov-report=term-missing:skip-covered
- --cov=starlette
- --cov=tests
-rxXs
--strict-config
--strict-markers
@@ -32,3 +29,8 @@ filterwarnings=
ignore: Using or importing the ABCs from 'collections' instead of from 'collections\.abc' is deprecated.*:DeprecationWarning
ignore: The 'context' alias has been deprecated. Please use 'context_value' instead\.:DeprecationWarning
ignore: The 'variables' alias has been deprecated. Please use 'variable_values' instead\.:DeprecationWarning
+ # Workaround for Python 3.9.7 (see https://bugs.python.org/issue45097)
+ ignore:The loop argument is deprecated since Python 3\.8, and scheduled for removal in Python 3\.10\.:DeprecationWarning:asyncio
+
+[coverage:run]
+source_pkgs = starlette, tests
diff --git a/setup.py b/setup.py
index c48356370c..31789fe09d 100644
--- a/setup.py
+++ b/setup.py
@@ -37,10 +37,14 @@ def get_long_description():
packages=find_packages(exclude=["tests*"]),
package_data={"starlette": ["py.typed"]},
include_package_data=True,
+ install_requires=[
+ "anyio>=3.0.0,<4",
+ "typing_extensions; python_version < '3.8'",
+ "contextlib2 >= 21.6.0; python_version < '3.7'",
+ ],
extras_require={
"full": [
- "aiofiles",
- "graphene",
+ "graphene; python_version<'3.10'",
"itsdangerous",
"jinja2",
"python-multipart",
@@ -60,6 +64,7 @@ def get_long_description():
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
],
zip_safe=False,
)
diff --git a/starlette/__init__.py b/starlette/__init__.py
index 745162e731..5a313cc7ef 100644
--- a/starlette/__init__.py
+++ b/starlette/__init__.py
@@ -1 +1 @@
-__version__ = "0.14.2"
+__version__ = "0.16.0"
diff --git a/starlette/applications.py b/starlette/applications.py
index 34c3e38bd9..ea52ee70ee 100644
--- a/starlette/applications.py
+++ b/starlette/applications.py
@@ -46,7 +46,7 @@ def __init__(
] = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
- lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None,
+ lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None,
) -> None:
# The lifespan context function is a newer style that replaces
# on_startup / on_shutdown handlers. Use one or the other, not both.
diff --git a/starlette/concurrency.py b/starlette/concurrency.py
index c8c5d57acb..e89d1e0471 100644
--- a/starlette/concurrency.py
+++ b/starlette/concurrency.py
@@ -1,33 +1,32 @@
-import asyncio
import functools
-import sys
import typing
from typing import Any, AsyncGenerator, Iterator
+import anyio
+
try:
import contextvars # Python 3.7+ only or via contextvars backport.
except ImportError: # pragma: no cover
contextvars = None # type: ignore
-if sys.version_info >= (3, 7): # pragma: no cover
- from asyncio import create_task
-else: # pragma: no cover
- from asyncio import ensure_future as create_task
T = typing.TypeVar("T")
async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
- tasks = [create_task(handler(**kwargs)) for handler, kwargs in args]
- (done, pending) = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
- [task.cancel() for task in pending]
- [task.result() for task in done]
+ async with anyio.create_task_group() as task_group:
+
+ async def run(func: typing.Callable[[], typing.Coroutine]) -> None:
+ await func()
+ task_group.cancel_scope.cancel()
+
+ for func, kwargs in args:
+ task_group.start_soon(run, functools.partial(func, **kwargs))
async def run_in_threadpool(
func: typing.Callable[..., T], *args: typing.Any, **kwargs: typing.Any
) -> T:
- loop = asyncio.get_event_loop()
if contextvars is not None: # pragma: no cover
# Ensure we run in the same context
child = functools.partial(func, *args, **kwargs)
@@ -35,9 +34,9 @@ async def run_in_threadpool(
func = context.run
args = (child,)
elif kwargs: # pragma: no cover
- # loop.run_in_executor doesn't accept 'kwargs', so bind them in here
+ # run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
- return await loop.run_in_executor(None, func, *args)
+ return await anyio.to_thread.run_sync(func, *args)
class _StopIteration(Exception):
@@ -57,6 +56,6 @@ def _next(iterator: Iterator) -> Any:
async def iterate_in_threadpool(iterator: Iterator) -> AsyncGenerator:
while True:
try:
- yield await run_in_threadpool(_next, iterator)
+ yield await anyio.to_thread.run_sync(_next, iterator)
except _StopIteration:
break
diff --git a/starlette/config.py b/starlette/config.py
index e9894e0773..7444ae06d2 100644
--- a/starlette/config.py
+++ b/starlette/config.py
@@ -46,6 +46,8 @@ def __len__(self) -> int:
environ = Environ()
+T = typing.TypeVar("T")
+
class Config:
def __init__(
@@ -58,6 +60,24 @@ def __init__(
if env_file is not None and os.path.isfile(env_file):
self.file_values = self._read_file(env_file)
+ @typing.overload
+ def __call__(
+ self, key: str, cast: typing.Type[T], default: T = ...
+ ) -> T: # pragma: no cover
+ ...
+
+ @typing.overload
+ def __call__(
+ self, key: str, cast: typing.Type[str] = ..., default: str = ...
+ ) -> str: # pragma: no cover
+ ...
+
+ @typing.overload
+ def __call__(
+ self, key: str, cast: typing.Type[str] = ..., default: T = ...
+ ) -> typing.Union[T, str]: # pragma: no cover
+ ...
+
def __call__(
self, key: str, cast: typing.Callable = None, default: typing.Any = undefined
) -> typing.Any:
diff --git a/starlette/datastructures.py b/starlette/datastructures.py
index 5149a6e2e5..17dc46eb6a 100644
--- a/starlette/datastructures.py
+++ b/starlette/datastructures.py
@@ -266,7 +266,7 @@ def __init__(
self._dict = {k: v for k, v in _items}
self._list = _items
- def getlist(self, key: typing.Any) -> typing.List[str]:
+ def getlist(self, key: typing.Any) -> typing.List[typing.Any]:
return [item_value for item_key, item_value in self._list if item_key == key]
def keys(self) -> typing.KeysView:
diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index b347a6a2dd..77ba669251 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -1,9 +1,10 @@
-import asyncio
import typing
+import anyio
+
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
-from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from starlette.types import ASGIApp, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
@@ -21,45 +22,39 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await self.app(scope, receive, send)
return
- request = Request(scope, receive=receive)
- response = await self.dispatch_func(request, self.call_next)
- await response(scope, receive, send)
+ async def call_next(request: Request) -> Response:
+ send_stream, recv_stream = anyio.create_memory_object_stream()
- async def call_next(self, request: Request) -> Response:
- loop = asyncio.get_event_loop()
- queue: "asyncio.Queue[typing.Optional[Message]]" = asyncio.Queue()
+ async def coro() -> None:
+ async with send_stream:
+ await self.app(scope, request.receive, send_stream.send)
- scope = request.scope
- receive = request.receive
- send = queue.put
+ task_group.start_soon(coro)
- async def coro() -> None:
try:
- await self.app(scope, receive, send)
- finally:
- await queue.put(None)
-
- task = loop.create_task(coro())
- message = await queue.get()
- if message is None:
- task.result()
- raise RuntimeError("No response returned.")
- assert message["type"] == "http.response.start"
-
- async def body_stream() -> typing.AsyncGenerator[bytes, None]:
- while True:
- message = await queue.get()
- if message is None:
- break
- assert message["type"] == "http.response.body"
- yield message.get("body", b"")
- task.result()
-
- response = StreamingResponse(
- status_code=message["status"], content=body_stream()
- )
- response.raw_headers = message["headers"]
- return response
+ message = await recv_stream.receive()
+ except anyio.EndOfStream:
+ raise RuntimeError("No response returned.")
+
+ assert message["type"] == "http.response.start"
+
+ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
+ async with recv_stream:
+ async for message in recv_stream:
+ assert message["type"] == "http.response.body"
+ yield message.get("body", b"")
+
+ response = StreamingResponse(
+ status_code=message["status"], content=body_stream()
+ )
+ response.raw_headers = message["headers"]
+ return response
+
+ async with anyio.create_task_group() as task_group:
+ request = Request(scope, receive=receive)
+ response = await self.dispatch_func(request, call_next)
+ await response(scope, receive, send)
+ task_group.cancel_scope.cancel()
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py
index 0b3f505e71..c850579c80 100644
--- a/starlette/middleware/cors.py
+++ b/starlette/middleware/cors.py
@@ -129,6 +129,7 @@ def preflight_response(self, request_headers: Headers) -> Response:
for header in [h.lower() for h in requested_headers.split(",")]:
if header.strip() not in self.allow_headers:
failures.append("headers")
+ break
# We don't strictly need to use 400 responses here, since its up to
# the browser to enforce the CORS policy, but its more informative
diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py
index a13ec5c0ed..ad7a6ee899 100644
--- a/starlette/middleware/sessions.py
+++ b/starlette/middleware/sessions.py
@@ -3,7 +3,7 @@
from base64 import b64decode, b64encode
import itsdangerous
-from itsdangerous.exc import BadTimeSignature, SignatureExpired
+from itsdangerous.exc import BadSignature
from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
@@ -42,7 +42,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
data = self.signer.unsign(data, max_age=self.max_age)
scope["session"] = json.loads(b64decode(data))
initial_session_was_empty = False
- except (BadTimeSignature, SignatureExpired):
+ except BadSignature:
scope["session"] = {}
else:
scope["session"] = {}
diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py
index 6b30610bc6..7e69e1a6b2 100644
--- a/starlette/middleware/wsgi.py
+++ b/starlette/middleware/wsgi.py
@@ -1,10 +1,11 @@
-import asyncio
import io
+import math
import sys
import typing
-from starlette.concurrency import run_in_threadpool
-from starlette.types import Message, Receive, Scope, Send
+import anyio
+
+from starlette.types import Receive, Scope, Send
def build_environ(scope: Scope, body: bytes) -> dict:
@@ -54,7 +55,7 @@ def build_environ(scope: Scope, body: bytes) -> dict:
class WSGIMiddleware:
- def __init__(self, app: typing.Callable, workers: int = 10) -> None:
+ def __init__(self, app: typing.Callable) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -69,9 +70,9 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None:
self.scope = scope
self.status = None
self.response_headers = None
- self.send_event = asyncio.Event()
- self.send_queue: typing.List[typing.Optional[Message]] = []
- self.loop = asyncio.get_event_loop()
+ self.stream_send, self.stream_receive = anyio.create_memory_object_stream(
+ math.inf
+ )
self.response_started = False
self.exc_info: typing.Any = None
@@ -83,31 +84,18 @@ async def __call__(self, receive: Receive, send: Send) -> None:
body += message.get("body", b"")
more_body = message.get("more_body", False)
environ = build_environ(self.scope, body)
- sender = None
- try:
- sender = self.loop.create_task(self.sender(send))
- await run_in_threadpool(self.wsgi, environ, self.start_response)
- self.send_queue.append(None)
- self.send_event.set()
- await asyncio.wait_for(sender, None)
- if self.exc_info is not None:
- raise self.exc_info[0].with_traceback(
- self.exc_info[1], self.exc_info[2]
- )
- finally:
- if sender and not sender.done():
- sender.cancel() # pragma: no cover
+
+ async with anyio.create_task_group() as task_group:
+ task_group.start_soon(self.sender, send)
+ async with self.stream_send:
+ await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
+ if self.exc_info is not None:
+ raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
async def sender(self, send: Send) -> None:
- while True:
- if self.send_queue:
- message = self.send_queue.pop(0)
- if message is None:
- return
+ async with self.stream_receive:
+ async for message in self.stream_receive:
await send(message)
- else:
- await self.send_event.wait()
- self.send_event.clear()
def start_response(
self,
@@ -124,21 +112,22 @@ def start_response(
(name.strip().encode("ascii").lower(), value.strip().encode("ascii"))
for name, value in response_headers
]
- self.send_queue.append(
+ anyio.from_thread.run(
+ self.stream_send.send,
{
"type": "http.response.start",
"status": status_code,
"headers": headers,
- }
+ },
)
- self.loop.call_soon_threadsafe(self.send_event.set)
def wsgi(self, environ: dict, start_response: typing.Callable) -> None:
for chunk in self.app(environ, start_response):
- self.send_queue.append(
- {"type": "http.response.body", "body": chunk, "more_body": True}
+ anyio.from_thread.run(
+ self.stream_send.send,
+ {"type": "http.response.body", "body": chunk, "more_body": True},
)
- self.loop.call_soon_threadsafe(self.send_event.set)
- self.send_queue.append({"type": "http.response.body", "body": b""})
- self.loop.call_soon_threadsafe(self.send_event.set)
+ anyio.from_thread.run(
+ self.stream_send.send, {"type": "http.response.body", "body": b""}
+ )
diff --git a/starlette/requests.py b/starlette/requests.py
index ab6f51424b..676f4e9aa3 100644
--- a/starlette/requests.py
+++ b/starlette/requests.py
@@ -1,9 +1,10 @@
-import asyncio
import json
import typing
from collections.abc import Mapping
from http import cookies as http_cookies
+import anyio
+
from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
from starlette.formparsers import FormParser, MultiPartParser
from starlette.types import Message, Receive, Scope, Send
@@ -64,7 +65,7 @@ def __init__(self, scope: Scope, receive: Receive = None) -> None:
assert scope["type"] in ("http", "websocket")
self.scope = scope
- def __getitem__(self, key: str) -> str:
+ def __getitem__(self, key: str) -> typing.Any:
return self.scope[key]
def __iter__(self) -> typing.Iterator[str]:
@@ -73,6 +74,12 @@ def __iter__(self) -> typing.Iterator[str]:
def __len__(self) -> int:
return len(self.scope)
+ # Don't use the `abc.Mapping.__eq__` implementation.
+ # Connection instances should never be considered equal
+ # unless `self is other`.
+ __eq__ = object.__eq__
+ __hash__ = object.__hash__
+
@property
def app(self) -> typing.Any:
return self.scope["app"]
@@ -251,10 +258,12 @@ async def close(self) -> None:
async def is_disconnected(self) -> bool:
if not self._is_disconnected:
- try:
- message = await asyncio.wait_for(self._receive(), timeout=0.0000001)
- except asyncio.TimeoutError:
- message = {}
+ message: Message = {}
+
+ # If message isn't immediately available, move on
+ with anyio.CancelScope() as cs:
+ cs.cancel()
+ message = await self._receive()
if message.get("type") == "http.disconnect":
self._is_disconnected = True
diff --git a/starlette/responses.py b/starlette/responses.py
index 00f6be4dbc..d03df23294 100644
--- a/starlette/responses.py
+++ b/starlette/responses.py
@@ -6,24 +6,20 @@
import sys
import typing
from email.utils import formatdate
+from functools import partial
from mimetypes import guess_type as mimetypes_guess_type
from urllib.parse import quote
+import anyio
+
from starlette.background import BackgroundTask
-from starlette.concurrency import iterate_in_threadpool, run_until_first_complete
+from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.types import Receive, Scope, Send
# Workaround for adding samesite support to pre 3.8 python
http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore
-try:
- import aiofiles
- from aiofiles.os import stat as aio_stat
-except ImportError: # pragma: nocover
- aiofiles = None # type: ignore
- aio_stat = None # type: ignore
-
# Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on None:
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
- await run_until_first_complete(
- (self.stream_response, {"send": send}),
- (self.listen_for_disconnect, {"receive": receive}),
- )
+ async with anyio.create_task_group() as task_group:
+
+ async def wrap(func: typing.Callable[[], typing.Coroutine]) -> None:
+ await func()
+ task_group.cancel_scope.cancel()
+
+ task_group.start_soon(wrap, partial(self.stream_response, send))
+ await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()
@@ -244,7 +244,6 @@ def __init__(
stat_result: os.stat_result = None,
method: str = None,
) -> None:
- assert aiofiles is not None, "'aiofiles' must be installed to use FileResponse"
self.path = path
self.status_code = status_code
self.filename = filename
@@ -280,7 +279,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.stat_result is None:
try:
- stat_result = await aio_stat(self.path)
+ stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
self.set_stat_headers(stat_result)
except FileNotFoundError:
raise RuntimeError(f"File at path {self.path} does not exist.")
@@ -298,10 +297,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
- # Tentatively ignoring type checking failure to work around the wrong type
- # definitions for aiofile that come with typeshed. See
- # https://github.com/python/typeshed/pull/4650
- async with aiofiles.open(self.path, mode="rb") as file: # type: ignore
+ async with await anyio.open_file(self.path, mode="rb") as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
diff --git a/starlette/routing.py b/starlette/routing.py
index cef1ef4848..9a1a5e12df 100644
--- a/starlette/routing.py
+++ b/starlette/routing.py
@@ -1,9 +1,13 @@
import asyncio
+import contextlib
import functools
import inspect
import re
+import sys
import traceback
+import types
import typing
+import warnings
from enum import Enum
from starlette.concurrency import run_in_threadpool
@@ -15,6 +19,11 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose
+if sys.version_info >= (3, 7):
+ from contextlib import asynccontextmanager # pragma: no cover
+else:
+ from contextlib2 import asynccontextmanager # pragma: no cover
+
class NoMatchFound(Exception):
"""
@@ -470,6 +479,51 @@ def __eq__(self, other: typing.Any) -> bool:
)
+_T = typing.TypeVar("_T")
+
+
+class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
+ def __init__(self, cm: typing.ContextManager[_T]):
+ self._cm = cm
+
+ async def __aenter__(self) -> _T:
+ return self._cm.__enter__()
+
+ async def __aexit__(
+ self,
+ exc_type: typing.Optional[typing.Type[BaseException]],
+ exc_value: typing.Optional[BaseException],
+ traceback: typing.Optional[types.TracebackType],
+ ) -> typing.Optional[bool]:
+ return self._cm.__exit__(exc_type, exc_value, traceback)
+
+
+def _wrap_gen_lifespan_context(
+ lifespan_context: typing.Callable[[typing.Any], typing.Generator]
+) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
+ cmgr = contextlib.contextmanager(lifespan_context)
+
+ @functools.wraps(cmgr)
+ def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
+ return _AsyncLiftContextManager(cmgr(app))
+
+ return wrapper
+
+
+class _DefaultLifespan:
+ def __init__(self, router: "Router"):
+ self._router = router
+
+ async def __aenter__(self) -> None:
+ await self._router.startup()
+
+ async def __aexit__(self, *exc_info: object) -> None:
+ await self._router.shutdown()
+
+ def __call__(self: _T, app: object) -> _T:
+ return self
+
+
class Router:
def __init__(
self,
@@ -478,7 +532,7 @@ def __init__(
default: ASGIApp = None,
on_startup: typing.Sequence[typing.Callable] = None,
on_shutdown: typing.Sequence[typing.Callable] = None,
- lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None,
+ lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
) -> None:
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
@@ -486,12 +540,31 @@ def __init__(
self.on_startup = [] if on_startup is None else list(on_startup)
self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
- async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator:
- await self.startup()
- yield
- await self.shutdown()
+ if lifespan is None:
+ self.lifespan_context: typing.Callable[
+ [typing.Any], typing.AsyncContextManager
+ ] = _DefaultLifespan(self)
- self.lifespan_context = default_lifespan if lifespan is None else lifespan
+ elif inspect.isasyncgenfunction(lifespan):
+ warnings.warn(
+ "async generator function lifespans are deprecated, "
+ "use an @contextlib.asynccontextmanager function instead",
+ DeprecationWarning,
+ )
+ self.lifespan_context = asynccontextmanager(
+ lifespan, # type: ignore[arg-type]
+ )
+ elif inspect.isgeneratorfunction(lifespan):
+ warnings.warn(
+ "generator function lifespans are deprecated, "
+ "use an @contextlib.asynccontextmanager function instead",
+ DeprecationWarning,
+ )
+ self.lifespan_context = _wrap_gen_lifespan_context(
+ lifespan, # type: ignore[arg-type]
+ )
+ else:
+ self.lifespan_context = lifespan
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "websocket":
@@ -541,25 +614,19 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
Handle ASGI lifespan messages, which allows us to manage application
startup and shutdown events.
"""
- first = True
+ started = False
app = scope.get("app")
await receive()
try:
- if inspect.isasyncgenfunction(self.lifespan_context):
- async for item in self.lifespan_context(app):
- assert first, "Lifespan context yielded multiple times."
- first = False
- await send({"type": "lifespan.startup.complete"})
- await receive()
- else:
- for item in self.lifespan_context(app): # type: ignore
- assert first, "Lifespan context yielded multiple times."
- first = False
- await send({"type": "lifespan.startup.complete"})
- await receive()
+ async with self.lifespan_context(app):
+ await send({"type": "lifespan.startup.complete"})
+ started = True
+ await receive()
except BaseException:
- if first:
- exc_text = traceback.format_exc()
+ exc_text = traceback.format_exc()
+ if started:
+ await send({"type": "lifespan.shutdown.failed", "message": exc_text})
+ else:
await send({"type": "lifespan.startup.failed", "message": exc_text})
raise
else:
diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py
index 15a67fe35d..33ea0b0337 100644
--- a/starlette/staticfiles.py
+++ b/starlette/staticfiles.py
@@ -4,7 +4,7 @@
import typing
from email.utils import parsedate
-from aiofiles.os import stat as aio_stat
+import anyio
from starlette.datastructures import URL, Headers
from starlette.responses import (
@@ -154,7 +154,7 @@ async def lookup_path(
# directory.
continue
try:
- stat_result = await aio_stat(full_path)
+ stat_result = await anyio.to_thread.run_sync(os.stat, full_path)
return full_path, stat_result
except FileNotFoundError:
pass
@@ -187,7 +187,7 @@ async def check_config(self) -> None:
return
try:
- stat_result = await aio_stat(self.directory)
+ stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
except FileNotFoundError:
raise RuntimeError(
f"StaticFiles directory '{self.directory}' does not exist."
diff --git a/starlette/templating.py b/starlette/templating.py
index 64fdbd14f6..18d5eb40c0 100644
--- a/starlette/templating.py
+++ b/starlette/templating.py
@@ -1,4 +1,5 @@
import typing
+from os import PathLike
from starlette.background import BackgroundTask
from starlette.responses import Response
@@ -54,11 +55,13 @@ class Jinja2Templates:
return templates.TemplateResponse("index.html", {"request": request})
"""
- def __init__(self, directory: str) -> None:
+ def __init__(self, directory: typing.Union[str, PathLike]) -> None:
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
- self.env = self.get_env(directory)
+ self.env = self._create_env(directory)
- def get_env(self, directory: str) -> "jinja2.Environment":
+ def _create_env(
+ self, directory: typing.Union[str, PathLike]
+ ) -> "jinja2.Environment":
@pass_context
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
request = context["request"]
diff --git a/starlette/testclient.py b/starlette/testclient.py
index 77c038b17f..08d03fa5c4 100644
--- a/starlette/testclient.py
+++ b/starlette/testclient.py
@@ -1,19 +1,35 @@
import asyncio
+import contextlib
import http
import inspect
import io
import json
+import math
import queue
-import threading
+import sys
import types
import typing
+from concurrent.futures import Future
from urllib.parse import unquote, urljoin, urlsplit
+import anyio.abc
import requests
+from anyio.streams.stapled import StapledObjectStream
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect
+if sys.version_info >= (3, 8): # pragma: no cover
+ from typing import TypedDict
+else: # pragma: no cover
+ from typing_extensions import TypedDict
+
+
+_PortalFactoryType = typing.Callable[
+ [], typing.ContextManager[anyio.abc.BlockingPortal]
+]
+
+
# Annotations for `Session.request()`
Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
@@ -87,13 +103,23 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await instance(receive, send)
+class _AsyncBackend(TypedDict):
+ backend: str
+ backend_options: typing.Dict[str, typing.Any]
+
+
class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(
- self, app: ASGI3App, raise_server_exceptions: bool = True, root_path: str = ""
+ self,
+ app: ASGI3App,
+ portal_factory: _PortalFactoryType,
+ raise_server_exceptions: bool = True,
+ root_path: str = "",
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
+ self.portal_factory = portal_factory
def send(
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
@@ -142,7 +168,7 @@ def send(
"server": [host, port],
"subprotocols": subprotocols,
}
- session = WebSocketTestSession(self.app, scope)
+ session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
scope = {
@@ -161,17 +187,17 @@ def send(
request_complete = False
response_started = False
- response_complete = False
+ response_complete: anyio.Event
raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()}
template = None
context = None
async def receive() -> Message:
- nonlocal request_complete, response_complete
+ nonlocal request_complete
if request_complete:
- while not response_complete:
- await asyncio.sleep(0.0001)
+ if not response_complete.is_set():
+ await response_complete.wait()
return {"type": "http.disconnect"}
body = request.body
@@ -195,7 +221,7 @@ async def receive() -> Message:
return {"type": "http.request", "body": body_bytes}
async def send(message: Message) -> None:
- nonlocal raw_kwargs, response_started, response_complete, template, context
+ nonlocal raw_kwargs, response_started, template, context
if message["type"] == "http.response.start":
assert (
@@ -205,7 +231,8 @@ async def send(message: Message) -> None:
raw_kwargs["status"] = message["status"]
raw_kwargs["reason"] = _get_reason_phrase(message["status"])
raw_kwargs["headers"] = [
- (key.decode(), value.decode()) for key, value in message["headers"]
+ (key.decode(), value.decode())
+ for key, value in message.get("headers", [])
]
raw_kwargs["preload_content"] = False
raw_kwargs["original_response"] = _MockOriginalResponse(
@@ -217,7 +244,7 @@ async def send(message: Message) -> None:
response_started
), 'Received "http.response.body" without "http.response.start".'
assert (
- not response_complete
+ not response_complete.is_set()
), 'Received "http.response.body" after response completed.'
body = message.get("body", b"")
more_body = message.get("more_body", False)
@@ -225,19 +252,15 @@ async def send(message: Message) -> None:
raw_kwargs["body"].write(body)
if not more_body:
raw_kwargs["body"].seek(0)
- response_complete = True
+ response_complete.set()
elif message["type"] == "http.response.template":
template = message["template"]
context = message["context"]
try:
- loop = asyncio.get_event_loop()
- except RuntimeError:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
-
- try:
- loop.run_until_complete(self.app(scope, receive, send))
+ with self.portal_factory() as portal:
+ response_complete = portal.call(anyio.Event)
+ portal.call(self.app, scope, receive, send)
except BaseException as exc:
if self.raise_server_exceptions:
raise exc
@@ -264,48 +287,60 @@ async def send(message: Message) -> None:
class WebSocketTestSession:
- def __init__(self, app: ASGI3App, scope: Scope) -> None:
+ def __init__(
+ self,
+ app: ASGI3App,
+ scope: Scope,
+ portal_factory: _PortalFactoryType,
+ ) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
+ self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
- self._thread = threading.Thread(target=self._run)
- self.send({"type": "websocket.connect"})
- self._thread.start()
- message = self.receive()
- self._raise_on_close(message)
- self.accepted_subprotocol = message.get("subprotocol", None)
def __enter__(self) -> "WebSocketTestSession":
+ self.exit_stack = contextlib.ExitStack()
+ self.portal = self.exit_stack.enter_context(self.portal_factory())
+
+ try:
+ _: "Future[None]" = self.portal.start_task_soon(self._run)
+ self.send({"type": "websocket.connect"})
+ message = self.receive()
+ self._raise_on_close(message)
+ except Exception:
+ self.exit_stack.close()
+ raise
+ self.accepted_subprotocol = message.get("subprotocol", None)
return self
def __exit__(self, *args: typing.Any) -> None:
- self.close(1000)
- self._thread.join()
+ try:
+ self.close(1000)
+ finally:
+ self.exit_stack.close()
while not self._send_queue.empty():
message = self._send_queue.get()
if isinstance(message, BaseException):
raise message
- def _run(self) -> None:
+ async def _run(self) -> None:
"""
The sub-thread in which the websocket session runs.
"""
- loop = asyncio.new_event_loop()
scope = self.scope
receive = self._asgi_receive
send = self._asgi_send
try:
- loop.run_until_complete(self.app(scope, receive, send))
+ await self.app(scope, receive, send)
except BaseException as exc:
self._send_queue.put(exc)
- finally:
- loop.close()
+ raise
async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
- await asyncio.sleep(0)
+ await anyio.sleep(0)
return self._receive_queue.get()
async def _asgi_send(self, message: Message) -> None:
@@ -364,6 +399,8 @@ def receive_json(self, mode: str = "text") -> typing.Any:
class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.
+ task: "Future[None]"
+ portal: typing.Optional[anyio.abc.BlockingPortal] = None
def __init__(
self,
@@ -371,8 +408,13 @@ def __init__(
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
+ backend: str = "asyncio",
+ backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> None:
super().__init__()
+ self.async_backend = _AsyncBackend(
+ backend=backend, backend_options=backend_options or {}
+ )
if _is_asgi3(app):
app = typing.cast(ASGI3App, app)
asgi_app = app
@@ -381,6 +423,7 @@ def __init__(
asgi_app = _WrapASGI2(app) # type: ignore
adapter = _ASGIAdapter(
asgi_app,
+ portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
)
@@ -392,6 +435,16 @@ def __init__(
self.app = asgi_app
self.base_url = base_url
+ @contextlib.contextmanager
+ def _portal_factory(
+ self,
+ ) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
+ if self.portal is not None:
+ yield self.portal
+ else:
+ with anyio.start_blocking_portal(**self.async_backend) as portal:
+ yield portal
+
def request( # type: ignore
self,
method: str,
@@ -452,42 +505,72 @@ def websocket_connect(
return session
def __enter__(self) -> "TestClient":
- loop = asyncio.get_event_loop()
- self.send_queue: "asyncio.Queue[typing.Any]" = asyncio.Queue()
- self.receive_queue: "asyncio.Queue[typing.Any]" = asyncio.Queue()
- self.task = loop.create_task(self.lifespan())
- loop.run_until_complete(self.wait_startup())
+ with contextlib.ExitStack() as stack:
+ self.portal = portal = stack.enter_context(
+ anyio.start_blocking_portal(**self.async_backend)
+ )
+
+ @stack.callback
+ def reset_portal() -> None:
+ self.portal = None
+
+ self.stream_send = StapledObjectStream(
+ *anyio.create_memory_object_stream(math.inf)
+ )
+ self.stream_receive = StapledObjectStream(
+ *anyio.create_memory_object_stream(math.inf)
+ )
+ self.task = portal.start_task_soon(self.lifespan)
+ portal.call(self.wait_startup)
+
+ @stack.callback
+ def wait_shutdown() -> None:
+ portal.call(self.wait_shutdown)
+
+ self.exit_stack = stack.pop_all()
+
return self
def __exit__(self, *args: typing.Any) -> None:
- loop = asyncio.get_event_loop()
- loop.run_until_complete(self.wait_shutdown())
+ self.exit_stack.close()
async def lifespan(self) -> None:
scope = {"type": "lifespan"}
try:
- await self.app(scope, self.receive_queue.get, self.send_queue.put)
+ await self.app(scope, self.stream_receive.receive, self.stream_send.send)
finally:
- await self.send_queue.put(None)
+ await self.stream_send.send(None)
async def wait_startup(self) -> None:
- await self.receive_queue.put({"type": "lifespan.startup"})
- message = await self.send_queue.get()
- if message is None:
- self.task.result()
+ await self.stream_receive.send({"type": "lifespan.startup"})
+
+ async def receive() -> typing.Any:
+ message = await self.stream_send.receive()
+ if message is None:
+ self.task.result()
+ return message
+
+ message = await receive()
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
)
if message["type"] == "lifespan.startup.failed":
- message = await self.send_queue.get()
- if message is None:
- self.task.result()
+ await receive()
async def wait_shutdown(self) -> None:
- await self.receive_queue.put({"type": "lifespan.shutdown"})
- message = await self.send_queue.get()
- if message is None:
- self.task.result()
- assert message["type"] == "lifespan.shutdown.complete"
- await self.task
+ async def receive() -> typing.Any:
+ message = await self.stream_send.receive()
+ if message is None:
+ self.task.result()
+ return message
+
+ async with self.stream_send:
+ await self.stream_receive.send({"type": "lifespan.shutdown"})
+ message = await receive()
+ assert message["type"] in (
+ "lifespan.shutdown.complete",
+ "lifespan.shutdown.failed",
+ )
+ if message["type"] == "lifespan.shutdown.failed":
+ await receive()
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000000..8b9872aebe
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,23 @@
+import functools
+import sys
+
+import pytest
+
+from starlette.testclient import TestClient
+
+
+@pytest.fixture
+def no_trio_support(anyio_backend_name):
+ if anyio_backend_name == "trio":
+ pytest.skip("Trio not supported (yet!)")
+
+
+@pytest.fixture
+def test_client_factory(anyio_backend_name, anyio_backend_options):
+ # anyio_backend_name defined by:
+ # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
+ return functools.partial(
+ TestClient,
+ backend=anyio_backend_name,
+ backend_options=anyio_backend_options,
+ )
diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py
index 048dd9ffb9..8a8df4ea66 100644
--- a/tests/middleware/test_base.py
+++ b/tests/middleware/test_base.py
@@ -5,7 +5,6 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Route
-from starlette.testclient import TestClient
class CustomMiddleware(BaseHTTPMiddleware):
@@ -48,8 +47,8 @@ async def websocket_endpoint(session):
await session.close()
-def test_custom_middleware():
- client = TestClient(app)
+def test_custom_middleware(test_client_factory):
+ client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
@@ -64,7 +63,7 @@ def test_custom_middleware():
assert text == "Hello, world!"
-def test_middleware_decorator():
+def test_middleware_decorator(test_client_factory):
app = Starlette()
@app.route("/homepage")
@@ -79,7 +78,7 @@ async def plaintext(request, call_next):
response.headers["Custom"] = "Example"
return response
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"
@@ -88,7 +87,7 @@ async def plaintext(request, call_next):
assert response.headers["Custom"] == "Example"
-def test_state_data_across_multiple_middlewares():
+def test_state_data_across_multiple_middlewares(test_client_factory):
expected_value1 = "foo"
expected_value2 = "bar"
@@ -120,14 +119,14 @@ async def dispatch(self, request, call_next):
def homepage(request):
return PlainTextResponse("OK")
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "OK"
assert response.headers["X-State-Foo"] == expected_value1
assert response.headers["X-State-Bar"] == expected_value2
-def test_app_middleware_argument():
+def test_app_middleware_argument(test_client_factory):
def homepage(request):
return PlainTextResponse("Homepage")
@@ -135,7 +134,7 @@ def homepage(request):
routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)]
)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
@@ -143,3 +142,18 @@ def homepage(request):
def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"
+
+
+def test_fully_evaluated_response(test_client_factory):
+ # Test for https://github.com/encode/starlette/issues/1022
+ class CustomMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request, call_next):
+ await call_next(request)
+ return PlainTextResponse("Custom")
+
+ app = Starlette()
+ app.add_middleware(CustomMiddleware)
+
+ client = test_client_factory(app)
+ response = client.get("/does_not_exist")
+ assert response.text == "Custom"
diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py
index 7a250a2416..65252e5024 100644
--- a/tests/middleware/test_cors.py
+++ b/tests/middleware/test_cors.py
@@ -1,10 +1,9 @@
from starlette.applications import Starlette
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import PlainTextResponse
-from starlette.testclient import TestClient
-def test_cors_allow_all():
+def test_cors_allow_all(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -20,7 +19,7 @@ def test_cors_allow_all():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
# Test pre-flight response
headers = {
@@ -61,7 +60,7 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers
-def test_cors_allow_all_except_credentials():
+def test_cors_allow_all_except_credentials(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -76,7 +75,7 @@ def test_cors_allow_all_except_credentials():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
# Test pre-flight response
headers = {
@@ -108,7 +107,7 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers
-def test_cors_allow_specific_origin():
+def test_cors_allow_specific_origin(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -121,7 +120,7 @@ def test_cors_allow_specific_origin():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
# Test pre-flight response
headers = {
@@ -153,7 +152,7 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers
-def test_cors_disallowed_preflight():
+def test_cors_disallowed_preflight(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -166,7 +165,7 @@ def test_cors_disallowed_preflight():
def homepage(request):
pass # pragma: no cover
- client = TestClient(app)
+ client = test_client_factory(app)
# Test pre-flight response
headers = {
@@ -179,8 +178,20 @@ def homepage(request):
assert response.text == "Disallowed CORS origin, method, headers"
assert "access-control-allow-origin" not in response.headers
+ # Bug specific test, https://github.com/encode/starlette/pull/1199
+ # Test preflight response text with multiple disallowed headers
+ headers = {
+ "Origin": "https://example.org",
+ "Access-Control-Request-Method": "GET",
+ "Access-Control-Request-Headers": "X-Nope-1, X-Nope-2",
+ }
+ response = client.options("/", headers=headers)
+ assert response.text == "Disallowed CORS headers"
+
-def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed():
+def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(
+ test_client_factory,
+):
app = Starlette()
app.add_middleware(
@@ -194,7 +205,7 @@ def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_all
def homepage(request):
return # pragma: no cover
- client = TestClient(app)
+ client = test_client_factory(app)
# Test pre-flight response
headers = {
@@ -211,7 +222,7 @@ def homepage(request):
assert response.headers["vary"] == "Origin"
-def test_cors_preflight_allow_all_methods():
+def test_cors_preflight_allow_all_methods(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -224,7 +235,7 @@ def test_cors_preflight_allow_all_methods():
def homepage(request):
pass # pragma: no cover
- client = TestClient(app)
+ client = test_client_factory(app)
headers = {
"Origin": "https://example.org",
@@ -237,7 +248,7 @@ def homepage(request):
assert method in response.headers["access-control-allow-methods"]
-def test_cors_allow_all_methods():
+def test_cors_allow_all_methods(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -252,7 +263,7 @@ def test_cors_allow_all_methods():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
headers = {"Origin": "https://example.org"}
@@ -261,7 +272,7 @@ def homepage(request):
assert response.status_code == 200
-def test_cors_allow_origin_regex():
+def test_cors_allow_origin_regex(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -275,7 +286,7 @@ def test_cors_allow_origin_regex():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
# Test standard response
headers = {"Origin": "https://example.org"}
@@ -329,7 +340,7 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers
-def test_cors_allow_origin_regex_fullmatch():
+def test_cors_allow_origin_regex_fullmatch(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -342,7 +353,7 @@ def test_cors_allow_origin_regex_fullmatch():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
# Test standard response
headers = {"Origin": "https://subdomain.example.org"}
@@ -363,7 +374,7 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers
-def test_cors_credentialed_requests_return_specific_origin():
+def test_cors_credentialed_requests_return_specific_origin(test_client_factory):
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=["*"])
@@ -372,7 +383,7 @@ def test_cors_credentialed_requests_return_specific_origin():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
# Test credentialed request
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
@@ -383,7 +394,7 @@ def homepage(request):
assert "access-control-allow-credentials" not in response.headers
-def test_cors_vary_header_defaults_to_origin():
+def test_cors_vary_header_defaults_to_origin(test_client_factory):
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
@@ -394,14 +405,14 @@ def test_cors_vary_header_defaults_to_origin():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.headers["vary"] == "Origin"
-def test_cors_vary_header_is_not_set_for_non_credentialed_request():
+def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory):
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=["*"])
@@ -412,14 +423,14 @@ def homepage(request):
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.status_code == 200
assert response.headers["vary"] == "Accept-Encoding"
-def test_cors_vary_header_is_properly_set_for_credentialed_request():
+def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory):
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=["*"])
@@ -430,7 +441,7 @@ def homepage(request):
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get(
"/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
@@ -439,7 +450,9 @@ def homepage(request):
assert response.headers["vary"] == "Accept-Encoding, Origin"
-def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard():
+def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
+ test_client_factory,
+):
app = Starlette()
app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"])
@@ -450,14 +463,16 @@ def homepage(request):
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"Origin": "https://example.org"})
assert response.status_code == 200
assert response.headers["vary"] == "Accept-Encoding, Origin"
-def test_cors_allowed_origin_does_not_leak_between_credentialed_requests():
+def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(
+ test_client_factory,
+):
app = Starlette()
app.add_middleware(
@@ -468,7 +483,7 @@ def test_cors_allowed_origin_does_not_leak_between_credentialed_requests():
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers
diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py
index c178ef9da2..2c926a9b2d 100644
--- a/tests/middleware/test_errors.py
+++ b/tests/middleware/test_errors.py
@@ -2,10 +2,9 @@
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.responses import JSONResponse, Response
-from starlette.testclient import TestClient
-def test_handler():
+def test_handler(test_client_factory):
async def app(scope, receive, send):
raise RuntimeError("Something went wrong")
@@ -13,49 +12,49 @@ def error_500(request, exc):
return JSONResponse({"detail": "Server Error"}, status_code=500)
app = ServerErrorMiddleware(app, handler=error_500)
- client = TestClient(app, raise_server_exceptions=False)
+ client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.json() == {"detail": "Server Error"}
-def test_debug_text():
+def test_debug_text(test_client_factory):
async def app(scope, receive, send):
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
- client = TestClient(app, raise_server_exceptions=False)
+ client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.headers["content-type"].startswith("text/plain")
assert "RuntimeError: Something went wrong" in response.text
-def test_debug_html():
+def test_debug_html(test_client_factory):
async def app(scope, receive, send):
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
- client = TestClient(app, raise_server_exceptions=False)
+ client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/", headers={"Accept": "text/html, */*"})
assert response.status_code == 500
assert response.headers["content-type"].startswith("text/html")
assert "RuntimeError" in response.text
-def test_debug_after_response_sent():
+def test_debug_after_response_sent(test_client_factory):
async def app(scope, receive, send):
response = Response(b"", status_code=204)
await response(scope, receive, send)
raise RuntimeError("Something went wrong")
app = ServerErrorMiddleware(app, debug=True)
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError):
client.get("/")
-def test_debug_not_http():
+def test_debug_not_http(test_client_factory):
"""
DebugMiddleware should just pass through any non-http messages as-is.
"""
@@ -66,5 +65,6 @@ async def app(scope, receive, send):
app = ServerErrorMiddleware(app)
with pytest.raises(RuntimeError):
- client = TestClient(app)
- client.websocket_connect("/")
+ client = test_client_factory(app)
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py
index cd989b8c1f..b917ea4dbb 100644
--- a/tests/middleware/test_gzip.py
+++ b/tests/middleware/test_gzip.py
@@ -1,10 +1,9 @@
from starlette.applications import Starlette
from starlette.middleware.gzip import GZipMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
-from starlette.testclient import TestClient
-def test_gzip_responses():
+def test_gzip_responses(test_client_factory):
app = Starlette()
app.add_middleware(GZipMiddleware)
@@ -13,7 +12,7 @@ def test_gzip_responses():
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
@@ -21,7 +20,7 @@ def homepage(request):
assert int(response.headers["Content-Length"]) < 4000
-def test_gzip_not_in_accept_encoding():
+def test_gzip_not_in_accept_encoding(test_client_factory):
app = Starlette()
app.add_middleware(GZipMiddleware)
@@ -30,7 +29,7 @@ def test_gzip_not_in_accept_encoding():
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "identity"})
assert response.status_code == 200
assert response.text == "x" * 4000
@@ -38,7 +37,7 @@ def homepage(request):
assert int(response.headers["Content-Length"]) == 4000
-def test_gzip_ignored_for_small_responses():
+def test_gzip_ignored_for_small_responses(test_client_factory):
app = Starlette()
app.add_middleware(GZipMiddleware)
@@ -47,7 +46,7 @@ def test_gzip_ignored_for_small_responses():
def homepage(request):
return PlainTextResponse("OK", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "OK"
@@ -55,7 +54,7 @@ def homepage(request):
assert int(response.headers["Content-Length"]) == 2
-def test_gzip_streaming_response():
+def test_gzip_streaming_response(test_client_factory):
app = Starlette()
app.add_middleware(GZipMiddleware)
@@ -69,7 +68,7 @@ async def generator(bytes, count):
streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py
index 757770b853..8db9506342 100644
--- a/tests/middleware/test_https_redirect.py
+++ b/tests/middleware/test_https_redirect.py
@@ -1,10 +1,9 @@
from starlette.applications import Starlette
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.responses import PlainTextResponse
-from starlette.testclient import TestClient
-def test_https_redirect_middleware():
+def test_https_redirect_middleware(test_client_factory):
app = Starlette()
app.add_middleware(HTTPSRedirectMiddleware)
@@ -13,26 +12,26 @@ def test_https_redirect_middleware():
def homepage(request):
return PlainTextResponse("OK", status_code=200)
- client = TestClient(app, base_url="https://testserver")
+ client = test_client_factory(app, base_url="https://testserver")
response = client.get("/")
assert response.status_code == 200
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", allow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://testserver/"
- client = TestClient(app, base_url="http://testserver:80")
+ client = test_client_factory(app, base_url="http://testserver:80")
response = client.get("/", allow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://testserver/"
- client = TestClient(app, base_url="http://testserver:443")
+ client = test_client_factory(app, base_url="http://testserver:443")
response = client.get("/", allow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://testserver/"
- client = TestClient(app, base_url="http://testserver:123")
+ client = test_client_factory(app, base_url="http://testserver:123")
response = client.get("/", allow_redirects=False)
assert response.status_code == 307
assert response.headers["location"] == "https://testserver:123/"
diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py
index 68cf36df99..42f4447e5c 100644
--- a/tests/middleware/test_session.py
+++ b/tests/middleware/test_session.py
@@ -3,7 +3,6 @@
from starlette.applications import Starlette
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import JSONResponse
-from starlette.testclient import TestClient
def view_session(request):
@@ -29,10 +28,10 @@ def create_app():
return app
-def test_session():
+def test_session(test_client_factory):
app = create_app()
app.add_middleware(SessionMiddleware, secret_key="example")
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/view_session")
assert response.json() == {"session": {}}
@@ -56,10 +55,10 @@ def test_session():
assert response.json() == {"session": {}}
-def test_session_expires():
+def test_session_expires(test_client_factory):
app = create_app()
app.add_middleware(SessionMiddleware, secret_key="example", max_age=-1)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
@@ -72,11 +71,11 @@ def test_session_expires():
assert response.json() == {"session": {}}
-def test_secure_session():
+def test_secure_session(test_client_factory):
app = create_app()
app.add_middleware(SessionMiddleware, secret_key="example", https_only=True)
- secure_client = TestClient(app, base_url="https://testserver")
- unsecure_client = TestClient(app, base_url="http://testserver")
+ secure_client = test_client_factory(app, base_url="https://testserver")
+ unsecure_client = test_client_factory(app, base_url="http://testserver")
response = unsecure_client.get("/view_session")
assert response.json() == {"session": {}}
@@ -103,13 +102,26 @@ def test_secure_session():
assert response.json() == {"session": {}}
-def test_session_cookie_subpath():
+def test_session_cookie_subpath(test_client_factory):
app = create_app()
second_app = create_app()
second_app.add_middleware(SessionMiddleware, secret_key="example")
app.mount("/second_app", second_app)
- client = TestClient(app, base_url="http://testserver/second_app")
+ client = test_client_factory(app, base_url="http://testserver/second_app")
response = client.post("second_app/update_session", json={"some": "data"})
cookie = response.headers["set-cookie"]
cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0]
assert cookie_path == "/second_app"
+
+
+def test_invalid_session_cookie(test_client_factory):
+ app = create_app()
+ app.add_middleware(SessionMiddleware, secret_key="example")
+ client = test_client_factory(app)
+
+ response = client.post("/update_session", json={"some": "data"})
+ assert response.json() == {"session": {"some": "data"}}
+
+ # we expect it to not raise an exception if we provide a bogus session cookie
+ response = client.get("/view_session", cookies={"session": "invalid"})
+ assert response.json() == {"session": {}}
diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py
index 934f2477bf..de9c79e66a 100644
--- a/tests/middleware/test_trusted_host.py
+++ b/tests/middleware/test_trusted_host.py
@@ -1,10 +1,9 @@
from starlette.applications import Starlette
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import PlainTextResponse
-from starlette.testclient import TestClient
-def test_trusted_host_middleware():
+def test_trusted_host_middleware(test_client_factory):
app = Starlette()
app.add_middleware(
@@ -15,15 +14,15 @@ def test_trusted_host_middleware():
def homepage(request):
return PlainTextResponse("OK", status_code=200)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
- client = TestClient(app, base_url="http://subdomain.testserver")
+ client = test_client_factory(app, base_url="http://subdomain.testserver")
response = client.get("/")
assert response.status_code == 200
- client = TestClient(app, base_url="http://invalidhost")
+ client = test_client_factory(app, base_url="http://invalidhost")
response = client.get("/")
assert response.status_code == 400
@@ -34,7 +33,7 @@ def test_default_allowed_hosts():
assert middleware.allowed_hosts == ["*"]
-def test_www_redirect():
+def test_www_redirect(test_client_factory):
app = Starlette()
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["www.example.com"])
@@ -43,7 +42,7 @@ def test_www_redirect():
def homepage(request):
return PlainTextResponse("OK", status_code=200)
- client = TestClient(app, base_url="https://example.com")
+ client = test_client_factory(app, base_url="https://example.com")
response = client.get("/")
assert response.status_code == 200
assert response.url == "https://www.example.com/"
diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py
index 615805a94d..bcb4cd6ff2 100644
--- a/tests/middleware/test_wsgi.py
+++ b/tests/middleware/test_wsgi.py
@@ -3,7 +3,6 @@
import pytest
from starlette.middleware.wsgi import WSGIMiddleware, build_environ
-from starlette.testclient import TestClient
def hello_world(environ, start_response):
@@ -46,41 +45,41 @@ def return_exc_info(environ, start_response):
return [output]
-def test_wsgi_get():
+def test_wsgi_get(test_client_factory):
app = WSGIMiddleware(hello_world)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello World!\n"
-def test_wsgi_post():
+def test_wsgi_post(test_client_factory):
app = WSGIMiddleware(echo_body)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.post("/", json={"example": 123})
assert response.status_code == 200
assert response.text == '{"example": 123}'
-def test_wsgi_exception():
+def test_wsgi_exception(test_client_factory):
# Note that we're testing the WSGI app directly here.
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError):
client.get("/")
-def test_wsgi_exc_info():
+def test_wsgi_exc_info(test_client_factory):
# Note that we're testing the WSGI app directly here.
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(return_exc_info)
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError):
response = client.get("/")
app = WSGIMiddleware(return_exc_info)
- client = TestClient(app, raise_server_exceptions=False)
+ client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.text == "Internal Server Error"
diff --git a/tests/test_applications.py b/tests/test_applications.py
index ad8504cbd5..f5f4c7fbea 100644
--- a/tests/test_applications.py
+++ b/tests/test_applications.py
@@ -1,4 +1,7 @@
import os
+import sys
+
+import pytest
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
@@ -7,7 +10,11 @@
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
-from starlette.testclient import TestClient
+
+if sys.version_info >= (3, 7):
+ from contextlib import asynccontextmanager # pragma: no cover
+else:
+ from contextlib2 import asynccontextmanager # pragma: no cover
app = Starlette()
@@ -86,14 +93,17 @@ async def websocket_endpoint(session):
await session.close()
-client = TestClient(app)
+@pytest.fixture
+def client(test_client_factory):
+ with test_client_factory(app) as client:
+ yield client
def test_url_path_for():
assert app.url_path_for("func_homepage") == "/func"
-def test_func_route():
+def test_func_route(client):
response = client.get("/func")
assert response.status_code == 200
assert response.text == "Hello, world!"
@@ -103,51 +113,51 @@ def test_func_route():
assert response.text == ""
-def test_async_route():
+def test_async_route(client):
response = client.get("/async")
assert response.status_code == 200
assert response.text == "Hello, world!"
-def test_class_route():
+def test_class_route(client):
response = client.get("/class")
assert response.status_code == 200
assert response.text == "Hello, world!"
-def test_mounted_route():
+def test_mounted_route(client):
response = client.get("/users/")
assert response.status_code == 200
assert response.text == "Hello, everyone!"
-def test_mounted_route_path_params():
+def test_mounted_route_path_params(client):
response = client.get("/users/tomchristie")
assert response.status_code == 200
assert response.text == "Hello, tomchristie!"
-def test_subdomain_route():
- client = TestClient(app, base_url="https://foo.example.org/")
+def test_subdomain_route(test_client_factory):
+ client = test_client_factory(app, base_url="https://foo.example.org/")
response = client.get("/")
assert response.status_code == 200
assert response.text == "Subdomain: foo"
-def test_websocket_route():
+def test_websocket_route(client):
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
-def test_400():
+def test_400(client):
response = client.get("/404")
assert response.status_code == 404
assert response.json() == {"detail": "Not Found"}
-def test_405():
+def test_405(client):
response = client.post("/func")
assert response.status_code == 405
assert response.json() == {"detail": "Custom message"}
@@ -157,15 +167,15 @@ def test_405():
assert response.json() == {"detail": "Custom message"}
-def test_500():
- client = TestClient(app, raise_server_exceptions=False)
+def test_500(test_client_factory):
+ client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/500")
assert response.status_code == 500
assert response.json() == {"detail": "Server Error"}
-def test_middleware():
- client = TestClient(app, base_url="http://incorrecthost")
+def test_middleware(test_client_factory):
+ client = test_client_factory(app, base_url="http://incorrecthost")
response = client.get("/func")
assert response.status_code == 400
assert response.text == "Invalid host header"
@@ -194,7 +204,7 @@ def test_routes():
]
-def test_app_mount(tmpdir):
+def test_app_mount(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
@@ -202,7 +212,7 @@ def test_app_mount(tmpdir):
app = Starlette()
app.mount("/static", StaticFiles(directory=tmpdir))
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/static/example.txt")
assert response.status_code == 200
@@ -213,7 +223,7 @@ def test_app_mount(tmpdir):
assert response.text == "Method Not Allowed"
-def test_app_debug():
+def test_app_debug(test_client_factory):
app = Starlette()
app.debug = True
@@ -221,27 +231,27 @@ def test_app_debug():
async def homepage(request):
raise RuntimeError()
- client = TestClient(app, raise_server_exceptions=False)
+ client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert "RuntimeError" in response.text
assert app.debug
-def test_app_add_route():
+def test_app_add_route(test_client_factory):
app = Starlette()
async def homepage(request):
return PlainTextResponse("Hello, World!")
app.add_route("/", homepage)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, World!"
-def test_app_add_websocket_route():
+def test_app_add_websocket_route(test_client_factory):
app = Starlette()
async def websocket_endpoint(session):
@@ -250,14 +260,14 @@ async def websocket_endpoint(session):
await session.close()
app.add_websocket_route("/ws", websocket_endpoint)
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
-def test_app_add_event_handler():
+def test_app_add_event_handler(test_client_factory):
startup_complete = False
cleanup_complete = False
app = Starlette()
@@ -275,14 +285,46 @@ def run_cleanup():
assert not startup_complete
assert not cleanup_complete
- with TestClient(app):
+ with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
-def test_app_async_lifespan():
+def test_app_async_cm_lifespan(test_client_factory):
+ startup_complete = False
+ cleanup_complete = False
+
+ @asynccontextmanager
+ async def lifespan(app):
+ nonlocal startup_complete, cleanup_complete
+ startup_complete = True
+ yield
+ cleanup_complete = True
+
+ app = Starlette(lifespan=lifespan)
+
+ assert not startup_complete
+ assert not cleanup_complete
+ with test_client_factory(app):
+ assert startup_complete
+ assert not cleanup_complete
+ assert startup_complete
+ assert cleanup_complete
+
+
+deprecated_lifespan = pytest.mark.filterwarnings(
+ r"ignore"
+ r":(async )?generator function lifespans are deprecated, use an "
+ r"@contextlib\.asynccontextmanager function instead"
+ r":DeprecationWarning"
+ r":starlette.routing"
+)
+
+
+@deprecated_lifespan
+def test_app_async_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False
@@ -296,14 +338,15 @@ async def lifespan(app):
assert not startup_complete
assert not cleanup_complete
- with TestClient(app):
+ with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete
-def test_app_sync_lifespan():
+@deprecated_lifespan
+def test_app_sync_gen_lifespan(test_client_factory):
startup_complete = False
cleanup_complete = False
@@ -317,7 +360,7 @@ def lifespan(app):
assert not startup_complete
assert not cleanup_complete
- with TestClient(app):
+ with test_client_factory(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
index 3373f67c50..43c7ab96dc 100644
--- a/tests/test_authentication.py
+++ b/tests/test_authentication.py
@@ -15,7 +15,6 @@
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
-from starlette.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
@@ -195,8 +194,8 @@ def foo():
pass # pragma: nocover
-def test_user_interface():
- with TestClient(app) as client:
+def test_user_interface(test_client_factory):
+ with test_client_factory(app) as client:
response = client.get("/")
assert response.status_code == 200
assert response.json() == {"authenticated": False, "user": ""}
@@ -206,8 +205,8 @@ def test_user_interface():
assert response.json() == {"authenticated": True, "user": "tomchristie"}
-def test_authentication_required():
- with TestClient(app) as client:
+def test_authentication_required(test_client_factory):
+ with test_client_factory(app) as client:
response = client.get("/dashboard")
assert response.status_code == 403
@@ -258,13 +257,17 @@ def test_authentication_required():
assert response.text == "Invalid basic auth credentials"
-def test_websocket_authentication_required():
- with TestClient(app) as client:
+def test_websocket_authentication_required(test_client_factory):
+ with test_client_factory(app) as client:
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/ws")
+ with client.websocket_connect("/ws"):
+ pass # pragma: nocover
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/ws", headers={"Authorization": "basic foobar"})
+ with client.websocket_connect(
+ "/ws", headers={"Authorization": "basic foobar"}
+ ):
+ pass # pragma: nocover
with client.websocket_connect(
"/ws", auth=("tomchristie", "example")
@@ -273,12 +276,14 @@ def test_websocket_authentication_required():
assert data == {"authenticated": True, "user": "tomchristie"}
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/ws/decorated")
+ with client.websocket_connect("/ws/decorated"):
+ pass # pragma: nocover
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect(
+ with client.websocket_connect(
"/ws/decorated", headers={"Authorization": "basic foobar"}
- )
+ ):
+ pass # pragma: nocover
with client.websocket_connect(
"/ws/decorated", auth=("tomchristie", "example")
@@ -291,8 +296,8 @@ def test_websocket_authentication_required():
}
-def test_authentication_redirect():
- with TestClient(app) as client:
+def test_authentication_redirect(test_client_factory):
+ with test_client_factory(app) as client:
response = client.get("/admin")
assert response.status_code == 200
assert response.url == "http://testserver/"
@@ -331,8 +336,8 @@ def control_panel(request):
)
-def test_custom_on_error():
- with TestClient(other_app) as client:
+def test_custom_on_error(test_client_factory):
+ with test_client_factory(other_app) as client:
response = client.get("/control-panel", auth=("tomchristie", "example"))
assert response.status_code == 200
assert response.json() == {"authenticated": True, "user": "tomchristie"}
diff --git a/tests/test_background.py b/tests/test_background.py
index d9d7ddd872..e299ec3628 100644
--- a/tests/test_background.py
+++ b/tests/test_background.py
@@ -1,9 +1,8 @@
from starlette.background import BackgroundTask, BackgroundTasks
from starlette.responses import Response
-from starlette.testclient import TestClient
-def test_async_task():
+def test_async_task(test_client_factory):
TASK_COMPLETE = False
async def async_task():
@@ -16,13 +15,13 @@ async def app(scope, receive, send):
response = Response("task initiated", media_type="text/plain", background=task)
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "task initiated"
assert TASK_COMPLETE
-def test_sync_task():
+def test_sync_task(test_client_factory):
TASK_COMPLETE = False
def sync_task():
@@ -35,13 +34,13 @@ async def app(scope, receive, send):
response = Response("task initiated", media_type="text/plain", background=task)
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "task initiated"
assert TASK_COMPLETE
-def test_multiple_tasks():
+def test_multiple_tasks(test_client_factory):
TASK_COUNTER = 0
def increment(amount):
@@ -58,7 +57,7 @@ async def app(scope, receive, send):
)
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "tasks initiated"
assert TASK_COUNTER == 1 + 2 + 3
diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py
new file mode 100644
index 0000000000..cc5eba974f
--- /dev/null
+++ b/tests/test_concurrency.py
@@ -0,0 +1,22 @@
+import anyio
+import pytest
+
+from starlette.concurrency import run_until_first_complete
+
+
+@pytest.mark.anyio
+async def test_run_until_first_complete():
+ task1_finished = anyio.Event()
+ task2_finished = anyio.Event()
+
+ async def task1():
+ task1_finished.set()
+
+ async def task2():
+ await task1_finished.wait()
+ await anyio.sleep(0) # pragma: nocover
+ task2_finished.set() # pragma: nocover
+
+ await run_until_first_complete((task1, {}), (task2, {}))
+ assert task1_finished.is_set()
+ assert not task2_finished.is_set()
diff --git a/tests/test_database.py b/tests/test_database.py
index 258a71ec51..c0a4745d11 100644
--- a/tests/test_database.py
+++ b/tests/test_database.py
@@ -4,7 +4,6 @@
from starlette.applications import Starlette
from starlette.responses import JSONResponse
-from starlette.testclient import TestClient
DATABASE_URL = "sqlite:///test.db"
@@ -19,6 +18,9 @@
)
+pytestmark = pytest.mark.usefixtures("no_trio_support")
+
+
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
@@ -87,8 +89,8 @@ async def read_note_text(request):
return JSONResponse(result[0])
-def test_database():
- with TestClient(app) as client:
+def test_database(test_client_factory):
+ with test_client_factory(app) as client:
response = client.post(
"/notes", json={"text": "buy the milk", "completed": True}
)
@@ -122,10 +124,8 @@ def test_database():
assert response.json() == "buy the milk"
-def test_database_execute_many():
- with TestClient(app) as client:
- response = client.get("/notes")
-
+def test_database_execute_many(test_client_factory):
+ with test_client_factory(app) as client:
data = [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
@@ -141,11 +141,11 @@ def test_database_execute_many():
]
-def test_database_isolated_during_test_cases():
+def test_database_isolated_during_test_cases(test_client_factory):
"""
Using `TestClient` as a context manager
"""
- with TestClient(app) as client:
+ with test_client_factory(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}
)
@@ -155,7 +155,7 @@ def test_database_isolated_during_test_cases():
assert response.status_code == 200
assert response.json() == [{"text": "just one note", "completed": True}]
- with TestClient(app) as client:
+ with test_client_factory(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}
)
diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py
index b0e6baf985..bb71ba870c 100644
--- a/tests/test_datastructures.py
+++ b/tests/test_datastructures.py
@@ -217,7 +217,7 @@ class BigUploadFile(UploadFile):
spool_max_size = 1024
-@pytest.mark.asyncio
+@pytest.mark.anyio
async def test_upload_file():
big_file = BigUploadFile("big-file")
await big_file.write(b"big-data" * 512)
diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py
index e491c085f3..e57d47486a 100644
--- a/tests/test_endpoints.py
+++ b/tests/test_endpoints.py
@@ -3,7 +3,6 @@
from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router
-from starlette.testclient import TestClient
class Homepage(HTTPEndpoint):
@@ -18,46 +17,50 @@ async def get(self, request):
routes=[Route("/", endpoint=Homepage), Route("/{username}", endpoint=Homepage)]
)
-client = TestClient(app)
+@pytest.fixture
+def client(test_client_factory):
+ with test_client_factory(app) as client:
+ yield client
-def test_http_endpoint_route():
+
+def test_http_endpoint_route(client):
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world!"
-def test_http_endpoint_route_path_params():
+def test_http_endpoint_route_path_params(client):
response = client.get("/tomchristie")
assert response.status_code == 200
assert response.text == "Hello, tomchristie!"
-def test_http_endpoint_route_method():
+def test_http_endpoint_route_method(client):
response = client.post("/")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
-def test_websocket_endpoint_on_connect():
+def test_websocket_endpoint_on_connect(test_client_factory):
class WebSocketApp(WebSocketEndpoint):
async def on_connect(self, websocket):
assert websocket["subprotocols"] == ["soap", "wamp"]
await websocket.accept(subprotocol="wamp")
- client = TestClient(WebSocketApp)
+ client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws", subprotocols=["soap", "wamp"]) as websocket:
assert websocket.accepted_subprotocol == "wamp"
-def test_websocket_endpoint_on_receive_bytes():
+def test_websocket_endpoint_on_receive_bytes(test_client_factory):
class WebSocketApp(WebSocketEndpoint):
encoding = "bytes"
async def on_receive(self, websocket, data):
await websocket.send_bytes(b"Message bytes was: " + data)
- client = TestClient(WebSocketApp)
+ client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_bytes(b"Hello, world!")
_bytes = websocket.receive_bytes()
@@ -68,14 +71,14 @@ async def on_receive(self, websocket, data):
websocket.send_text("Hello world")
-def test_websocket_endpoint_on_receive_json():
+def test_websocket_endpoint_on_receive_json(test_client_factory):
class WebSocketApp(WebSocketEndpoint):
encoding = "json"
async def on_receive(self, websocket, data):
await websocket.send_json({"message": data})
- client = TestClient(WebSocketApp)
+ client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_json({"hello": "world"})
data = websocket.receive_json()
@@ -86,28 +89,28 @@ async def on_receive(self, websocket, data):
websocket.send_text("Hello world")
-def test_websocket_endpoint_on_receive_json_binary():
+def test_websocket_endpoint_on_receive_json_binary(test_client_factory):
class WebSocketApp(WebSocketEndpoint):
encoding = "json"
async def on_receive(self, websocket, data):
await websocket.send_json({"message": data}, mode="binary")
- client = TestClient(WebSocketApp)
+ client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_json({"hello": "world"}, mode="binary")
data = websocket.receive_json(mode="binary")
assert data == {"message": {"hello": "world"}}
-def test_websocket_endpoint_on_receive_text():
+def test_websocket_endpoint_on_receive_text(test_client_factory):
class WebSocketApp(WebSocketEndpoint):
encoding = "text"
async def on_receive(self, websocket, data):
await websocket.send_text(f"Message text was: {data}")
- client = TestClient(WebSocketApp)
+ client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_text("Hello, world!")
_text = websocket.receive_text()
@@ -118,26 +121,26 @@ async def on_receive(self, websocket, data):
websocket.send_bytes(b"Hello world")
-def test_websocket_endpoint_on_default():
+def test_websocket_endpoint_on_default(test_client_factory):
class WebSocketApp(WebSocketEndpoint):
encoding = None
async def on_receive(self, websocket, data):
await websocket.send_text(f"Message text was: {data}")
- client = TestClient(WebSocketApp)
+ client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.send_text("Hello, world!")
_text = websocket.receive_text()
assert _text == "Message text was: Hello, world!"
-def test_websocket_endpoint_on_disconnect():
+def test_websocket_endpoint_on_disconnect(test_client_factory):
class WebSocketApp(WebSocketEndpoint):
async def on_disconnect(self, websocket, close_code):
assert close_code == 1001
await websocket.close(code=close_code)
- client = TestClient(WebSocketApp)
+ client = test_client_factory(WebSocketApp)
with client.websocket_connect("/ws") as websocket:
websocket.close(code=1001)
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index 841c9a5cf6..5fba9981b1 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -3,7 +3,6 @@
from starlette.exceptions import ExceptionMiddleware, HTTPException
from starlette.responses import PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute
-from starlette.testclient import TestClient
def raise_runtime_error(request):
@@ -37,27 +36,33 @@ async def __call__(self, scope, receive, send):
app = ExceptionMiddleware(router)
-client = TestClient(app)
-def test_not_acceptable():
+@pytest.fixture
+def client(test_client_factory):
+ with test_client_factory(app) as client:
+ yield client
+
+
+def test_not_acceptable(client):
response = client.get("/not_acceptable")
assert response.status_code == 406
assert response.text == "Not Acceptable"
-def test_not_modified():
+def test_not_modified(client):
response = client.get("/not_modified")
assert response.status_code == 304
assert response.text == ""
-def test_websockets_should_raise():
+def test_websockets_should_raise(client):
with pytest.raises(RuntimeError):
- client.websocket_connect("/runtime_error")
+ with client.websocket_connect("/runtime_error"):
+ pass # pragma: nocover
-def test_handled_exc_after_response():
+def test_handled_exc_after_response(test_client_factory, client):
# A 406 HttpException is raised *after* the response has already been sent.
# The exception middleware should raise a RuntimeError.
with pytest.raises(RuntimeError):
@@ -65,17 +70,17 @@ def test_handled_exc_after_response():
# If `raise_server_exceptions=False` then the test client will still allow
# us to see the response as it will have been seen by the client.
- allow_200_client = TestClient(app, raise_server_exceptions=False)
+ allow_200_client = test_client_factory(app, raise_server_exceptions=False)
response = allow_200_client.get("/handled_exc_after_response")
assert response.status_code == 200
assert response.text == "OK"
-def test_force_500_response():
+def test_force_500_response(test_client_factory):
def app(scope):
raise RuntimeError()
- force_500_client = TestClient(app, raise_server_exceptions=False)
+ force_500_client = test_client_factory(app, raise_server_exceptions=False)
response = force_500_client.get("/")
assert response.status_code == 500
assert response.text == ""
diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py
index 73a720fd13..8a1174e1d2 100644
--- a/tests/test_formparsers.py
+++ b/tests/test_formparsers.py
@@ -3,7 +3,6 @@
from starlette.formparsers import UploadFile, _user_safe_decode
from starlette.requests import Request
from starlette.responses import JSONResponse
-from starlette.testclient import TestClient
class ForceMultipartDict(dict):
@@ -70,18 +69,18 @@ async def app_read_body(scope, receive, send):
await response(scope, receive, send)
-def test_multipart_request_data(tmpdir):
- client = TestClient(app)
+def test_multipart_request_data(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART)
assert response.json() == {"some": "data"}
-def test_multipart_request_files(tmpdir):
+def test_multipart_request_files(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"")
- client = TestClient(app)
+ client = test_client_factory(app)
with open(path, "rb") as f:
response = client.post("/", files={"test": f})
assert response.json() == {
@@ -93,12 +92,12 @@ def test_multipart_request_files(tmpdir):
}
-def test_multipart_request_files_with_content_type(tmpdir):
+def test_multipart_request_files_with_content_type(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(b"")
- client = TestClient(app)
+ client = test_client_factory(app)
with open(path, "rb") as f:
response = client.post("/", files={"test": ("test.txt", f, "text/plain")})
assert response.json() == {
@@ -110,7 +109,7 @@ def test_multipart_request_files_with_content_type(tmpdir):
}
-def test_multipart_request_multiple_files(tmpdir):
+def test_multipart_request_multiple_files(tmpdir, test_client_factory):
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"")
@@ -119,7 +118,7 @@ def test_multipart_request_multiple_files(tmpdir):
with open(path2, "wb") as file:
file.write(b"")
- client = TestClient(app)
+ client = test_client_factory(app)
with open(path1, "rb") as f1, open(path2, "rb") as f2:
response = client.post(
"/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")}
@@ -138,7 +137,7 @@ def test_multipart_request_multiple_files(tmpdir):
}
-def test_multi_items(tmpdir):
+def test_multi_items(tmpdir, test_client_factory):
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"")
@@ -147,7 +146,7 @@ def test_multi_items(tmpdir):
with open(path2, "wb") as file:
file.write(b"")
- client = TestClient(multi_items_app)
+ client = test_client_factory(multi_items_app)
with open(path1, "rb") as f1, open(path2, "rb") as f2:
response = client.post(
"/",
@@ -171,8 +170,8 @@ def test_multi_items(tmpdir):
}
-def test_multipart_request_mixed_files_and_data(tmpdir):
- client = TestClient(app)
+def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post(
"/",
data=(
@@ -208,8 +207,8 @@ def test_multipart_request_mixed_files_and_data(tmpdir):
}
-def test_multipart_request_with_charset_for_filename(tmpdir):
- client = TestClient(app)
+def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post(
"/",
data=(
@@ -236,8 +235,8 @@ def test_multipart_request_with_charset_for_filename(tmpdir):
}
-def test_multipart_request_without_charset_for_filename(tmpdir):
- client = TestClient(app)
+def test_multipart_request_without_charset_for_filename(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post(
"/",
data=(
@@ -263,8 +262,8 @@ def test_multipart_request_without_charset_for_filename(tmpdir):
}
-def test_multipart_request_with_encoded_value(tmpdir):
- client = TestClient(app)
+def test_multipart_request_with_encoded_value(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post(
"/",
data=(
@@ -284,38 +283,38 @@ def test_multipart_request_with_encoded_value(tmpdir):
assert response.json() == {"value": "Transférer"}
-def test_urlencoded_request_data(tmpdir):
- client = TestClient(app)
+def test_urlencoded_request_data(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post("/", data={"some": "data"})
assert response.json() == {"some": "data"}
-def test_no_request_data(tmpdir):
- client = TestClient(app)
+def test_no_request_data(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post("/")
assert response.json() == {}
-def test_urlencoded_percent_encoding(tmpdir):
- client = TestClient(app)
+def test_urlencoded_percent_encoding(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post("/", data={"some": "da ta"})
assert response.json() == {"some": "da ta"}
-def test_urlencoded_percent_encoding_keys(tmpdir):
- client = TestClient(app)
+def test_urlencoded_percent_encoding_keys(tmpdir, test_client_factory):
+ client = test_client_factory(app)
response = client.post("/", data={"so me": "data"})
assert response.json() == {"so me": "data"}
-def test_urlencoded_multi_field_app_reads_body(tmpdir):
- client = TestClient(app_read_body)
+def test_urlencoded_multi_field_app_reads_body(tmpdir, test_client_factory):
+ client = test_client_factory(app_read_body)
response = client.post("/", data={"some": "data", "second": "key pair"})
assert response.json() == {"some": "data", "second": "key pair"}
-def test_multipart_multi_field_app_reads_body(tmpdir):
- client = TestClient(app_read_body)
+def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory):
+ client = test_client_factory(app_read_body)
response = client.post(
"/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART
)
diff --git a/tests/test_requests.py b/tests/test_requests.py
index a83a2c480b..d7c69fbeb2 100644
--- a/tests/test_requests.py
+++ b/tests/test_requests.py
@@ -1,20 +1,18 @@
-import asyncio
-
+import anyio
import pytest
from starlette.requests import ClientDisconnect, Request, State
from starlette.responses import JSONResponse, Response
-from starlette.testclient import TestClient
-def test_request_url():
+def test_request_url(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
data = {"method": request.method, "url": str(request.url)}
response = JSONResponse(data)
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/123?a=abc")
assert response.json() == {"method": "GET", "url": "http://testserver/123?a=abc"}
@@ -22,26 +20,26 @@ async def app(scope, receive, send):
assert response.json() == {"method": "GET", "url": "https://example.org:123/"}
-def test_request_query_params():
+def test_request_query_params(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
params = dict(request.query_params)
response = JSONResponse({"params": params})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/?a=123&b=456")
assert response.json() == {"params": {"a": "123", "b": "456"}}
-def test_request_headers():
+def test_request_headers(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
headers = dict(request.headers)
response = JSONResponse({"headers": headers})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"host": "example.org"})
assert response.json() == {
"headers": {
@@ -54,7 +52,7 @@ async def app(scope, receive, send):
}
-def test_request_client():
+def test_request_client(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
response = JSONResponse(
@@ -62,19 +60,19 @@ async def app(scope, receive, send):
)
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"host": "testclient", "port": 50000}
-def test_request_body():
+def test_request_body(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"body": ""}
@@ -86,7 +84,7 @@ async def app(scope, receive, send):
assert response.json() == {"body": "abc"}
-def test_request_stream():
+def test_request_stream(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
body = b""
@@ -95,7 +93,7 @@ async def app(scope, receive, send):
response = JSONResponse({"body": body.decode()})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"body": ""}
@@ -107,20 +105,20 @@ async def app(scope, receive, send):
assert response.json() == {"body": "abc"}
-def test_request_form_urlencoded():
+def test_request_form_urlencoded(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
form = await request.form()
response = JSONResponse({"form": dict(form)})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.post("/", data={"abc": "123 @"})
assert response.json() == {"form": {"abc": "123 @"}}
-def test_request_body_then_stream():
+def test_request_body_then_stream(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
body = await request.body()
@@ -130,13 +128,13 @@ async def app(scope, receive, send):
response = JSONResponse({"body": body.decode(), "stream": chunks.decode()})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.post("/", data="abc")
assert response.json() == {"body": "abc", "stream": "abc"}
-def test_request_stream_then_body():
+def test_request_stream_then_body(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
chunks = b""
@@ -149,20 +147,20 @@ async def app(scope, receive, send):
response = JSONResponse({"body": body.decode(), "stream": chunks.decode()})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.post("/", data="abc")
assert response.json() == {"body": "", "stream": "abc"}
-def test_request_json():
+def test_request_json(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
data = await request.json()
response = JSONResponse({"json": data})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.post("/", json={"a": "123"})
assert response.json() == {"json": {"a": "123"}}
@@ -178,7 +176,7 @@ def test_request_scope_interface():
assert len(request) == 3
-def test_request_without_setting_receive():
+def test_request_without_setting_receive(test_client_factory):
"""
If Request is instantiated without the receive channel, then .body()
is not available.
@@ -193,12 +191,12 @@ async def app(scope, receive, send):
response = JSONResponse({"json": data})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.post("/", json={"a": "123"})
assert response.json() == {"json": "Receive channel not available"}
-def test_request_disconnect():
+def test_request_disconnect(anyio_backend_name, anyio_backend_options):
"""
If a client disconnect occurs while reading request body
then ClientDisconnect should be raised.
@@ -212,12 +210,18 @@ async def receiver():
return {"type": "http.disconnect"}
scope = {"type": "http", "method": "POST", "path": "/"}
- loop = asyncio.get_event_loop()
with pytest.raises(ClientDisconnect):
- loop.run_until_complete(app(scope, receiver, None))
+ anyio.run(
+ app,
+ scope,
+ receiver,
+ None,
+ backend=anyio_backend_name,
+ backend_options=anyio_backend_options,
+ )
-def test_request_is_disconnected():
+def test_request_is_disconnected(test_client_factory):
"""
If a client disconnect occurs while reading request body
then ClientDisconnect should be raised.
@@ -234,7 +238,7 @@ async def app(scope, receive, send):
await response(scope, receive, send)
disconnected_after_response = await request.is_disconnected()
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"disconnected": False}
assert disconnected_after_response
@@ -254,19 +258,19 @@ def test_request_state_object():
s.new
-def test_request_state():
+def test_request_state(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
request.state.example = 123
response = JSONResponse({"state.example": request.state.example})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/123?a=abc")
assert response.json() == {"state.example": 123}
-def test_request_cookies():
+def test_request_cookies(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
mycookie = request.cookies.get("mycookie")
@@ -278,14 +282,14 @@ async def app(scope, receive, send):
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
response = client.get("/")
assert response.text == "Hello, cookies!"
-def test_cookie_lenient_parsing():
+def test_cookie_lenient_parsing(test_client_factory):
"""
The following test is based on a cookie set by Okta, a well-known authorization
service. It turns out that it's common practice to set cookies that would be
@@ -312,7 +316,7 @@ async def app(scope, receive, send):
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"cookie": tough_cookie})
result = response.json()
assert len(result["cookies"]) == 4
@@ -341,13 +345,13 @@ async def app(scope, receive, send):
("a=b; h=i; a=c", {"a": "c", "h": "i"}),
],
)
-def test_cookies_edge_cases(set_cookie, expected):
+def test_cookies_edge_cases(set_cookie, expected, test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"cookie": set_cookie})
result = response.json()
assert result["cookies"] == expected
@@ -376,7 +380,7 @@ async def app(scope, receive, send):
# (" = b ; ; = ; c = ; ", {"": "b", "c": ""}),
],
)
-def test_cookies_invalid(set_cookie, expected):
+def test_cookies_invalid(set_cookie, expected, test_client_factory):
"""
Cookie strings that are against the RFC6265 spec but which browsers will send if set
via document.cookie.
@@ -387,20 +391,20 @@ async def app(scope, receive, send):
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/", headers={"cookie": set_cookie})
result = response.json()
assert result["cookies"] == expected
-def test_chunked_encoding():
+def test_chunked_encoding(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
def post_body():
yield b"foo"
@@ -410,7 +414,7 @@ def post_body():
assert response.json() == {"body": "foobar"}
-def test_request_send_push_promise():
+def test_request_send_push_promise(test_client_factory):
async def app(scope, receive, send):
# the server is push-enabled
scope["extensions"]["http.response.push"] = {}
@@ -421,12 +425,12 @@ async def app(scope, receive, send):
response = JSONResponse({"json": "OK"})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"json": "OK"}
-def test_request_send_push_promise_without_push_extension():
+def test_request_send_push_promise_without_push_extension(test_client_factory):
"""
If server does not support the `http.response.push` extension,
.send_push_promise() does nothing.
@@ -439,12 +443,12 @@ async def app(scope, receive, send):
response = JSONResponse({"json": "OK"})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"json": "OK"}
-def test_request_send_push_promise_without_setting_send():
+def test_request_send_push_promise_without_setting_send(test_client_factory):
"""
If Request is instantiated without the send channel, then
.send_push_promise() is not available.
@@ -463,6 +467,6 @@ async def app(scope, receive, send):
response = JSONResponse({"json": data})
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"json": "Send channel not available"}
diff --git a/tests/test_responses.py b/tests/test_responses.py
index fd2ba0e424..baba549baf 100644
--- a/tests/test_responses.py
+++ b/tests/test_responses.py
@@ -1,6 +1,6 @@
-import asyncio
import os
+import anyio
import pytest
from starlette import status
@@ -13,40 +13,39 @@
Response,
StreamingResponse,
)
-from starlette.testclient import TestClient
-def test_text_response():
+def test_text_response(test_client_factory):
async def app(scope, receive, send):
response = Response("hello, world", media_type="text/plain")
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "hello, world"
-def test_bytes_response():
+def test_bytes_response(test_client_factory):
async def app(scope, receive, send):
response = Response(b"xxxxx", media_type="image/png")
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.content == b"xxxxx"
-def test_json_none_response():
+def test_json_none_response(test_client_factory):
async def app(scope, receive, send):
response = JSONResponse(None)
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() is None
-def test_redirect_response():
+def test_redirect_response(test_client_factory):
async def app(scope, receive, send):
if scope["path"] == "/":
response = Response("hello, world", media_type="text/plain")
@@ -54,13 +53,13 @@ async def app(scope, receive, send):
response = RedirectResponse("/")
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/redirect")
assert response.text == "hello, world"
assert response.url == "http://testserver/"
-def test_quoting_redirect_response():
+def test_quoting_redirect_response(test_client_factory):
async def app(scope, receive, send):
if scope["path"] == "/I ♥ Starlette/":
response = Response("hello, world", media_type="text/plain")
@@ -68,13 +67,13 @@ async def app(scope, receive, send):
response = RedirectResponse("/I ♥ Starlette/")
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/redirect")
assert response.text == "hello, world"
assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"
-def test_streaming_response():
+def test_streaming_response(test_client_factory):
filled_by_bg_task = ""
async def app(scope, receive, send):
@@ -83,7 +82,7 @@ async def numbers(minimum, maximum):
yield str(i)
if i != maximum:
yield ", "
- await asyncio.sleep(0)
+ await anyio.sleep(0)
async def numbers_for_cleanup(start=1, stop=5):
nonlocal filled_by_bg_task
@@ -98,13 +97,13 @@ async def numbers_for_cleanup(start=1, stop=5):
await response(scope, receive, send)
assert filled_by_bg_task == ""
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"
assert filled_by_bg_task == "6, 7, 8, 9"
-def test_streaming_response_custom_iterator():
+def test_streaming_response_custom_iterator(test_client_factory):
async def app(scope, receive, send):
class CustomAsyncIterator:
def __init__(self):
@@ -122,12 +121,12 @@ async def __anext__(self):
response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain")
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "12345"
-def test_streaming_response_custom_iterable():
+def test_streaming_response_custom_iterable(test_client_factory):
async def app(scope, receive, send):
class CustomAsyncIterable:
async def __aiter__(self):
@@ -137,12 +136,12 @@ async def __aiter__(self):
response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain")
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "12345"
-def test_sync_streaming_response():
+def test_sync_streaming_response(test_client_factory):
async def app(scope, receive, send):
def numbers(minimum, maximum):
for i in range(minimum, maximum + 1):
@@ -154,37 +153,37 @@ def numbers(minimum, maximum):
response = StreamingResponse(generator, media_type="text/plain")
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"
-def test_response_headers():
+def test_response_headers(test_client_factory):
async def app(scope, receive, send):
headers = {"x-header-1": "123", "x-header-2": "456"}
response = Response("hello, world", media_type="text/plain", headers=headers)
response.headers["x-header-2"] = "789"
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.headers["x-header-1"] == "123"
assert response.headers["x-header-2"] == "789"
-def test_response_phrase():
+def test_response_phrase(test_client_factory):
app = Response(status_code=204)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.reason == "No Content"
app = Response(b"", status_code=123)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.reason == ""
-def test_file_response(tmpdir):
+def test_file_response(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "xyz")
content = b"" * 1000
with open(path, "wb") as file:
@@ -197,7 +196,7 @@ async def numbers(minimum, maximum):
yield str(i)
if i != maximum:
yield ", "
- await asyncio.sleep(0)
+ await anyio.sleep(0)
async def numbers_for_cleanup(start=1, stop=5):
nonlocal filled_by_bg_task
@@ -213,7 +212,7 @@ async def app(scope, receive, send):
await response(scope, receive, send)
assert filled_by_bg_task == ""
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
expected_disposition = 'attachment; filename="example.png"'
assert response.status_code == status.HTTP_200_OK
@@ -226,31 +225,31 @@ async def app(scope, receive, send):
assert filled_by_bg_task == "6, 7, 8, 9"
-def test_file_response_with_directory_raises_error(tmpdir):
+def test_file_response_with_directory_raises_error(tmpdir, test_client_factory):
app = FileResponse(path=tmpdir, filename="example.png")
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/")
assert "is not a file" in str(exc_info.value)
-def test_file_response_with_missing_file_raises_error(tmpdir):
+def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "404.txt")
app = FileResponse(path=path, filename="404.txt")
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/")
assert "does not exist" in str(exc_info.value)
-def test_file_response_with_chinese_filename(tmpdir):
+def test_file_response_with_chinese_filename(tmpdir, test_client_factory):
content = b"file content"
filename = "你好.txt" # probably "Hello.txt" in Chinese
path = os.path.join(tmpdir, filename)
with open(path, "wb") as f:
f.write(content)
app = FileResponse(path=path, filename=filename)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt"
assert response.status_code == status.HTTP_200_OK
@@ -258,7 +257,7 @@ def test_file_response_with_chinese_filename(tmpdir):
assert response.headers["content-disposition"] == expected_disposition
-def test_set_cookie():
+def test_set_cookie(test_client_factory):
async def app(scope, receive, send):
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie(
@@ -274,12 +273,12 @@ async def app(scope, receive, send):
)
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
-def test_delete_cookie():
+def test_delete_cookie(test_client_factory):
async def app(scope, receive, send):
request = Request(scope, receive)
response = Response("Hello, world!", media_type="text/plain")
@@ -289,24 +288,24 @@ async def app(scope, receive, send):
response.set_cookie("mycookie", "myvalue")
await response(scope, receive, send)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.cookies["mycookie"]
response = client.get("/")
assert not response.cookies.get("mycookie")
-def test_populate_headers():
+def test_populate_headers(test_client_factory):
app = Response(content="hi", headers={}, media_type="text/html")
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "hi"
assert response.headers["content-length"] == "2"
assert response.headers["content-type"] == "text/html; charset=utf-8"
-def test_head_method():
+def test_head_method(test_client_factory):
app = Response("hello, world", media_type="text/plain")
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.head("/")
assert response.text == ""
diff --git a/tests/test_routing.py b/tests/test_routing.py
index fff3332dbd..9e734b9cc9 100644
--- a/tests/test_routing.py
+++ b/tests/test_routing.py
@@ -6,7 +6,6 @@
from starlette.applications import Starlette
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
-from starlette.testclient import TestClient
from starlette.websockets import WebSocket, WebSocketDisconnect
@@ -105,10 +104,19 @@ async def websocket_params(session):
await session.close()
-client = TestClient(app)
+@pytest.fixture
+def client(test_client_factory):
+ with test_client_factory(app) as client:
+ yield client
-def test_router():
+@pytest.mark.filterwarnings(
+ r"ignore"
+ r":Trying to detect encoding from a tiny portion of \(5\) byte\(s\)\."
+ r":UserWarning"
+ r":charset_normalizer.api"
+)
+def test_router(client):
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world"
@@ -147,7 +155,7 @@ def test_router():
assert response.text == "xxxxx"
-def test_route_converters():
+def test_route_converters(client):
# Test integer conversion
response = client.get("/int/5")
assert response.status_code == 200
@@ -232,19 +240,19 @@ def test_url_for():
)
-def test_router_add_route():
+def test_router_add_route(client):
response = client.get("/func")
assert response.status_code == 200
assert response.text == "Hello, world!"
-def test_router_duplicate_path():
+def test_router_duplicate_path(client):
response = client.post("/func")
assert response.status_code == 200
assert response.text == "Hello, POST!"
-def test_router_add_websocket_route():
+def test_router_add_websocket_route(client):
with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"
@@ -275,8 +283,8 @@ async def __call__(self, scope, receive, send):
)
-def test_protocol_switch():
- client = TestClient(mixed_protocol_app)
+def test_protocol_switch(test_client_factory):
+ client = test_client_factory(mixed_protocol_app)
response = client.get("/")
assert response.status_code == 200
@@ -286,15 +294,16 @@ def test_protocol_switch():
assert session.receive_json() == {"URL": "ws://testserver/"}
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/404")
+ with client.websocket_connect("/404"):
+ pass # pragma: nocover
ok = PlainTextResponse("OK")
-def test_mount_urls():
+def test_mount_urls(test_client_factory):
mounted = Router([Mount("/users", ok, name="users")])
- client = TestClient(mounted)
+ client = test_client_factory(mounted)
assert client.get("/users").status_code == 200
assert client.get("/users").url == "http://testserver/users/"
assert client.get("/users/").status_code == 200
@@ -317,9 +326,9 @@ def test_reverse_mount_urls():
)
-def test_mount_at_root():
+def test_mount_at_root(test_client_factory):
mounted = Router([Mount("/", ok, name="users")])
- client = TestClient(mounted)
+ client = test_client_factory(mounted)
assert client.get("/").status_code == 200
@@ -347,8 +356,8 @@ def users_api(request):
)
-def test_host_routing():
- client = TestClient(mixed_hosts_app, base_url="https://api.example.org/")
+def test_host_routing(test_client_factory):
+ client = test_client_factory(mixed_hosts_app, base_url="https://api.example.org/")
response = client.get("/users")
assert response.status_code == 200
@@ -357,7 +366,7 @@ def test_host_routing():
response = client.get("/")
assert response.status_code == 404
- client = TestClient(mixed_hosts_app, base_url="https://www.example.org/")
+ client = test_client_factory(mixed_hosts_app, base_url="https://www.example.org/")
response = client.get("/users")
assert response.status_code == 200
@@ -392,8 +401,8 @@ async def subdomain_app(scope, receive, send):
)
-def test_subdomain_routing():
- client = TestClient(subdomain_app, base_url="https://foo.example.org/")
+def test_subdomain_routing(test_client_factory):
+ client = test_client_factory(subdomain_app, base_url="https://foo.example.org/")
response = client.get("/")
assert response.status_code == 200
@@ -428,9 +437,11 @@ async def echo_urls(request):
]
-def test_url_for_with_root_path():
+def test_url_for_with_root_path(test_client_factory):
app = Starlette(routes=echo_url_routes)
- client = TestClient(app, base_url="https://www.example.org/", root_path="/sub_path")
+ client = test_client_factory(
+ app, base_url="https://www.example.org/", root_path="/sub_path"
+ )
response = client.get("/")
assert response.json() == {
"index": "https://www.example.org/sub_path/",
@@ -458,17 +469,17 @@ def test_url_for_with_double_mount():
assert url == "/mount/static/123"
-def test_standalone_route_matches():
+def test_standalone_route_matches(test_client_factory):
app = Route("/", PlainTextResponse("Hello, World!"))
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, World!"
-def test_standalone_route_does_not_match():
+def test_standalone_route_does_not_match(test_client_factory):
app = Route("/", PlainTextResponse("Hello, World!"))
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/invalid")
assert response.status_code == 404
assert response.text == "Not Found"
@@ -480,22 +491,23 @@ async def ws_helloworld(websocket):
await websocket.close()
-def test_standalone_ws_route_matches():
+def test_standalone_ws_route_matches(test_client_factory):
app = WebSocketRoute("/", ws_helloworld)
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
text = websocket.receive_text()
assert text == "Hello, world!"
-def test_standalone_ws_route_does_not_match():
+def test_standalone_ws_route_does_not_match(test_client_factory):
app = WebSocketRoute("/", ws_helloworld)
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect):
- client.websocket_connect("/invalid")
+ with client.websocket_connect("/invalid"):
+ pass # pragma: nocover
-def test_lifespan_async():
+def test_lifespan_async(test_client_factory):
startup_complete = False
shutdown_complete = False
@@ -518,7 +530,7 @@ async def run_shutdown():
assert not startup_complete
assert not shutdown_complete
- with TestClient(app) as client:
+ with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
@@ -526,7 +538,7 @@ async def run_shutdown():
assert shutdown_complete
-def test_lifespan_sync():
+def test_lifespan_sync(test_client_factory):
startup_complete = False
shutdown_complete = False
@@ -549,7 +561,7 @@ def run_shutdown():
assert not startup_complete
assert not shutdown_complete
- with TestClient(app) as client:
+ with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
@@ -557,7 +569,7 @@ def run_shutdown():
assert shutdown_complete
-def test_raise_on_startup():
+def test_raise_on_startup(test_client_factory):
def run_startup():
raise RuntimeError()
@@ -574,19 +586,19 @@ async def _send(message):
startup_failed = False
with pytest.raises(RuntimeError):
- with TestClient(app):
+ with test_client_factory(app):
pass # pragma: nocover
assert startup_failed
-def test_raise_on_shutdown():
+def test_raise_on_shutdown(test_client_factory):
def run_shutdown():
raise RuntimeError()
app = Router(on_shutdown=[run_shutdown])
with pytest.raises(RuntimeError):
- with TestClient(app):
+ with test_client_factory(app):
pass # pragma: nocover
@@ -613,8 +625,8 @@ async def _partial_async_endpoint(arg, request):
)
-def test_partial_async_endpoint():
- test_client = TestClient(partial_async_app)
+def test_partial_async_endpoint(test_client_factory):
+ test_client = test_client_factory(partial_async_app)
response = test_client.get("/")
assert response.status_code == 200
assert response.json() == {"arg": "foo"}
diff --git a/tests/test_schemas.py b/tests/test_schemas.py
index 0ae43238fe..28fe777f09 100644
--- a/tests/test_schemas.py
+++ b/tests/test_schemas.py
@@ -1,7 +1,6 @@
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.schemas import SchemaGenerator
-from starlette.testclient import TestClient
schemas = SchemaGenerator(
{"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}
@@ -213,8 +212,8 @@ def test_schema_generation():
"""
-def test_schema_endpoint():
- client = TestClient(app)
+def test_schema_endpoint(test_client_factory):
+ client = test_client_factory(app)
response = client.get("/schema")
assert response.headers["Content-Type"] == "application/vnd.oai.openapi"
assert response.text.strip() == EXPECTED_SCHEMA.strip()
diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py
index 6b325071fa..d5ec1afc5e 100644
--- a/tests/test_staticfiles.py
+++ b/tests/test_staticfiles.py
@@ -1,43 +1,42 @@
-import asyncio
import os
import pathlib
import time
+import anyio
import pytest
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount
from starlette.staticfiles import StaticFiles
-from starlette.testclient import TestClient
-def test_staticfiles(tmpdir):
+def test_staticfiles(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/example.txt")
assert response.status_code == 200
assert response.text == ""
-def test_staticfiles_with_pathlib(tmpdir):
+def test_staticfiles_with_pathlib(tmpdir, test_client_factory):
base_dir = pathlib.Path(tmpdir)
path = base_dir / "example.txt"
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=base_dir)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/example.txt")
assert response.status_code == 200
assert response.text == ""
-def test_staticfiles_head_with_middleware(tmpdir):
+def test_staticfiles_head_with_middleware(tmpdir, test_client_factory):
"""
see https://github.com/encode/starlette/pull/935
"""
@@ -53,51 +52,51 @@ async def does_nothing_middleware(request: Request, call_next):
response = await call_next(request)
return response
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.head("/static/example.txt")
assert response.status_code == 200
assert response.headers.get("content-length") == "100"
-def test_staticfiles_with_package():
+def test_staticfiles_with_package(test_client_factory):
app = StaticFiles(packages=["tests"])
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/example.txt")
assert response.status_code == 200
assert response.text == "123\n"
-def test_staticfiles_post(tmpdir):
+def test_staticfiles_post(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.post("/example.txt")
assert response.status_code == 405
assert response.text == "Method Not Allowed"
-def test_staticfiles_with_directory_returns_404(tmpdir):
+def test_staticfiles_with_directory_returns_404(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 404
assert response.text == "Not Found"
-def test_staticfiles_with_missing_file_returns_404(tmpdir):
+def test_staticfiles_with_missing_file_returns_404(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/404.txt")
assert response.status_code == 404
assert response.text == "Not Found"
@@ -110,30 +109,32 @@ def test_staticfiles_instantiated_with_missing_directory(tmpdir):
assert "does not exist" in str(exc_info.value)
-def test_staticfiles_configured_with_missing_directory(tmpdir):
+def test_staticfiles_configured_with_missing_directory(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "no_such_directory")
app = StaticFiles(directory=path, check_dir=False)
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/example.txt")
assert "does not exist" in str(exc_info.value)
-def test_staticfiles_configured_with_file_instead_of_directory(tmpdir):
+def test_staticfiles_configured_with_file_instead_of_directory(
+ tmpdir, test_client_factory
+):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=path, check_dir=False)
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/example.txt")
assert "is not a directory" in str(exc_info.value)
-def test_staticfiles_config_check_occurs_only_once(tmpdir):
+def test_staticfiles_config_check_occurs_only_once(tmpdir, test_client_factory):
app = StaticFiles(directory=tmpdir)
- client = TestClient(app)
+ client = test_client_factory(app)
assert not app.config_checked
client.get("/")
assert app.config_checked
@@ -153,32 +154,31 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
# We can't test this with 'requests', so we test the app directly here.
path = app.get_path({"path": "/../example.txt"})
scope = {"method": "GET"}
- loop = asyncio.get_event_loop()
- response = loop.run_until_complete(app.get_response(path, scope))
+ response = anyio.run(app.get_response, path, scope)
assert response.status_code == 404
assert response.body == b"Not Found"
-def test_staticfiles_never_read_file_for_head_method(tmpdir):
+def test_staticfiles_never_read_file_for_head_method(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.head("/example.txt")
assert response.status_code == 200
assert response.content == b""
assert response.headers["content-length"] == "14"
-def test_staticfiles_304_with_etag_match(tmpdir):
+def test_staticfiles_304_with_etag_match(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("")
app = StaticFiles(directory=tmpdir)
- client = TestClient(app)
+ client = test_client_factory(app)
first_resp = client.get("/example.txt")
assert first_resp.status_code == 200
last_etag = first_resp.headers["etag"]
@@ -187,7 +187,9 @@ def test_staticfiles_304_with_etag_match(tmpdir):
assert second_resp.content == b""
-def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir):
+def test_staticfiles_304_with_last_modified_compare_last_req(
+ tmpdir, test_client_factory
+):
path = os.path.join(tmpdir, "example.txt")
file_last_modified_time = time.mktime(
time.strptime("2013-10-10 23:40:00", "%Y-%m-%d %H:%M:%S")
@@ -197,7 +199,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir):
os.utime(path, (file_last_modified_time, file_last_modified_time))
app = StaticFiles(directory=tmpdir)
- client = TestClient(app)
+ client = test_client_factory(app)
# last modified less than last request, 304
response = client.get(
"/example.txt", headers={"If-Modified-Since": "Thu, 11 Oct 2013 15:30:19 GMT"}
@@ -212,7 +214,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(tmpdir):
assert response.content == b""
-def test_staticfiles_html(tmpdir):
+def test_staticfiles_html(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "404.html")
with open(path, "w") as file:
file.write("Custom not found page ")
@@ -223,7 +225,7 @@ def test_staticfiles_html(tmpdir):
file.write("Hello ")
app = StaticFiles(directory=tmpdir, html=True)
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/dir/")
assert response.url == "http://testserver/dir/"
@@ -245,7 +247,9 @@ def test_staticfiles_html(tmpdir):
assert response.text == "Custom not found page "
-def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(tmpdir):
+def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(
+ tmpdir, test_client_factory
+):
path_404 = os.path.join(tmpdir, "404.html")
with open(path_404, "w") as file:
file.write("404 file
")
@@ -260,7 +264,7 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(tmpdir):
os.utime(path_some, (common_modified_time, common_modified_time))
app = StaticFiles(directory=tmpdir, html=True)
- client = TestClient(app)
+ client = test_client_factory(app)
resp_exists = client.get("/some.html")
assert resp_exists.status_code == 200
diff --git a/tests/test_templates.py b/tests/test_templates.py
index a0ab3e1b0b..073482d65a 100644
--- a/tests/test_templates.py
+++ b/tests/test_templates.py
@@ -4,10 +4,9 @@
from starlette.applications import Starlette
from starlette.templating import Jinja2Templates
-from starlette.testclient import TestClient
-def test_templates(tmpdir):
+def test_templates(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("Hello, world ")
@@ -19,7 +18,7 @@ def test_templates(tmpdir):
async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world "
assert response.template.name == "index.html"
diff --git a/tests/test_testclient.py b/tests/test_testclient.py
index 00f4e0125b..8c06667896 100644
--- a/tests/test_testclient.py
+++ b/tests/test_testclient.py
@@ -1,13 +1,24 @@
import asyncio
+import itertools
+import sys
+import anyio
import pytest
+import sniffio
+import trio.lowlevel
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse
-from starlette.testclient import TestClient
from starlette.websockets import WebSocket, WebSocketDisconnect
+if sys.version_info >= (3, 7): # pragma: no cover
+ from asyncio import current_task as asyncio_current_task
+ from contextlib import asynccontextmanager
+else: # pragma: no cover
+ asyncio_current_task = asyncio.Task.current_task
+ from contextlib2 import asynccontextmanager
+
mock_service = Starlette()
@@ -16,14 +27,19 @@ def mock_service_endpoint(request):
return JSONResponse({"mock": "example"})
-app = Starlette()
-
+def current_task():
+ # anyio's TaskInfo comparisons are invalid after their associated native
+ # task object is GC'd https://github.com/agronholm/anyio/issues/324
+ asynclib_name = sniffio.current_async_library()
+ if asynclib_name == "trio":
+ return trio.lowlevel.current_task()
-@app.route("/")
-def homepage(request):
- client = TestClient(mock_service)
- response = client.get("/")
- return JSONResponse(response.json())
+ if asynclib_name == "asyncio":
+ task = asyncio_current_task()
+ if task is None:
+ raise RuntimeError("must be called from a running task") # pragma: no cover
+ return task
+ raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover
startup_error_app = Starlette()
@@ -34,30 +50,110 @@ def startup():
raise RuntimeError()
-def test_use_testclient_in_endpoint():
+def test_use_testclient_in_endpoint(test_client_factory):
"""
We should be able to use the test client within applications.
This is useful if we need to mock out other services,
during tests or in development.
"""
- client = TestClient(app)
+
+ app = Starlette()
+
+ @app.route("/")
+ def homepage(request):
+ client = test_client_factory(mock_service)
+ response = client.get("/")
+ return JSONResponse(response.json())
+
+ client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"mock": "example"}
-def test_use_testclient_as_contextmanager():
- with TestClient(app):
- pass
+def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name):
+ """
+ This test asserts a number of properties that are important for an
+ app level task_group
+ """
+ counter = itertools.count()
+ identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar")
+
+ def get_identity():
+ try:
+ return identity_runvar.get()
+ except LookupError:
+ token = next(counter)
+ identity_runvar.set(token)
+ return token
+
+ startup_task = object()
+ startup_loop = None
+ shutdown_task = object()
+ shutdown_loop = None
+
+ @asynccontextmanager
+ async def lifespan_context(app):
+ nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop
+
+ startup_task = current_task()
+ startup_loop = get_identity()
+ async with anyio.create_task_group() as app.task_group:
+ yield
+ shutdown_task = current_task()
+ shutdown_loop = get_identity()
+
+ app = Starlette(lifespan=lifespan_context)
+
+ @app.route("/loop_id")
+ async def loop_id(request):
+ return JSONResponse(get_identity())
+
+ client = test_client_factory(app)
+
+ with client:
+ # within a TestClient context every async request runs in the same thread
+ assert client.get("/loop_id").json() == 0
+ assert client.get("/loop_id").json() == 0
+
+ # that thread is also the same as the lifespan thread
+ assert startup_loop == 0
+ assert shutdown_loop == 0
+
+ # lifespan events run in the same task, this is important because a task
+ # group must be entered and exited in the same task.
+ assert startup_task is shutdown_task
+
+ # outside the TestClient context, new requests continue to spawn in new
+ # eventloops in new threads
+ assert client.get("/loop_id").json() == 1
+ assert client.get("/loop_id").json() == 2
+
+ first_task = startup_task
+
+ with client:
+ # the TestClient context can be re-used, starting a new lifespan task
+ # in a new thread
+ assert client.get("/loop_id").json() == 3
+ assert client.get("/loop_id").json() == 3
+
+ assert startup_loop == 3
+ assert shutdown_loop == 3
+
+ # lifespan events still run in the same task, with the context but...
+ assert startup_task is shutdown_task
+
+ # ... the second TestClient context creates a new lifespan task.
+ assert first_task is not startup_task
-def test_error_on_startup():
+def test_error_on_startup(test_client_factory):
with pytest.raises(RuntimeError):
- with TestClient(startup_error_app):
+ with test_client_factory(startup_error_app):
pass # pragma: no cover
-def test_exception_in_middleware():
+def test_exception_in_middleware(test_client_factory):
class MiddlewareException(Exception):
pass
@@ -71,11 +167,11 @@ async def __call__(self, scope, receive, send):
broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)])
with pytest.raises(MiddlewareException):
- with TestClient(broken_middleware):
+ with test_client_factory(broken_middleware):
pass # pragma: no cover
-def test_testclient_asgi2():
+def test_testclient_asgi2(test_client_factory):
def app(scope):
async def inner(receive, send):
await send(
@@ -89,12 +185,12 @@ async def inner(receive, send):
return inner
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
-def test_testclient_asgi3():
+def test_testclient_asgi3(test_client_factory):
async def app(scope, receive, send):
await send(
{
@@ -105,12 +201,12 @@ async def app(scope, receive, send):
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
- client = TestClient(app)
+ client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
-def test_websocket_blocking_receive():
+def test_websocket_blocking_receive(test_client_factory):
def app(scope):
async def respond(websocket):
await websocket.send_json({"message": "test"})
@@ -118,17 +214,18 @@ async def respond(websocket):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
- asyncio.ensure_future(respond(websocket))
- try:
- # this will block as the client does not send us data
- # it should not prevent `respond` from executing though
- await websocket.receive_json()
- except WebSocketDisconnect:
- pass
+ async with anyio.create_task_group() as task_group:
+ task_group.start_soon(respond, websocket)
+ try:
+ # this will block as the client does not send us data
+ # it should not prevent `respond` from executing though
+ await websocket.receive_json()
+ except WebSocketDisconnect:
+ pass
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
data = websocket.receive_json()
assert data == {"message": "test"}
diff --git a/tests/test_websockets.py b/tests/test_websockets.py
index ffb1a44a8c..e02d433d57 100644
--- a/tests/test_websockets.py
+++ b/tests/test_websockets.py
@@ -1,14 +1,11 @@
-import asyncio
-
+import anyio
import pytest
from starlette import status
-from starlette.concurrency import run_until_first_complete
-from starlette.testclient import TestClient
from starlette.websockets import WebSocket, WebSocketDisconnect
-def test_websocket_url():
+def test_websocket_url(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -18,13 +15,13 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/123?a=abc") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/123?a=abc"}
-def test_websocket_binary_json():
+def test_websocket_binary_json(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -35,14 +32,14 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/123?a=abc") as websocket:
websocket.send_json({"test": "data"}, mode="binary")
data = websocket.receive_json(mode="binary")
assert data == {"test": "data"}
-def test_websocket_query_params():
+def test_websocket_query_params(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -53,13 +50,13 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/?a=abc&b=456") as websocket:
data = websocket.receive_json()
assert data == {"params": {"a": "abc", "b": "456"}}
-def test_websocket_headers():
+def test_websocket_headers(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -70,7 +67,7 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
expected_headers = {
"accept": "*/*",
@@ -85,7 +82,7 @@ async def asgi(receive, send):
assert data == {"headers": expected_headers}
-def test_websocket_port():
+def test_websocket_port(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -95,13 +92,13 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket:
data = websocket.receive_json()
assert data == {"port": 123}
-def test_websocket_send_and_receive_text():
+def test_websocket_send_and_receive_text(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -112,14 +109,14 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.send_text("Hello, world!")
data = websocket.receive_text()
assert data == "Message was: Hello, world!"
-def test_websocket_send_and_receive_bytes():
+def test_websocket_send_and_receive_bytes(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -130,14 +127,14 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.send_bytes(b"Hello, world!")
data = websocket.receive_bytes()
assert data == b"Message was: Hello, world!"
-def test_websocket_send_and_receive_json():
+def test_websocket_send_and_receive_json(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -148,14 +145,14 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.send_json({"hello": "world"})
data = websocket.receive_json()
assert data == {"message": {"hello": "world"}}
-def test_websocket_iter_text():
+def test_websocket_iter_text(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -165,14 +162,14 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.send_text("Hello, world!")
data = websocket.receive_text()
assert data == "Message was: Hello, world!"
-def test_websocket_iter_bytes():
+def test_websocket_iter_bytes(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -182,14 +179,14 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.send_bytes(b"Hello, world!")
data = websocket.receive_bytes()
assert data == b"Message was: Hello, world!"
-def test_websocket_iter_json():
+def test_websocket_iter_json(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -199,44 +196,45 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.send_json({"hello": "world"})
data = websocket.receive_json()
assert data == {"message": {"hello": "world"}}
-def test_websocket_concurrency_pattern():
+def test_websocket_concurrency_pattern(test_client_factory):
def app(scope):
- async def reader(websocket, queue):
- async for data in websocket.iter_json():
- await queue.put(data)
+ stream_send, stream_receive = anyio.create_memory_object_stream()
+
+ async def reader(websocket):
+ async with stream_send:
+ async for data in websocket.iter_json():
+ await stream_send.send(data)
- async def writer(websocket, queue):
- while True:
- message = await queue.get()
- await websocket.send_json(message)
+ async def writer(websocket):
+ async with stream_receive:
+ async for message in stream_receive:
+ await websocket.send_json(message)
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
- queue = asyncio.Queue()
await websocket.accept()
- await run_until_first_complete(
- (reader, {"websocket": websocket, "queue": queue}),
- (writer, {"websocket": websocket, "queue": queue}),
- )
+ async with anyio.create_task_group() as task_group:
+ task_group.start_soon(reader, websocket)
+ await writer(websocket)
await websocket.close()
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.send_json({"hello": "world"})
data = websocket.receive_json()
assert data == {"hello": "world"}
-def test_client_close():
+def test_client_close(test_client_factory):
close_code = None
def app(scope):
@@ -251,13 +249,13 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
websocket.close(code=status.WS_1001_GOING_AWAY)
assert close_code == status.WS_1001_GOING_AWAY
-def test_application_close():
+def test_application_close(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -266,14 +264,14 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
-def test_rejected_connection():
+def test_rejected_connection(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -281,13 +279,14 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect) as exc:
- client.websocket_connect("/")
+ with client.websocket_connect("/"):
+ pass # pragma: nocover
assert exc.value.code == status.WS_1001_GOING_AWAY
-def test_subprotocol():
+def test_subprotocol(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -297,24 +296,25 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket:
assert websocket.accepted_subprotocol == "wamp"
-def test_websocket_exception():
+def test_websocket_exception(test_client_factory):
def app(scope):
async def asgi(receive, send):
assert False
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(AssertionError):
- client.websocket_connect("/123?a=abc")
+ with client.websocket_connect("/123?a=abc"):
+ pass # pragma: nocover
-def test_duplicate_close():
+def test_duplicate_close(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -324,13 +324,13 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
- pass
+ pass # pragma: nocover
-def test_duplicate_disconnect():
+def test_duplicate_disconnect(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
@@ -341,7 +341,7 @@ async def asgi(receive, send):
return asgi
- client = TestClient(app)
+ client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/") as websocket:
websocket.close()
@@ -367,3 +367,13 @@ async def mock_send(message):
assert websocket["type"] == "websocket"
assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []}
assert len(websocket) == 3
+
+ # check __eq__ and __hash__
+ assert websocket != WebSocket(
+ {"type": "websocket", "path": "/abc/", "headers": []},
+ receive=mock_receive,
+ send=mock_send,
+ )
+ assert websocket == websocket
+ assert websocket in {websocket}
+ assert {websocket} == {websocket}