Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to test_requests.py #2481

Merged
merged 11 commits into from
Feb 6, 2024
138 changes: 79 additions & 59 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import sys
from typing import List, Optional
from typing import Any, Callable, Dict, Iterator, List, Optional

import anyio
import pytest

from starlette.datastructures import Address, State
from starlette.requests import ClientDisconnect, Request
from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.types import Message, Scope
from starlette.testclient import TestClient
from starlette.types import Message, Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]

def test_request_url(test_client_factory):
async def app(scope, receive, send):

def test_request_url(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = {"method": request.method, "url": str(request.url)}
response = JSONResponse(data)
Expand All @@ -25,8 +28,8 @@ async def app(scope, receive, send):
assert response.json() == {"method": "GET", "url": "https://example.org:123/"}


def test_request_query_params(test_client_factory):
async def app(scope, receive, send):
def test_request_query_params(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
params = dict(request.query_params)
response = JSONResponse({"params": params})
Expand All @@ -41,8 +44,8 @@ async def app(scope, receive, send):
any(module in sys.modules for module in ("brotli", "brotlicffi")),
reason='urllib3 includes "br" to the "accept-encoding" headers.',
)
def test_request_headers(test_client_factory):
async def app(scope, receive, send):
def test_request_headers(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
headers = dict(request.headers)
response = JSONResponse({"headers": headers})
Expand All @@ -69,14 +72,14 @@ async def app(scope, receive, send):
({}, None),
],
)
def test_request_client(scope: Scope, expected_client: Optional[Address]):
def test_request_client(scope: Scope, expected_client: Optional[Address]) -> None:
scope.update({"type": "http"}) # required by Request's constructor
client = Request(scope).client
assert client == expected_client


