Skip to content

Commit

Permalink
Add type hints to test_cors.py (encode#2458)
Browse files Browse the repository at this point in the history
Co-authored-by: Scirlat Danut <[email protected]>
  • Loading branch information
2 people authored and Rocky Allen committed Mar 18, 2024
1 parent 1a9f63c commit 6a9ad19
Showing 1 changed file with 64 additions and 33 deletions.
97 changes: 64 additions & 33 deletions tests/middleware/test_cors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import Callable

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.types import ASGIApp

TestClientFactory = Callable[[ASGIApp], TestClient]


def test_cors_allow_all(test_client_factory):
def homepage(request):
def test_cors_allow_all(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand Down Expand Up @@ -64,8 +73,10 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers


def test_cors_allow_all_except_credentials(test_client_factory):
def homepage(request):
def test_cors_allow_all_except_credentials(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand Down Expand Up @@ -113,8 +124,10 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers


def test_cors_allow_specific_origin(test_client_factory):
def homepage(request):
def test_cors_allow_specific_origin(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand Down Expand Up @@ -160,8 +173,10 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers


def test_cors_disallowed_preflight(test_client_factory):
def homepage(request):
def test_cors_disallowed_preflight(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> None:
pass # pragma: no cover

app = Starlette(
Expand Down Expand Up @@ -200,9 +215,9 @@ def homepage(request):


def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(
test_client_factory,
):
def homepage(request):
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> None:
return # pragma: no cover

app = Starlette(
Expand Down Expand Up @@ -234,8 +249,10 @@ def homepage(request):
assert response.headers["vary"] == "Origin"


def test_cors_preflight_allow_all_methods(test_client_factory):
def homepage(request):
def test_cors_preflight_allow_all_methods(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> None:
pass # pragma: no cover

app = Starlette(
Expand All @@ -258,8 +275,10 @@ def homepage(request):
assert method in response.headers["access-control-allow-methods"]


def test_cors_allow_all_methods(test_client_factory):
def homepage(request):
def test_cors_allow_all_methods(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand Down Expand Up @@ -287,8 +306,10 @@ def homepage(request):
assert response.status_code == 200


def test_cors_allow_origin_regex(test_client_factory):
def homepage(request):
def test_cors_allow_origin_regex(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand Down Expand Up @@ -357,8 +378,10 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers


def test_cors_allow_origin_regex_fullmatch(test_client_factory):
def homepage(request):
def test_cors_allow_origin_regex_fullmatch(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand Down Expand Up @@ -393,8 +416,10 @@ def homepage(request):
assert "access-control-allow-origin" not in response.headers


def test_cors_credentialed_requests_return_specific_origin(test_client_factory):
def homepage(request):
def test_cors_credentialed_requests_return_specific_origin(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand All @@ -412,8 +437,10 @@ def homepage(request):
assert "access-control-allow-credentials" not in response.headers


def test_cors_vary_header_defaults_to_origin(test_client_factory):
def homepage(request):
def test_cors_vary_header_defaults_to_origin(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand All @@ -430,8 +457,10 @@ def homepage(request):
assert response.headers["vary"] == "Origin"


def test_cors_vary_header_is_not_set_for_non_credentialed_request(test_client_factory):
def homepage(request):
def test_cors_vary_header_is_not_set_for_non_credentialed_request(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
Expand All @@ -447,8 +476,10 @@ def homepage(request):
assert response.headers["vary"] == "Accept-Encoding"


def test_cors_vary_header_is_properly_set_for_credentialed_request(test_client_factory):
def homepage(request):
def test_cors_vary_header_is_properly_set_for_credentialed_request(
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
Expand All @@ -467,9 +498,9 @@ def homepage(request):


def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
test_client_factory,
):
def homepage(request):
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse(
"Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
)
Expand All @@ -488,9 +519,9 @@ def homepage(request):


def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(
test_client_factory,
):
def homepage(request):
test_client_factory: TestClientFactory,
) -> None:
def homepage(request: Request) -> PlainTextResponse:
return PlainTextResponse("Homepage", status_code=200)

app = Starlette(
Expand Down

0 comments on commit 6a9ad19

Please sign in to comment.