diff --git a/pyproject.toml b/pyproject.toml index 48562e714..23fceaeea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,10 @@ select = ["E", "F", "I"] [tool.ruff.lint.isort] combine-as-imports = true +split-on-trailing-comma = false + +[tool.ruff.format] +skip-magic-trailing-comma = true [tool.mypy] strict = true diff --git a/starlette/datastructures.py b/starlette/datastructures.py index e430d09b6..5841b0de5 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -22,10 +22,7 @@ class Address(typing.NamedTuple): class URL: def __init__( - self, - url: str = "", - scope: Scope | None = None, - **components: typing.Any, + self, url: str = "", scope: Scope | None = None, **components: typing.Any ) -> None: if scope is not None: assert not url, 'Cannot set both "url" and "scope".' diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py index 3d0342dc3..815bce0e1 100644 --- a/starlette/middleware/__init__.py +++ b/starlette/middleware/__init__.py @@ -23,10 +23,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class Middleware: def __init__( - self, - cls: type[_MiddlewareClass[P]], - *args: P.args, - **kwargs: P.kwargs, + self, cls: type[_MiddlewareClass[P]], *args: P.args, **kwargs: P.kwargs ) -> None: self.cls = cls self.args = args diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index ad3ffcfee..d3c1380bf 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -60,21 +60,13 @@ async def wrapped_receive(self) -> Message: if getattr(self, "_body", None) is not None: # body() was called, we return it even if the client disconnected self._wrapped_rcv_consumed = True - return { - "type": "http.request", - "body": self._body, - "more_body": False, - } + return {"type": "http.request", "body": self._body, "more_body": False} elif self._stream_consumed: # stream() was called to completion # return an empty body so that downstream apps don't hang # waiting for a disconnect self._wrapped_rcv_consumed = True - return { - "type": "http.request", - "body": b"", - "more_body": False, - } + return {"type": "http.request", "body": b"", "more_body": False} else: # body() was never called and stream() wasn't consumed try: diff --git a/starlette/routing.py b/starlette/routing.py index b5467bb05..9aace575f 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -659,18 +659,14 @@ def __init__( "use an @contextlib.asynccontextmanager function instead", DeprecationWarning, ) - self.lifespan_context = asynccontextmanager( - lifespan, - ) + self.lifespan_context = asynccontextmanager(lifespan) elif inspect.isgeneratorfunction(lifespan): warnings.warn( "generator function lifespans are deprecated, " "use an @contextlib.asynccontextmanager function instead", DeprecationWarning, ) - self.lifespan_context = _wrap_gen_lifespan_context( - lifespan, - ) + self.lifespan_context = _wrap_gen_lifespan_context(lifespan) else: self.lifespan_context = lifespan diff --git a/starlette/templating.py b/starlette/templating.py index fe31ab5ee..e21df0df1 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -47,10 +47,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await send( { "type": "http.response.debug", - "info": { - "template": self.template, - "context": self.context, - }, + "info": {"template": self.template, "context": self.context}, } ) await super().__call__(scope, receive, send) diff --git a/starlette/testclient.py b/starlette/testclient.py index 90eb53e3d..39812771f 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -79,8 +79,7 @@ def __init__(self, session: WebSocketTestSession) -> None: class WebSocketDenialResponse( # type: ignore[misc] - httpx.Response, - WebSocketDisconnect, + httpx.Response, WebSocketDisconnect ): """ A special case of `WebSocketDisconnect`, raised in the `TestClient` if the @@ -90,10 +89,7 @@ class WebSocketDenialResponse( # type: ignore[misc] class WebSocketTestSession: def __init__( - self, - app: ASGI3App, - scope: Scope, - portal_factory: _PortalFactoryType, + self, app: ASGI3App, scope: Scope, portal_factory: _PortalFactoryType ) -> None: self.app = app self.scope = scope @@ -182,9 +178,7 @@ def _raise_on_close(self, message: Message) -> None: if not message.get("more_body", False): break raise WebSocketDenialResponse( - status_code=status_code, - headers=headers, - content=b"".join(body), + status_code=status_code, headers=headers, content=b"".join(body) ) def send(self, message: Message) -> None: @@ -400,11 +394,7 @@ async def send(message: Message) -> None: if self.raise_server_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: - raw_kwargs = { - "status_code": 500, - "headers": [], - "stream": io.BytesIO(), - } + raw_kwargs = {"status_code": 500, "headers": [], "stream": io.BytesIO()} raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) diff --git a/tests/conftest.py b/tests/conftest.py index 1a61664d1..08bdd67cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,5 @@ def test_client_factory( # anyio_backend_name defined by: # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on return functools.partial( - TestClient, - backend=anyio_backend_name, - backend_options=anyio_backend_options, + TestClient, backend=anyio_backend_name, backend_options=anyio_backend_options ) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 6e5e42b94..d48a32fe5 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,14 +1,6 @@ import contextvars from contextlib import AsyncExitStack -from typing import ( - Any, - AsyncGenerator, - Callable, - Generator, - List, - Type, - Union, -) +from typing import Any, AsyncGenerator, Callable, Generator, List, Type, Union import anyio import pytest @@ -30,9 +22,7 @@ class CustomMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: response = await call_next(request) response.headers["Custom-Header"] = "Example" @@ -57,12 +47,7 @@ def _generate_faulty_stream() -> Generator[bytes, None, None]: class NoResponse: - def __init__( - self, - scope: Scope, - receive: Receive, - send: Send, - ): + def __init__(self, scope: Scope, receive: Receive, send: Send): pass def __await__(self) -> Generator[Any, None, None]: @@ -119,9 +104,7 @@ def test_state_data_across_multiple_middlewares( class aMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: request.state.foo = expected_value1 response = await call_next(request) @@ -129,9 +112,7 @@ async def dispatch( class bMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: request.state.bar = expected_value2 response = await call_next(request) @@ -140,9 +121,7 @@ async def dispatch( class cMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: response = await call_next(request) response.headers["X-State-Bar"] = request.state.bar @@ -184,9 +163,7 @@ def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> Non # Test for https://github.com/encode/starlette/issues/1022 class CustomMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> PlainTextResponse: await call_next(request) return PlainTextResponse("Custom") @@ -213,9 +190,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: ctxvar.set("set by middleware") resp = await call_next(request) @@ -240,8 +215,7 @@ async def dispatch( ], ) def test_contextvars( - test_client_factory: TestClientFactory, - middleware_cls: Type[_MiddlewareClass[Any]], + test_client_factory: TestClientFactory, middleware_cls: Type[_MiddlewareClass[Any]] ) -> None: # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating @@ -278,8 +252,7 @@ async def endpoint_with_background_task(_: Request) -> PlainTextResponse: return PlainTextResponse(background=BackgroundTask(sleep_and_set)) async def passthrough( - request: Request, - call_next: RequestResponseEndpoint, + request: Request, call_next: RequestResponseEndpoint ) -> Response: return await call_next(request) @@ -288,12 +261,7 @@ async def passthrough( routes=[Route("/", endpoint_with_background_task)], ) - scope = { - "type": "http", - "version": "3", - "method": "GET", - "path": "/", - } + scope = {"type": "http", "version": "3", "method": "GET", "path": "/"} async def receive() -> Message: nonlocal request_body_sent @@ -340,12 +308,7 @@ async def passthrough( routes=[Route("/", endpoint_with_background_task)], ) - scope = { - "type": "http", - "version": "3", - "method": "GET", - "path": "/", - } + scope = {"type": "http", "version": "3", "method": "GET", "path": "/"} async def receive() -> Message: nonlocal request_body_sent @@ -406,8 +369,7 @@ async def simple_endpoint(_: Request) -> PlainTextResponse: return PlainTextResponse(background=BackgroundTask(sleep_and_set)) async def passthrough( - request: Request, - call_next: RequestResponseEndpoint, + request: Request, call_next: RequestResponseEndpoint ) -> Response: return await call_next(request) @@ -419,12 +381,7 @@ async def passthrough( routes=[Route("/", simple_endpoint)], ) - scope = { - "type": "http", - "version": "3", - "method": "GET", - "path": "/", - } + scope = {"type": "http", "version": "3", "method": "GET", "path": "/"} async def receive() -> Message: nonlocal request_body_sent @@ -449,11 +406,7 @@ def test_app_receives_http_disconnect_while_sending_if_discarded( test_client_factory: TestClientFactory, ) -> None: class DiscardingMiddleware(BaseHTTPMiddleware): - async def dispatch( - self, - request: Request, - call_next: Any, - ) -> PlainTextResponse: + async def dispatch(self, request: Request, call_next: Any) -> PlainTextResponse: # As a matter of ordering, this test targets the case where the downstream # app response is discarded while it is sending a response body. # We need to wait for the downstream app to begin sending a response body @@ -468,25 +421,18 @@ async def dispatch( return PlainTextResponse("Custom") - async def downstream_app( - scope: Scope, - receive: Receive, - send: Send, - ) -> None: + async def downstream_app(scope: Scope, receive: Receive, send: Send) -> None: await send( { "type": "http.response.start", "status": 200, - "headers": [ - (b"content-type", b"text/plain"), - ], + "headers": [(b"content-type", b"text/plain")], } ) async with anyio.create_task_group() as task_group: async def cancel_on_disconnect( - *, - task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED ) -> None: task_status.started() while True: @@ -528,40 +474,24 @@ def test_app_receives_http_disconnect_after_sending_if_discarded( ) -> None: class DiscardingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> PlainTextResponse: await call_next(request) return PlainTextResponse("Custom") - async def downstream_app( - scope: Scope, - receive: Receive, - send: Send, - ) -> None: + async def downstream_app(scope: Scope, receive: Receive, send: Send) -> None: await send( { "type": "http.response.start", "status": 200, - "headers": [ - (b"content-type", b"text/plain"), - ], + "headers": [(b"content-type", b"text/plain")], } ) await send( - { - "type": "http.response.body", - "body": b"first chunk, ", - "more_body": True, - } + {"type": "http.response.body", "body": b"first chunk, ", "more_body": True} ) await send( - { - "type": "http.response.body", - "body": b"second chunk", - "more_body": True, - } + {"type": "http.response.body", "body": b"second chunk", "more_body": True} ) message = await receive() assert message["type"] == "http.disconnect" @@ -585,9 +515,7 @@ async def homepage(request: Request) -> PlainTextResponse: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: expected = [b"a", b""] async for chunk in request.stream(): @@ -617,9 +545,7 @@ async def homepage(request: Request) -> PlainTextResponse: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: assert await request.body() == b"a" return await call_next(request) @@ -643,9 +569,7 @@ async def homepage(request: Request) -> PlainTextResponse: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: expected = [b"a", b""] async for chunk in request.stream(): @@ -672,9 +596,7 @@ async def homepage(request: Request) -> PlainTextResponse: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: assert await request.body() == b"a" return await call_next(request) @@ -701,9 +623,7 @@ async def homepage(request: Request) -> PlainTextResponse: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: resp = await call_next(request) with pytest.raises(RuntimeError, match="Stream consumed"): @@ -730,9 +650,7 @@ async def homepage(request: Request) -> PlainTextResponse: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: resp = await call_next(request) with pytest.raises(RuntimeError, match="Stream consumed"): @@ -761,9 +679,7 @@ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: expected = b"1" response: Union[Response, None] = None @@ -816,9 +732,7 @@ async def homepage(request: Request) -> PlainTextResponse: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: assert ( await request.body() == b"a" @@ -848,9 +762,7 @@ async def homepage(request: Request) -> PlainTextResponse: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: assert ( await request.body() == b"a" @@ -885,9 +797,7 @@ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: response = await call_next(request) disconnected = await request.is_disconnected() @@ -925,9 +835,7 @@ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: await request.body() disconnected = await request.is_disconnected() @@ -970,9 +878,7 @@ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: class ConsumingMiddleware(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: body = await request.body() assert body == b"foo " @@ -1004,9 +910,7 @@ def test_pr_1519_comment_1236166180_example() -> None: class LogRequestBodySize(BaseHTTPMiddleware): async def dispatch( - self, - request: Request, - call_next: RequestResponseEndpoint, + self, request: Request, call_next: RequestResponseEndpoint ) -> Response: print(len(await request.body())) return await call_next(request) diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 09ec9513f..59855d3eb 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -12,9 +12,7 @@ TestClientFactory = Callable[[ASGIApp], TestClient] -def test_cors_allow_all( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_all(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -124,9 +122,7 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-origin" not in response.headers -def test_cors_allow_specific_origin( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_specific_origin(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -173,9 +169,7 @@ def homepage(request: Request) -> PlainTextResponse: assert "access-control-allow-origin" not in response.headers -def test_cors_disallowed_preflight( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_disallowed_preflight(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> None: pass # pragma: no cover @@ -235,14 +229,8 @@ def homepage(request: Request) -> None: client = test_client_factory(app) # Test pre-flight response - headers = { - "Origin": "https://example.org", - "Access-Control-Request-Method": "POST", - } - response = client.options( - "/", - headers=headers, - ) + headers = {"Origin": "https://example.org", "Access-Control-Request-Method": "POST"} + response = client.options("/", headers=headers) assert response.status_code == 200 assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-credentials"] == "true" @@ -264,10 +252,7 @@ def homepage(request: Request) -> None: client = test_client_factory(app) - headers = { - "Origin": "https://example.org", - "Access-Control-Request-Method": "POST", - } + headers = {"Origin": "https://example.org", "Access-Control-Request-Method": "POST"} for method in ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"): response = client.options("/", headers=headers) @@ -275,9 +260,7 @@ def homepage(request: Request) -> None: assert method in response.headers["access-control-allow-methods"] -def test_cors_allow_all_methods( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_all_methods(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -306,9 +289,7 @@ def homepage(request: Request) -> PlainTextResponse: assert response.status_code == 200 -def test_cors_allow_origin_regex( - test_client_factory: TestClientFactory, -) -> None: +def test_cors_allow_origin_regex(test_client_factory: TestClientFactory) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) @@ -506,9 +487,7 @@ def homepage(request: Request) -> PlainTextResponse: ) app = Starlette( - routes=[ - Route("/", endpoint=homepage), - ], + routes=[Route("/", endpoint=homepage)], middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])], ) client = test_client_factory(app) @@ -525,9 +504,7 @@ def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( - routes=[ - Route("/", endpoint=homepage), - ], + routes=[Route("/", endpoint=homepage)], middleware=[ Middleware( CORSMiddleware, diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index a2dbabd8a..f60b9f0b0 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -14,9 +14,7 @@ TestClientFactory = Callable[..., TestClient] -def test_handler( - test_client_factory: TestClientFactory, -) -> None: +def test_handler(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index 5bfecadb7..0462ef886 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -17,8 +17,7 @@ def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("x" * 4000, status_code=200) app = Starlette( - routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(GZipMiddleware)], + routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)] ) client = test_client_factory(app) @@ -34,8 +33,7 @@ def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("x" * 4000, status_code=200) app = Starlette( - routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(GZipMiddleware)], + routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)] ) client = test_client_factory(app) @@ -53,8 +51,7 @@ def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK", status_code=200) app = Starlette( - routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(GZipMiddleware)], + routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)] ) client = test_client_factory(app) @@ -75,8 +72,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream: return StreamingResponse(streaming, status_code=200) app = Starlette( - routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(GZipMiddleware)], + routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)] ) client = test_client_factory(app) @@ -101,8 +97,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream: ) app = Starlette( - routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(GZipMiddleware)], + routes=[Route("/", endpoint=homepage)], middleware=[Middleware(GZipMiddleware)] ) client = test_client_factory(app) diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 4fbeec88c..5cfa337a4 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -125,9 +125,7 @@ def test_secure_session(test_client_factory: TestClientFactory) -> None: def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None: second_app = Starlette( - routes=[ - Route("/update_session", endpoint=update_session, methods=["POST"]), - ], + routes=[Route("/update_session", endpoint=update_session, methods=["POST"])], middleware=[ Middleware(SessionMiddleware, secret_key="example", path="/second_app") ], diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 69842d3ad..e4361c703 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -13,10 +13,7 @@ Environment = Dict[str, Any] -def hello_world( - environ: Environment, - start_response: StartResponse, -) -> WSGIResponse: +def hello_world(environ: Environment, start_response: StartResponse) -> WSGIResponse: status = "200 OK" output = b"Hello World!\n" headers = [ @@ -27,10 +24,7 @@ def hello_world( return [output] -def echo_body( - environ: Environment, - start_response: StartResponse, -) -> WSGIResponse: +def echo_body(environ: Environment, start_response: StartResponse) -> WSGIResponse: status = "200 OK" output = environ["wsgi.input"].read() headers = [ @@ -42,15 +36,13 @@ def echo_body( def raise_exception( - environ: Environment, - start_response: StartResponse, + environ: Environment, start_response: StartResponse ) -> WSGIResponse: raise RuntimeError("Something went wrong") def return_exc_info( - environ: Environment, - start_response: StartResponse, + environ: Environment, start_response: StartResponse ) -> WSGIResponse: try: raise RuntimeError("Something went wrong") diff --git a/tests/test__utils.py b/tests/test__utils.py index 06fece58b..12ba59ced 100644 --- a/tests/test__utils.py +++ b/tests/test__utils.py @@ -57,19 +57,11 @@ def __call__(self) -> None: def test_async_partial_object_call() -> None: class Async: - async def __call__( - self, - a: Any, - b: Any, - ) -> None: + async def __call__(self, a: Any, b: Any) -> None: ... # pragma: no cover class Sync: - def __call__( - self, - a: Any, - b: Any, - ) -> None: + def __call__(self, a: Any, b: Any) -> None: ... # pragma: no cover partial = functools.partial(Async(), 1) @@ -80,10 +72,7 @@ def __call__( def test_async_nested_partial() -> None: - async def async_func( - a: Any, - b: Any, - ) -> None: + async def async_func(a: Any, b: Any) -> None: ... # pragma: no cover partial = functools.partial(async_func, b=2) diff --git a/tests/test_applications.py b/tests/test_applications.py index 5b6c9d545..59075d74b 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -97,11 +97,7 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> ] ) -subdomain = Router( - routes=[ - Route("/", custom_subdomain), - ] -) +subdomain = Router(routes=[Route("/", custom_subdomain)]) exception_handlers = { 500: error_500, @@ -269,11 +265,7 @@ def test_app_mount(tmpdir: Path, test_client_factory: TestClientFactory) -> None with open(path, "w") as file: file.write("") - app = Starlette( - routes=[ - Mount("/static", StaticFiles(directory=tmpdir)), - ] - ) + app = Starlette(routes=[Mount("/static", StaticFiles(directory=tmpdir))]) client = test_client_factory(app) @@ -290,11 +282,7 @@ def test_app_debug(test_client_factory: TestClientFactory) -> None: async def homepage(request: Request) -> None: raise RuntimeError() - app = Starlette( - routes=[ - Route("/", homepage), - ], - ) + app = Starlette(routes=[Route("/", homepage)]) app.debug = True client = test_client_factory(app, raise_server_exceptions=False) @@ -308,11 +296,7 @@ def test_app_add_route(test_client_factory: TestClientFactory) -> None: async def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Hello, World!") - app = Starlette( - routes=[ - Route("/", endpoint=homepage), - ] - ) + app = Starlette(routes=[Route("/", endpoint=homepage)]) client = test_client_factory(app) response = client.get("/") @@ -326,11 +310,7 @@ async def websocket_endpoint(session: WebSocket) -> None: await session.send_text("Hello, world!") await session.close() - app = Starlette( - routes=[ - WebSocketRoute("/ws", endpoint=websocket_endpoint), - ] - ) + app = Starlette(routes=[WebSocketRoute("/ws", endpoint=websocket_endpoint)]) client = test_client_factory(app) with client.websocket_connect("/ws") as session: @@ -353,10 +333,7 @@ def run_cleanup() -> None: with pytest.deprecated_call( match="The on_startup and on_shutdown parameters are deprecated" ): - app = Starlette( - on_startup=[run_startup], - on_shutdown=[run_cleanup], - ) + app = Starlette(on_startup=[run_startup], on_shutdown=[run_cleanup]) assert not startup_complete assert not cleanup_complete diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 27b033762..7718f5308 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -29,8 +29,7 @@ class BasicAuth(AuthenticationBackend): async def authenticate( - self, - request: HTTPConnection, + self, request: HTTPConnection ) -> Optional[Tuple[AuthCredentials, SimpleUser]]: if "Authorization" not in request.headers: return None diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 401ad8212..465560d78 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -118,8 +118,7 @@ def test_websockets_should_raise(client: TestClient) -> None: def test_handled_exc_after_response( - test_client_factory: TestClientFactory, - client: TestClient, + test_client_factory: TestClientFactory, client: TestClient ) -> None: # A 406 HttpException is raised *after* the response has already been sent. # The exception middleware should raise a RuntimeError. diff --git a/tests/test_requests.py b/tests/test_requests.py index b3ce3a04a..d9fe65cc0 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -238,8 +238,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_request_disconnect( - anyio_backend_name: str, - anyio_backend_options: Dict[str, Any], + anyio_backend_name: str, anyio_backend_options: Dict[str, Any] ) -> None: """ If a client disconnect occurs while reading request body @@ -390,9 +389,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ], ) def test_cookies_edge_cases( - set_cookie: str, - expected: Dict[str, str], - test_client_factory: TestClientFactory, + set_cookie: str, expected: Dict[str, str], test_client_factory: TestClientFactory ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) @@ -429,9 +426,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ], ) def test_cookies_invalid( - set_cookie: str, - expected: Dict[str, str], - test_client_factory: TestClientFactory, + set_cookie: str, expected: Dict[str, str], test_client_factory: TestClientFactory ) -> None: """ Cookie strings that are against the RFC6265 spec but which browsers will send if set diff --git a/tests/test_routing.py b/tests/test_routing.py index 8c3f16639..3a7942cc7 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -343,8 +343,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) app = Router( - routes=[Route("/", homepage)], - middleware=[Middleware(CustomMiddleware)], + routes=[Route("/", homepage)], middleware=[Middleware(CustomMiddleware)] ) client = test_client_factory(app) @@ -585,7 +584,7 @@ async def stub_app(scope: Scope, receive: Receive, send: Send) -> None: double_mount_routes = [ - Mount("/mount", name="mount", routes=[Mount("/static", stub_app, name="static")]), + Mount("/mount", name="mount", routes=[Mount("/static", stub_app, name="static")]) ] @@ -595,9 +594,7 @@ def test_url_for_with_double_mount() -> None: assert url == "/mount/static/123" -def test_standalone_route_matches( - test_client_factory: TestClientFactory, -) -> None: +def test_standalone_route_matches(test_client_factory: TestClientFactory) -> None: app = Route("/", PlainTextResponse("Hello, World!")) client = test_client_factory(app) response = client.get("/") @@ -621,9 +618,7 @@ async def ws_helloworld(websocket: WebSocket) -> None: await websocket.close() -def test_standalone_ws_route_matches( - test_client_factory: TestClientFactory, -) -> None: +def test_standalone_ws_route_matches(test_client_factory: TestClientFactory) -> None: app = WebSocketRoute("/", ws_helloworld) client = test_client_factory(app) with client.websocket_connect("/") as websocket: @@ -756,9 +751,7 @@ def run_shutdown() -> None: assert shutdown_complete -def test_lifespan_state_unsupported( - test_client_factory: TestClientFactory, -) -> None: +def test_lifespan_state_unsupported(test_client_factory: TestClientFactory) -> None: @contextlib.asynccontextmanager async def lifespan( app: ASGIApp, @@ -766,8 +759,7 @@ async def lifespan( yield {"foo": "bar"} app = Router( - lifespan=lifespan, - routes=[Mount("/", PlainTextResponse("hello, world"))], + lifespan=lifespan, routes=[Mount("/", PlainTextResponse("hello, world"))] ) async def no_state_wrapper(scope: Scope, receive: Receive, send: Send) -> None: @@ -812,10 +804,7 @@ async def lifespan(app: Starlette) -> typing.AsyncIterator[State]: # via state assert state["items"] == [1, 1] - app = Router( - lifespan=lifespan, - routes=[Route("/", hello_world)], - ) + app = Router(lifespan=lifespan, routes=[Route("/", hello_world)]) assert not startup_complete assert not shutdown_complete @@ -879,9 +868,7 @@ def test_partial_async_endpoint(test_client_factory: TestClientFactory) -> None: assert cls_method_response.json() == {"arg": "foo"} -def test_partial_async_ws_endpoint( - test_client_factory: TestClientFactory, -) -> None: +def test_partial_async_ws_endpoint(test_client_factory: TestClientFactory) -> None: test_client = test_client_factory(app) with test_client.websocket_connect("/partial/ws") as websocket: data = websocket.receive_json() @@ -893,10 +880,7 @@ def test_partial_async_ws_endpoint( def test_duplicated_param_names() -> None: - with pytest.raises( - ValueError, - match="Duplicated param name id at path /{id}/{id}", - ): + with pytest.raises(ValueError, match="Duplicated param name id at path /{id}/{id}"): Route("/{id}/{id}", user) with pytest.raises( @@ -928,11 +912,7 @@ def __call__(self, request: Request) -> None: pytest.param(func_homepage, "func_homepage", id="function"), pytest.param(Endpoint().my_method, "my_method", id="method"), pytest.param(Endpoint.my_classmethod, "my_classmethod", id="classmethod"), - pytest.param( - Endpoint.my_staticmethod, - "my_staticmethod", - id="staticmethod", - ), + pytest.param(Endpoint.my_staticmethod, "my_staticmethod", id="staticmethod"), pytest.param(Endpoint(), "Endpoint", id="object"), pytest.param(lambda request: ..., "", id="lambda"), ], @@ -985,7 +965,7 @@ def assert_middleware_header_route(request: Request) -> Response: endpoint=assert_middleware_header_route, methods=["GET"], name="route", - ), + ) ], middleware=[Middleware(AddHeadersMiddleware)], ), @@ -1020,8 +1000,7 @@ def assert_middleware_header_route(request: Request) -> Response: ], ) def test_base_route_middleware( - test_client_factory: TestClientFactory, - app: Starlette, + test_client_factory: TestClientFactory, app: Starlette ) -> None: test_client = test_client_factory(app) @@ -1055,19 +1034,13 @@ def test_add_route_to_app_after_mount( """ inner_app = Router() app = Mount("/http", app=inner_app) - inner_app.add_route( - "/inner", - endpoint=homepage, - methods=["GET"], - ) + inner_app.add_route("/inner", endpoint=homepage, methods=["GET"]) client = test_client_factory(app) response = client.get("/http/inner") assert response.status_code == 200 -def test_exception_on_mounted_apps( - test_client_factory: TestClientFactory, -) -> None: +def test_exception_on_mounted_apps(test_client_factory: TestClientFactory) -> None: def exc(request: Request) -> None: raise Exception("Exc") @@ -1104,10 +1077,7 @@ async def modified_send(msg: Message) -> None: routes=[ Mount( "/mount", - routes=[ - Route("/err", exc), - Route("/home", homepage), - ], + routes=[Route("/err", exc), Route("/home", homepage)], middleware=[Middleware(NamedMiddleware, name="Mounted")], ), Route("/err", exc), @@ -1135,9 +1105,7 @@ async def modified_send(msg: Message) -> None: assert "X-Mounted" in resp.headers -def test_websocket_route_middleware( - test_client_factory: TestClientFactory, -) -> None: +def test_websocket_route_middleware(test_client_factory: TestClientFactory) -> None: async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") @@ -1192,51 +1160,25 @@ def test_websocket_route_repr() -> None: def test_mount_repr() -> None: - route = Mount( - "/app", - routes=[ - Route("/", endpoint=homepage), - ], - ) + route = Mount("/app", routes=[Route("/", endpoint=homepage)]) # test for substring because repr(Router) returns unique object ID assert repr(route).startswith("Mount(path='/app', name='', app=") def test_mount_named_repr() -> None: - route = Mount( - "/app", - name="app", - routes=[ - Route("/", endpoint=homepage), - ], - ) + route = Mount("/app", name="app", routes=[Route("/", endpoint=homepage)]) # test for substring because repr(Router) returns unique object ID assert repr(route).startswith("Mount(path='/app', name='app', app=") def test_host_repr() -> None: - route = Host( - "example.com", - app=Router( - [ - Route("/", endpoint=homepage), - ] - ), - ) + route = Host("example.com", app=Router([Route("/", endpoint=homepage)])) # test for substring because repr(Router) returns unique object ID assert repr(route).startswith("Host(host='example.com', name='', app=") def test_host_named_repr() -> None: - route = Host( - "example.com", - name="app", - app=Router( - [ - Route("/", endpoint=homepage), - ] - ), - ) + route = Host("example.com", name="app", app=Router([Route("/", endpoint=homepage)])) # test for substring because repr(Router) returns unique object ID assert repr(route).startswith("Host(host='example.com', name='app', app=") @@ -1300,7 +1242,7 @@ async def pure_asgi_echo_paths( functools.partial(echo_paths, name="subpath"), name="subpath", methods=["GET"], - ), + ) ], ), ] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index e00b2b8de..84a70d42c 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -113,11 +113,7 @@ def schema(request: Request) -> Response: return schemas.OpenAPIResponse(request=request) -subapp = Starlette( - routes=[ - Route("/subapp-endpoint", endpoint=subapp_endpoint), - ] -) +subapp = Starlette(routes=[Route("/subapp-endpoint", endpoint=subapp_endpoint)]) app = Starlette( routes=[ @@ -199,12 +195,9 @@ def test_schema_generation() -> None: "/users/{id}": { "get": { "responses": { - 200: { - "description": "A user.", - "examples": {"username": "tom"}, - } + 200: {"description": "A user.", "examples": {"username": "tom"}} } - }, + } }, }, } diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index d20bb7ef7..ffa730ed8 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -421,11 +421,7 @@ def test_staticfiles_with_invalid_dir_permissions_returns_401( tmp_path.chmod(stat.S_IRWXO) try: routes = [ - Mount( - "/", - app=StaticFiles(directory=os.fsdecode(tmp_path)), - name="static", - ) + Mount("/", app=StaticFiles(directory=os.fsdecode(tmp_path)), name="static") ] app = Starlette(routes=routes) client = test_client_factory(app) diff --git a/tests/test_templates.py b/tests/test_templates.py index ab0b38a91..544fcf5be 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -49,15 +49,9 @@ async def homepage(request: Request) -> Response: def hello_world_processor(request: Request) -> typing.Dict[str, str]: return {"username": "World"} - app = Starlette( - debug=True, - routes=[Route("/", endpoint=homepage)], - ) + app = Starlette(debug=True, routes=[Route("/", endpoint=homepage)]) templates = Jinja2Templates( - directory=tmp_path, - context_processors=[ - hello_world_processor, - ], + directory=tmp_path, context_processors=[hello_world_processor] ) client = test_client_factory(app) @@ -117,8 +111,7 @@ async def page_b(request: Request) -> Response: return templates.TemplateResponse(request, "template_b.html") app = Starlette( - debug=True, - routes=[Route("/a", endpoint=page_a), Route("/b", endpoint=page_b)], + debug=True, routes=[Route("/a", endpoint=page_a), Route("/b", endpoint=page_b)] ) templates = Jinja2Templates(directory=[dir_a, dir_b]) @@ -162,10 +155,7 @@ async def homepage(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") env = jinja2.Environment(loader=jinja2.FileSystemLoader(str(tmpdir))) - app = Starlette( - debug=True, - routes=[Route("/", endpoint=homepage)], - ) + app = Starlette(debug=True, routes=[Route("/", endpoint=homepage)]) templates = Jinja2Templates(env=env) client = test_client_factory(app) response = client.get("/") @@ -216,10 +206,7 @@ def test_templates_with_kwargs_only_requires_request_in_context(tmpdir: Path) -> # MAINTAINERS: remove after 1.0 templates = Jinja2Templates(directory=str(tmpdir)) - with pytest.warns( - DeprecationWarning, - match="requires the `request` argument", - ): + with pytest.warns(DeprecationWarning, match="requires the `request` argument"): with pytest.raises(ValueError): templates.TemplateResponse(name="index.html", context={"a": "b"}) @@ -243,10 +230,7 @@ def page(request: Request) -> Response: app = Starlette(routes=[Route("/", page)]) client = test_client_factory(app) - with pytest.warns( - DeprecationWarning, - match="requires the `request` argument", - ): + with pytest.warns(DeprecationWarning, match="requires the `request` argument"): client.get("/") diff --git a/tests/test_testclient.py b/tests/test_testclient.py index e8956cd30..ea1c5c24d 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -127,8 +127,7 @@ async def loop_id(request: Request) -> JSONResponse: return JSONResponse(get_identity()) app = Starlette( - lifespan=lifespan_context, - routes=[Route("/loop_id", endpoint=loop_id)], + lifespan=lifespan_context, routes=[Route("/loop_id", endpoint=loop_id)] ) client = test_client_factory(app) @@ -296,7 +295,7 @@ def homepage(request: Request) -> Response: sys.version_info < (3, 11), reason="Fails due to domain handling in http.cookiejar module (see " "#2152)", - ), + ) ], ), ("testserver.local", True), @@ -318,12 +317,7 @@ def test_domain_restricted_cookies( async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") - response.set_cookie( - "mycookie", - "myvalue", - path="/", - domain=domain, - ) + response.set_cookie("mycookie", "myvalue", path="/", domain=domain) await response(scope, receive, send) client = test_client_factory(app) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index c4b6c16bd..57af51532 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -336,18 +336,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: } ) await websocket.send( - { - "type": "websocket.http.response.body", - "body": b"hard", - "more_body": True, - } - ) - await websocket.send( - { - "type": "websocket.http.response.body", - "body": b"body", - } + {"type": "websocket.http.response.body", "body": b"hard", "more_body": True} ) + await websocket.send({"type": "websocket.http.response.body", "body": b"body"}) client = test_client_factory(app) with pytest.raises(WebSocketDenialResponse) as exc: