diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 8a00f38dc..bc794c52b 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -16,6 +16,7 @@ from uvicorn.lifespan.on import LifespanOn from uvicorn.main import ServerState from uvicorn.protocols.http.h11_impl import H11Protocol +from uvicorn.protocols.utils import ClientDisconnected try: from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol @@ -369,9 +370,7 @@ async def test_close(http_protocol_cls: HTTPProtocol): @pytest.mark.anyio -async def test_chunked_encoding( - http_protocol_cls: HTTPProtocol, -): +async def test_chunked_encoding(http_protocol_cls: HTTPProtocol): app = Response( b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} ) @@ -385,9 +384,7 @@ async def test_chunked_encoding( @pytest.mark.anyio -async def test_chunked_encoding_empty_body( - http_protocol_cls: HTTPProtocol, -): +async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol): app = Response( b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"} ) @@ -416,9 +413,7 @@ async def test_chunked_encoding_head_request( @pytest.mark.anyio -async def test_pipelined_requests( - http_protocol_cls: HTTPProtocol, -): +async def test_pipelined_requests(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -440,9 +435,7 @@ async def test_pipelined_requests( @pytest.mark.anyio -async def test_undersized_request( - http_protocol_cls: HTTPProtocol, -): +async def test_undersized_request(http_protocol_cls: HTTPProtocol): app = Response(b"xxx", headers={"content-length": "10"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -452,9 +445,7 @@ async def test_undersized_request( @pytest.mark.anyio -async def test_oversized_request( - http_protocol_cls: HTTPProtocol, -): +async def test_oversized_request(http_protocol_cls: HTTPProtocol): app = Response(b"xxx" * 20, headers={"content-length": "10"}) protocol = get_connected_protocol(app, http_protocol_cls) @@ -464,9 +455,7 @@ async def test_oversized_request( @pytest.mark.anyio -async def test_large_post_request( - http_protocol_cls: HTTPProtocol, -): +async def test_large_post_request(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -486,9 +475,7 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol): @pytest.mark.anyio -async def test_app_exception( - http_protocol_cls: HTTPProtocol, -): +async def test_app_exception(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): raise Exception() @@ -500,9 +487,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_exception_during_response( - http_protocol_cls: HTTPProtocol, -): +async def test_exception_during_response(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b"1", "more_body": True}) @@ -516,9 +501,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_no_response_returned( - http_protocol_cls: HTTPProtocol, -): +async def test_no_response_returned(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): ... @@ -530,9 +513,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_partial_response_returned( - http_protocol_cls: HTTPProtocol, -): +async def test_partial_response_returned(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) @@ -544,9 +525,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_duplicate_start_message( - http_protocol_cls: HTTPProtocol, -): +async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.start", "status": 200}) @@ -559,9 +538,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_missing_start_message( - http_protocol_cls: HTTPProtocol, -): +async def test_missing_start_message(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.body", "body": b""}) @@ -573,9 +550,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_message_after_body_complete( - http_protocol_cls: HTTPProtocol, -): +async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b""}) @@ -589,9 +564,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_value_returned( - http_protocol_cls: HTTPProtocol, -): +async def test_value_returned(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): await send({"type": "http.response.start", "status": 200}) await send({"type": "http.response.body", "body": b""}) @@ -605,9 +578,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_early_disconnect( - http_protocol_cls: HTTPProtocol, -): +async def test_early_disconnect(http_protocol_cls: HTTPProtocol): got_disconnect_event = False async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -629,9 +600,26 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_early_response( - http_protocol_cls: HTTPProtocol, -): +async def test_disconnect_on_send(http_protocol_cls: HTTPProtocol) -> None: + got_disconnected = False + + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + try: + await send({"type": "http.response.start", "status": 200}) + except ClientDisconnected: + nonlocal got_disconnected + got_disconnected = True + + protocol = get_connected_protocol(app, http_protocol_cls) + protocol.data_received(SIMPLE_GET_REQUEST) + protocol.eof_received() + protocol.connection_lost(None) + await protocol.loop.run_one() + assert got_disconnected + + +@pytest.mark.anyio +async def test_early_response(http_protocol_cls: HTTPProtocol): app = Response("Hello, world", media_type="text/plain") protocol = get_connected_protocol(app, http_protocol_cls) @@ -643,9 +631,7 @@ async def test_early_response( @pytest.mark.anyio -async def test_read_after_response( - http_protocol_cls: HTTPProtocol, -): +async def test_read_after_response(http_protocol_cls: HTTPProtocol): message_after_response = None async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -663,9 +649,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio -async def test_http10_request( - http_protocol_cls: HTTPProtocol, -): +async def test_http10_request(http_protocol_cls: HTTPProtocol): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" content = "Version: %s" % scope["http_version"] @@ -876,8 +860,8 @@ async def asgi(receive: ASGIReceiveCallable, send: ASGISendCallable): @pytest.mark.parametrize( "asgi2or3_app, expected_scopes", [ - (asgi3app, {"version": "3.0", "spec_version": "2.3"}), - (asgi2app, {"version": "2.0", "spec_version": "2.3"}), + (asgi3app, {"version": "3.0", "spec_version": "2.4"}), + (asgi2app, {"version": "2.0", "spec_version": "2.4"}), ], ) async def test_scopes( diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 4922d1781..90bfaeadf 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -27,6 +27,7 @@ service_unavailable, ) from uvicorn.protocols.utils import ( + ClientDisconnected, get_client_addr, get_local_addr, get_path_with_query_string, @@ -205,7 +206,7 @@ def handle_events(self) -> None: "type": "http", "asgi": { "version": self.config.asgi_version, - "spec_version": "2.3", + "spec_version": "2.4", }, "http_version": event.http_version.decode("ascii"), "server": self.server, @@ -412,6 +413,8 @@ async def run_asgi(self, app: "ASGI3Application") -> None: result = await app( # type: ignore[func-returns-value] self.scope, self.receive, self.send ) + except ClientDisconnected: + pass except BaseException as exc: msg = "Exception in ASGI application\n" self.logger.error(msg, exc_info=exc) @@ -436,7 +439,7 @@ async def run_asgi(self, app: "ASGI3Application") -> None: self.on_response = lambda: None async def send_500_response(self) -> None: - response_start_event: "HTTPResponseStartEvent" = { + response_start_event: HTTPResponseStartEvent = { "type": "http.response.start", "status": 500, "headers": [ @@ -445,7 +448,7 @@ async def send_500_response(self) -> None: ], } await self.send(response_start_event) - response_body_event: "HTTPResponseBodyEvent" = { + response_body_event: HTTPResponseBodyEvent = { "type": "http.response.body", "body": b"Internal Server Error", "more_body": False, @@ -453,14 +456,14 @@ async def send_500_response(self) -> None: await self.send(response_body_event) # ASGI interface - async def send(self, message: "ASGISendEvent") -> None: + async def send(self, message: ASGISendEvent) -> None: message_type = message["type"] if self.flow.write_paused and not self.disconnected: await self.flow.drain() if self.disconnected: - return + raise ClientDisconnected if not self.response_started: # Sending response status line and headers @@ -527,7 +530,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.transport.close() self.on_response() - async def receive(self) -> "ASGIReceiveEvent": + async def receive(self) -> ASGIReceiveEvent: if self.waiting_for_100_continue and not self.transport.is_closing(): headers: list[tuple[str, str]] = [] event = h11.InformationalResponse( @@ -545,7 +548,7 @@ async def receive(self) -> "ASGIReceiveEvent": if self.disconnected or self.response_complete: return {"type": "http.disconnect"} - message: "HTTPRequestEvent" = { + message: HTTPRequestEvent = { "type": "http.request", "body": self.body, "more_body": self.more_body, diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index e203745b1..78e38154d 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -15,7 +15,6 @@ ASGI3Application, ASGIReceiveEvent, ASGISendEvent, - HTTPDisconnectEvent, HTTPRequestEvent, HTTPResponseBodyEvent, HTTPResponseStartEvent, @@ -30,6 +29,7 @@ service_unavailable, ) from uvicorn.protocols.utils import ( + ClientDisconnected, get_client_addr, get_local_addr, get_path_with_query_string, @@ -227,7 +227,7 @@ def on_message_begin(self) -> None: self.headers = [] self.scope = { # type: ignore[typeddict-item] "type": "http", - "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, + "asgi": {"version": self.config.asgi_version, "spec_version": "2.4"}, "http_version": "1.1", "server": self.server, "client": self.client, @@ -414,11 +414,13 @@ def __init__( self.expected_content_length = 0 # ASGI exception wrapper - async def run_asgi(self, app: "ASGI3Application") -> None: + async def run_asgi(self, app: ASGI3Application) -> None: try: result = await app( # type: ignore[func-returns-value] self.scope, self.receive, self.send ) + except ClientDisconnected: + pass except BaseException as exc: msg = "Exception in ASGI application\n" self.logger.error(msg, exc_info=exc) @@ -443,7 +445,7 @@ async def run_asgi(self, app: "ASGI3Application") -> None: self.on_response = lambda: None async def send_500_response(self) -> None: - response_start_event: "HTTPResponseStartEvent" = { + response_start_event: HTTPResponseStartEvent = { "type": "http.response.start", "status": 500, "headers": [ @@ -452,7 +454,7 @@ async def send_500_response(self) -> None: ], } await self.send(response_start_event) - response_body_event: "HTTPResponseBodyEvent" = { + response_body_event: HTTPResponseBodyEvent = { "type": "http.response.body", "body": b"Internal Server Error", "more_body": False, @@ -460,14 +462,14 @@ async def send_500_response(self) -> None: await self.send(response_body_event) # ASGI interface - async def send(self, message: "ASGISendEvent") -> None: + async def send(self, message: ASGISendEvent) -> None: message_type = message["type"] if self.flow.write_paused and not self.disconnected: await self.flow.drain() if self.disconnected: - return + raise ClientDisconnected if not self.response_started: # Sending response status line and headers @@ -570,7 +572,7 @@ async def send(self, message: "ASGISendEvent") -> None: msg = "Unexpected ASGI message '%s' sent, after response already completed." raise RuntimeError(msg % message_type) - async def receive(self) -> "ASGIReceiveEvent": + async def receive(self) -> ASGIReceiveEvent: if self.waiting_for_100_continue and not self.transport.is_closing(): self.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.waiting_for_100_continue = False @@ -580,15 +582,13 @@ async def receive(self) -> "ASGIReceiveEvent": await self.message_event.wait() self.message_event.clear() - message: HTTPDisconnectEvent | HTTPRequestEvent if self.disconnected or self.response_complete: - message = {"type": "http.disconnect"} - else: - message = { - "type": "http.request", - "body": self.body, - "more_body": self.more_body, - } - self.body = b"" + return {"type": "http.disconnect"} + message: HTTPRequestEvent = { + "type": "http.request", + "body": self.body, + "more_body": self.more_body, + } + self.body = b"" return message