def test_request_body(test_client_factory):
async def app(scope, receive, send):
def test_request_body(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
Expand All @@ -90,12 +93,12 @@ async def app(scope, receive, send):
response = client.post("/", json={"a": "123"})
assert response.json() == {"body": '{"a": "123"}'}

response = client.post("/", data="abc")
response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc"}


def test_request_stream(test_client_factory):
async def app(scope, receive, send):
def test_request_stream(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = b""
async for chunk in request.stream():
Expand All @@ -111,12 +114,12 @@ async def app(scope, receive, send):
response = client.post("/", json={"a": "123"})
assert response.json() == {"body": '{"a": "123"}'}

response = client.post("/", data="abc")
response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc"}


def test_request_form_urlencoded(test_client_factory):
async def app(scope, receive, send):
def test_request_form_urlencoded(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
form = await request.form()
response = JSONResponse({"form": dict(form)})
Expand All @@ -128,8 +131,8 @@ async def app(scope, receive, send):
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_form_context_manager(test_client_factory):
async def app(scope, receive, send):
def test_request_form_context_manager(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
async with request.form() as form:
response = JSONResponse({"form": dict(form)})
Expand All @@ -141,8 +144,8 @@ async def app(scope, receive, send):
assert response.json() == {"form": {"abc": "123 @"}}


def test_request_body_then_stream(test_client_factory):
async def app(scope, receive, send):
def test_request_body_then_stream(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
chunks = b""
Expand All @@ -153,12 +156,12 @@ async def app(scope, receive, send):

client = test_client_factory(app)

response = client.post("/", data="abc")
response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "abc", "stream": "abc"}


def test_request_stream_then_body(test_client_factory):
async def app(scope, receive, send):
def test_request_stream_then_body(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
chunks = b""
async for chunk in request.stream():
Expand All @@ -172,12 +175,12 @@ async def app(scope, receive, send):

client = test_client_factory(app)

response = client.post("/", data="abc")
response = client.post("/", data="abc") # type: ignore
assert response.json() == {"body": "<stream consumed>", "stream": "abc"}


def test_request_json(test_client_factory):
async def app(scope, receive, send):
def test_request_json(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
data = await request.json()
response = JSONResponse({"json": data})
Expand All @@ -188,7 +191,7 @@ async def app(scope, receive, send):
assert response.json() == {"json": {"a": "123"}}


def test_request_scope_interface():
def test_request_scope_interface() -> None:
"""
A Request can be instantiated with a scope, and presents a `Mapping`
interface.
Expand All @@ -199,8 +202,8 @@ def test_request_scope_interface():
assert len(request) == 3


def test_request_raw_path(test_client_factory):
async def app(scope, receive, send):
def test_request_raw_path(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
path = request.scope["path"]
raw_path = request.scope["raw_path"]
Expand All @@ -212,13 +215,15 @@ async def app(scope, receive, send):
assert response.text == "/he/llo, b'/he%2Fllo'"


def test_request_without_setting_receive(test_client_factory):
def test_request_without_setting_receive(
test_client_factory: TestClientFactory,
) -> None:
"""
If Request is instantiated without the receive channel, then .body()
is not available.
"""

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope)
try:
data = await request.json()
Expand All @@ -232,23 +237,26 @@ async def app(scope, receive, send):
assert response.json() == {"json": "Receive channel not available"}


def test_request_disconnect(anyio_backend_name, anyio_backend_options):
def test_request_disconnect(
anyio_backend_name: str,
anyio_backend_options: Dict[str, Any],
) -> None:
"""
If a client disconnect occurs while reading request body
then ClientDisconnect should be raised.
"""

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
await request.body()

async def receiver():
async def receiver() -> Message:
return {"type": "http.disconnect"}

scope = {"type": "http", "method": "POST", "path": "/"}
with pytest.raises(ClientDisconnect):
anyio.run(
app,
app, # type: ignore
scope,
receiver,
None,
Expand All @@ -257,14 +265,14 @@ async def receiver():
)


def test_request_is_disconnected(test_client_factory):
def test_request_is_disconnected(test_client_factory: TestClientFactory) -> None:
"""
If a client disconnect occurs while reading request body
then ClientDisconnect should be raised.
"""
disconnected_after_response = None

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal disconnected_after_response

request = Request(scope, receive)
Expand All @@ -280,7 +288,7 @@ async def app(scope, receive, send):
assert disconnected_after_response


def test_request_state_object():
def test_request_state_object() -> None:
scope = {"state": {"old": "foo"}}

s = State(scope["state"])
Expand All @@ -294,8 +302,8 @@ def test_request_state_object():
s.new


def test_request_state(test_client_factory):
async def app(scope, receive, send):
def test_request_state(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
request.state.example = 123
response = JSONResponse({"state.example": request.state.example})
Expand All @@ -306,8 +314,8 @@ async def app(scope, receive, send):
assert response.json() == {"state.example": 123}


def test_request_cookies(test_client_factory):
async def app(scope, receive, send):
def test_request_cookies(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
mycookie = request.cookies.get("mycookie")
if mycookie:
Expand All @@ -325,7 +333,7 @@ async def app(scope, receive, send):
assert response.text == "Hello, cookies!"


def test_cookie_lenient_parsing(test_client_factory):
def test_cookie_lenient_parsing(test_client_factory: TestClientFactory) -> None:
"""
The following test is based on a cookie set by Okta, a well-known authorization
service. It turns out that it's common practice to set cookies that would be
Expand All @@ -347,7 +355,7 @@ def test_cookie_lenient_parsing(test_client_factory):
"sessionCookie",
}

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
Expand Down Expand Up @@ -381,8 +389,12 @@ async def app(scope, receive, send):
("a=b; h=i; a=c", {"a": "c", "h": "i"}),
],
)
def test_cookies_edge_cases(set_cookie, expected, test_client_factory):
async def app(scope, receive, send):
def test_cookies_edge_cases(
set_cookie: str,
expected: Dict[str, str],
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
Expand Down Expand Up @@ -416,13 +428,17 @@ async def app(scope, receive, send):
# (" = b ; ; = ; c = ; ", {"": "b", "c": ""}),
],
)
def test_cookies_invalid(set_cookie, expected, test_client_factory):
def test_cookies_invalid(
set_cookie: str,
expected: Dict[str, str],
test_client_factory: TestClientFactory,
) -> None:
"""
Cookie strings that are against the RFC6265 spec but which browsers will send if set
via document.cookie.
"""

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = JSONResponse({"cookies": request.cookies})
await response(scope, receive, send)
Expand All @@ -433,25 +449,25 @@ async def app(scope, receive, send):
assert result["cookies"] == expected


def test_chunked_encoding(test_client_factory):
async def app(scope, receive, send):
def test_chunked_encoding(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
body = await request.body()
response = JSONResponse({"body": body.decode()})
await response(scope, receive, send)

client = test_client_factory(app)

def post_body():
def post_body() -> Iterator[bytes]:
yield b"foo"
yield b"bar"

response = client.post("/", data=post_body())
response = client.post("/", data=post_body()) # type: ignore
assert response.json() == {"body": "foobar"}


def test_request_send_push_promise(test_client_factory):
async def app(scope, receive, send):
def test_request_send_push_promise(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
# the server is push-enabled
scope["extensions"]["http.response.push"] = {}

Expand All @@ -466,13 +482,15 @@ async def app(scope, receive, send):
assert response.json() == {"json": "OK"}


def test_request_send_push_promise_without_push_extension(test_client_factory):
def test_request_send_push_promise_without_push_extension(
test_client_factory: TestClientFactory,
) -> None:
"""
If server does not support the `http.response.push` extension,
.send_push_promise() does nothing.
"""

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope)
await request.send_push_promise("/style.css")

Expand All @@ -484,13 +502,15 @@ async def app(scope, receive, send):
assert response.json() == {"json": "OK"}


def test_request_send_push_promise_without_setting_send(test_client_factory):
def test_request_send_push_promise_without_setting_send(
test_client_factory: TestClientFactory,
) -> None:
"""
If Request is instantiated without the send channel, then
.send_push_promise() is not available.
"""

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
# the server is push-enabled
scope["extensions"]["http.response.push"] = {}

Expand Down