From 1a9f63ce81046acc7e9176541c8fa0ff832eba7f Mon Sep 17 00:00:00 2001 From: Scirlat Danut Date: Sat, 3 Feb 2024 18:54:21 +0200 Subject: [PATCH] Add type hints to `test_base.py` (#2445) * added type annotations to test-base.py * deleted unused imports * fixed import order * conditional import * conditional import TestClient on types.py * using string literals when importing in types * added missing imports * deleted starlette/types, refactored test_base types * deleted types --------- Co-authored-by: Scirlat Danut Co-authored-by: Marcelo Trylesinski --- tests/middleware/test_base.py | 292 ++++++++++++++++++++++++---------- 1 file changed, 208 insertions(+), 84 deletions(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 012e2c1ff..6e5e42b94 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,9 +1,18 @@ import contextvars from contextlib import AsyncExitStack -from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union +from typing import ( + Any, + AsyncGenerator, + Callable, + Generator, + List, + Type, + Union, +) import anyio import pytest +from anyio.abc import TaskStatus from starlette.applications import Starlette from starlette.background import BackgroundTask @@ -14,44 +23,56 @@ from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.websockets import WebSocket + +TestClientFactory = Callable[[ASGIApp], TestClient] class CustomMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: response = await call_next(request) response.headers["Custom-Header"] = "Example" return response -def homepage(request): +def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage") -def exc(request): +def exc(request: Request) -> None: raise Exception("Exc") -def exc_stream(request): +def exc_stream(request: Request) -> StreamingResponse: return StreamingResponse(_generate_faulty_stream()) -def _generate_faulty_stream(): +def _generate_faulty_stream() -> Generator[bytes, None, None]: yield b"Ok" raise Exception("Faulty Stream") class NoResponse: - def __init__(self, scope, receive, send): + def __init__( + self, + scope: Scope, + receive: Receive, + send: Send, + ): pass - def __await__(self): + def __await__(self) -> Generator[Any, None, None]: return self.dispatch().__await__() - async def dispatch(self): + async def dispatch(self) -> None: pass -async def websocket_endpoint(session): +async def websocket_endpoint(session: WebSocket) -> None: await session.accept() await session.send_text("Hello, world!") await session.close() @@ -69,7 +90,7 @@ async def websocket_endpoint(session): ) -def test_custom_middleware(test_client_factory): +def test_custom_middleware(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -90,30 +111,44 @@ def test_custom_middleware(test_client_factory): assert text == "Hello, world!" -def test_state_data_across_multiple_middlewares(test_client_factory): +def test_state_data_across_multiple_middlewares( + test_client_factory: TestClientFactory, +) -> None: expected_value1 = "foo" expected_value2 = "bar" class aMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: request.state.foo = expected_value1 response = await call_next(request) return response class bMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: request.state.bar = expected_value2 response = await call_next(request) response.headers["X-State-Foo"] = request.state.foo return response class cMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: response = await call_next(request) response.headers["X-State-Bar"] = request.state.bar return response - def homepage(request): + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("OK") app = Starlette( @@ -132,8 +167,8 @@ def homepage(request): assert response.headers["X-State-Bar"] == expected_value2 -def test_app_middleware_argument(test_client_factory): - def homepage(request): +def test_app_middleware_argument(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage") app = Starlette( @@ -145,10 +180,14 @@ def homepage(request): assert response.headers["Custom-Header"] == "Example" -def test_fully_evaluated_response(test_client_factory): +def test_fully_evaluated_response(test_client_factory: TestClientFactory) -> None: # Test for https://github.com/encode/starlette/issues/1022 class CustomMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> PlainTextResponse: await call_next(request) return PlainTextResponse("Custom") @@ -173,7 +212,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: ctxvar.set("set by middleware") resp = await call_next(request) assert ctxvar.get() == "set by endpoint" @@ -196,11 +239,14 @@ async def dispatch(self, request, call_next): ), ], ) -def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]): +def test_contextvars( + test_client_factory: TestClientFactory, + middleware_cls: Type[_MiddlewareClass[Any]], +) -> None: # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating # contextvars (it propagates them forwards but not backwards) - async def homepage(request): + async def homepage(request: Request) -> PlainTextResponse: assert ctxvar.get() == "set by middleware" ctxvar.set("set by endpoint") return PlainTextResponse("Homepage") @@ -215,23 +261,26 @@ async def homepage(request): @pytest.mark.anyio -async def test_run_background_tasks_even_if_client_disconnects(): +async def test_run_background_tasks_even_if_client_disconnects() -> None: # test for https://github.com/encode/starlette/issues/1438 request_body_sent = False response_complete = anyio.Event() background_task_run = anyio.Event() - async def sleep_and_set(): + async def sleep_and_set() -> None: # small delay to give BaseHTTPMiddleware a chance to cancel us # this is required to make the test fail prior to fixing the issue # so do not be surprised if you remove it and the test still passes await anyio.sleep(0.1) background_task_run.set() - async def endpoint_with_background_task(_): + async def endpoint_with_background_task(_: Request) -> PlainTextResponse: return PlainTextResponse(background=BackgroundTask(sleep_and_set)) - async def passthrough(request, call_next): + async def passthrough( + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: return await call_next(request) app = Starlette( @@ -246,7 +295,7 @@ async def passthrough(request, call_next): "path": "/", } - async def receive(): + async def receive() -> Message: nonlocal request_body_sent if not request_body_sent: request_body_sent = True @@ -255,7 +304,7 @@ async def receive(): await response_complete.wait() return {"type": "http.disconnect"} - async def send(message): + async def send(message: Message) -> None: if message["type"] == "http.response.body": if not message.get("more_body", False): response_complete.set() @@ -266,23 +315,23 @@ async def send(message): @pytest.mark.anyio -async def test_do_not_block_on_background_tasks(): +async def test_do_not_block_on_background_tasks() -> None: request_body_sent = False response_complete = anyio.Event() events: List[Union[str, Message]] = [] - async def sleep_and_set(): + async def sleep_and_set() -> None: events.append("Background task started") await anyio.sleep(0.1) events.append("Background task finished") - async def endpoint_with_background_task(_): + async def endpoint_with_background_task(_: Request) -> PlainTextResponse: return PlainTextResponse( content="Hello", background=BackgroundTask(sleep_and_set) ) async def passthrough( - request: Request, call_next: Callable[[Request], Awaitable[Response]] + request: Request, call_next: RequestResponseEndpoint ) -> Response: return await call_next(request) @@ -306,7 +355,7 @@ async def receive() -> Message: await response_complete.wait() return {"type": "http.disconnect"} - async def send(message: Message): + async def send(message: Message) -> None: if message["type"] == "http.response.body": events.append(message) if not message.get("more_body", False): @@ -331,13 +380,13 @@ async def send(message: Message): @pytest.mark.anyio -async def test_run_context_manager_exit_even_if_client_disconnects(): +async def test_run_context_manager_exit_even_if_client_disconnects() -> None: # test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042 request_body_sent = False response_complete = anyio.Event() context_manager_exited = anyio.Event() - async def sleep_and_set(): + async def sleep_and_set() -> None: # small delay to give BaseHTTPMiddleware a chance to cancel us # this is required to make the test fail prior to fixing the issue # so do not be surprised if you remove it and the test still passes @@ -345,18 +394,21 @@ async def sleep_and_set(): context_manager_exited.set() class ContextManagerMiddleware: - def __init__(self, app): + def __init__(self, app: ASGIApp): self.app = app - async def __call__(self, scope: Scope, receive: Receive, send: Send): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with AsyncExitStack() as stack: stack.push_async_callback(sleep_and_set) await self.app(scope, receive, send) - async def simple_endpoint(_): + async def simple_endpoint(_: Request) -> PlainTextResponse: return PlainTextResponse(background=BackgroundTask(sleep_and_set)) - async def passthrough(request, call_next): + async def passthrough( + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: return await call_next(request) app = Starlette( @@ -374,7 +426,7 @@ async def passthrough(request, call_next): "path": "/", } - async def receive(): + async def receive() -> Message: nonlocal request_body_sent if not request_body_sent: request_body_sent = True @@ -383,7 +435,7 @@ async def receive(): await response_complete.wait() return {"type": "http.disconnect"} - async def send(message): + async def send(message: Message) -> None: if message["type"] == "http.response.body": if not message.get("more_body", False): response_complete.set() @@ -393,9 +445,15 @@ async def send(message): assert context_manager_exited.is_set() -def test_app_receives_http_disconnect_while_sending_if_discarded(test_client_factory): +def test_app_receives_http_disconnect_while_sending_if_discarded( + test_client_factory: TestClientFactory, +) -> None: class DiscardingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, + request: Request, + call_next: Any, + ) -> PlainTextResponse: # As a matter of ordering, this test targets the case where the downstream # app response is discarded while it is sending a response body. # We need to wait for the downstream app to begin sending a response body @@ -410,7 +468,11 @@ async def dispatch(self, request, call_next): return PlainTextResponse("Custom") - async def downstream_app(scope, receive, send): + async def downstream_app( + scope: Scope, + receive: Receive, + send: Send, + ) -> None: await send( { "type": "http.response.start", @@ -422,7 +484,10 @@ async def downstream_app(scope, receive, send): ) async with anyio.create_task_group() as task_group: - async def cancel_on_disconnect(*, task_status=anyio.TASK_STATUS_IGNORED): + async def cancel_on_disconnect( + *, + task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: task_status.started() while True: message = await receive() @@ -458,13 +523,23 @@ async def cancel_on_disconnect(*, task_status=anyio.TASK_STATUS_IGNORED): assert response.text == "Custom" -def test_app_receives_http_disconnect_after_sending_if_discarded(test_client_factory): +def test_app_receives_http_disconnect_after_sending_if_discarded( + test_client_factory: TestClientFactory, +) -> None: class DiscardingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request, call_next): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> PlainTextResponse: await call_next(request) return PlainTextResponse("Custom") - async def downstream_app(scope, receive, send): + async def downstream_app( + scope: Scope, + receive: Receive, + send: Send, + ) -> None: await send( { "type": "http.response.start", @@ -499,9 +574,9 @@ async def downstream_app(scope, receive, send): def test_read_request_stream_in_app_after_middleware_calls_stream( - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: - async def homepage(request: Request): + async def homepage(request: Request) -> PlainTextResponse: expected = [b""] async for chunk in request.stream(): assert chunk == expected.pop(0) @@ -509,7 +584,11 @@ async def homepage(request: Request): return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: expected = [b"a", b""] async for chunk in request.stream(): assert chunk == expected.pop(0) @@ -527,9 +606,9 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): def test_read_request_stream_in_app_after_middleware_calls_body( - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: - async def homepage(request: Request): + async def homepage(request: Request) -> PlainTextResponse: expected = [b"a", b""] async for chunk in request.stream(): assert chunk == expected.pop(0) @@ -537,7 +616,11 @@ async def homepage(request: Request): return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: assert await request.body() == b"a" return await call_next(request) @@ -552,14 +635,18 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): def test_read_request_body_in_app_after_middleware_calls_stream( - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: - async def homepage(request: Request): + async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: expected = [b"a", b""] async for chunk in request.stream(): assert chunk == expected.pop(0) @@ -577,14 +664,18 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): def test_read_request_body_in_app_after_middleware_calls_body( - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: - async def homepage(request: Request): + async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: assert await request.body() == b"a" return await call_next(request) @@ -599,9 +690,9 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): def test_read_request_stream_in_dispatch_after_app_calls_stream( - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: - async def homepage(request: Request): + async def homepage(request: Request) -> PlainTextResponse: expected = [b"a", b""] async for chunk in request.stream(): assert chunk == expected.pop(0) @@ -609,7 +700,11 @@ async def homepage(request: Request): return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: resp = await call_next(request) with pytest.raises(RuntimeError, match="Stream consumed"): async for _ in request.stream(): @@ -627,14 +722,18 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): def test_read_request_stream_in_dispatch_after_app_calls_body( - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: - async def homepage(request: Request): + async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: resp = await call_next(request) with pytest.raises(RuntimeError, match="Stream consumed"): async for _ in request.stream(): @@ -661,7 +760,11 @@ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: await Response()(scope, receive, send) class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: expected = b"1" response: Union[Response, None] = None async for chunk in request.stream(): @@ -705,14 +808,18 @@ async def send(msg: Message) -> None: def test_read_request_stream_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501 - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: - async def homepage(request: Request): + async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: assert ( await request.body() == b"a" ) # this buffers the request body in memory @@ -733,14 +840,18 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): def test_read_request_body_in_dispatch_after_app_calls_body_with_middleware_calling_body_before_call_next( # noqa: E501 - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: - async def homepage(request: Request): + async def homepage(request: Request) -> PlainTextResponse: assert await request.body() == b"a" return PlainTextResponse("Homepage") class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: assert ( await request.body() == b"a" ) # this buffers the request body in memory @@ -773,7 +884,11 @@ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: await Response()(scope, receive, send) class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: response = await call_next(request) disconnected = await request.is_disconnected() assert disconnected is True @@ -785,7 +900,7 @@ async def receive() -> AsyncGenerator[Message, None]: yield {"type": "http.disconnect"} raise AssertionError("Should not be called, would hang") # pragma: no cover - async def send(msg: Message): + async def send(msg: Message) -> None: if msg["type"] == "http.response.start": assert msg["status"] == 200 @@ -809,7 +924,11 @@ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: await Response()(scope, receive, send) class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: await request.body() disconnected = await request.is_disconnected() assert disconnected is True @@ -823,7 +942,7 @@ async def receive() -> AsyncGenerator[Message, None]: yield {"type": "http.disconnect"} raise AssertionError("Should not be called, would hang") # pragma: no cover - async def send(msg: Message): + async def send(msg: Message) -> None: if msg["type"] == "http.response.start": assert msg["status"] == 200 @@ -837,7 +956,7 @@ async def send(msg: Message): def test_downstream_middleware_modifies_receive( - test_client_factory: Callable[[ASGIApp], TestClient], + test_client_factory: TestClientFactory, ) -> None: """If a downstream middleware modifies receive() the final ASGI app should see the modified version. @@ -850,7 +969,11 @@ async def endpoint(scope: Scope, receive: Receive, send: Send) -> None: await Response()(scope, receive, send) class ConsumingMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: body = await request.body() assert body == b"foo " return await call_next(request) @@ -873,9 +996,6 @@ async def wrapped_receive() -> Message: assert resp.status_code == 200 -CallNext = Callable[[Request], Awaitable[Response]] - - def test_pr_1519_comment_1236166180_example() -> None: """ https://github.com/encode/starlette/pull/1519#issuecomment-1236166180 @@ -883,7 +1003,11 @@ def test_pr_1519_comment_1236166180_example() -> None: bodies: List[bytes] = [] class LogRequestBodySize(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: CallNext) -> Response: + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: print(len(await request.body())) return await call_next(request)