From 6a9ad192461871a70543a4a7802026df5293f450 Mon Sep 17 00:00:00 2001 From: Scirlat Danut Date: Sat, 3 Feb 2024 22:51:19 +0200 Subject: [PATCH] Add type hints to `test_cors.py` (#2458) Co-authored-by: Scirlat Danut --- tests/middleware/test_cors.py | 97 +++++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 33 deletions(-) diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index ca3d4f47b..09ec9513f 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -1,12 +1,21 @@ +from typing import Callable + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route +from starlette.testclient import TestClient +from starlette.types import ASGIApp + +TestClientFactory = Callable[[ASGIApp], TestClient] -def test_cors_allow_all(test_client_factory): - def homepage(request): +def test_cors_allow_all( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( @@ -64,8 +73,10 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_all_except_credentials(test_client_factory): - def homepage(request): +def test_cors_allow_all_except_credentials( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( @@ -113,8 +124,10 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_specific_origin(test_client_factory): - def homepage(request): +def test_cors_allow_specific_origin( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( @@ -160,8 +173,10 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_disallowed_preflight(test_client_factory): - def homepage(request): +def test_cors_disallowed_preflight( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> None: pass # pragma: no cover app = Starlette( @@ -200,9 +215,9 @@ def homepage(request): def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed( - test_client_factory, -): - def homepage(request): + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> None: return # pragma: no cover app = Starlette( @@ -234,8 +249,10 @@ def homepage(request): assert response.headers["vary"] == "Origin" -def test_cors_preflight_allow_all_methods(test_client_factory): - def homepage(request): +def test_cors_preflight_allow_all_methods( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> None: pass # pragma: no cover app = Starlette( @@ -258,8 +275,10 @@ def homepage(request): assert method in response.headers["access-control-allow-methods"] -def test_cors_allow_all_methods(test_client_factory): - def homepage(request): +def test_cors_allow_all_methods( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( @@ -287,8 +306,10 @@ def homepage(request): assert response.status_code == 200 -def test_cors_allow_origin_regex(test_client_factory): - def homepage(request): +def test_cors_allow_origin_regex( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( @@ -357,8 +378,10 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_allow_origin_regex_fullmatch(test_client_factory): - def homepage(request): +def test_cors_allow_origin_regex_fullmatch( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( @@ -393,8 +416,10 @@ def homepage(request): assert "access-control-allow-origin" not in response.headers -def test_cors_credentialed_requests_return_specific_origin(test_client_factory): - def homepage(request): +def test_cors_credentialed_requests_return_specific_origin( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( @@ -412,8 +437,10 @@ def homepage(request): assert "access-control-allow-credentials" not in response.headers -def test_cors_vary_header_defaults_to_origin(test_client_factory): - def homepage(request): +def test_cors_vary_header_defaults_to_origin( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette( @@ -430,8 +457,10 @@ def homepage(request): assert response.headers["vary"] == "Origin" -def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory): - def homepage(request): +def test_cors_vary_header_is_not_set_for_non_credentialed_request( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse( "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) @@ -447,8 +476,10 @@ def homepage(request): assert response.headers["vary"] == "Accept-Encoding" -def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory): - def homepage(request): +def test_cors_vary_header_is_properly_set_for_credentialed_request( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse( "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) @@ -467,9 +498,9 @@ def homepage(request): def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard( - test_client_factory, -): - def homepage(request): + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse( "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} ) @@ -488,9 +519,9 @@ def homepage(request): def test_cors_allowed_origin_does_not_leak_between_credentialed_requests( - test_client_factory, -): - def homepage(request): + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("Homepage", status_code=200) app = Starlette(