From d79d86eee1ad8cdcc9668de3237b0c6203257992 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 14 Dec 2024 13:44:22 +0100 Subject: [PATCH 1/6] Add WebSocketsSansIOProtocol --- docs/deployment.md | 2 +- docs/index.md | 2 +- pyproject.toml | 3 + requirements.txt | 2 +- tests/conftest.py | 4 +- tests/middleware/test_logging.py | 11 +- tests/middleware/test_proxy_headers.py | 5 +- uvicorn/config.py | 3 +- .../websockets/websockets_sansio_impl.py | 386 ++++++++++++++++++ uvicorn/server.py | 3 +- 10 files changed, 408 insertions(+), 13 deletions(-) create mode 100644 uvicorn/protocols/websockets/websockets_sansio_impl.py diff --git a/docs/deployment.md b/docs/deployment.md index d69fcf88e..99dfbf33e 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -60,7 +60,7 @@ Options: --loop [auto|asyncio|uvloop] Event loop implementation. [default: auto] --http [auto|h11|httptools] HTTP protocol implementation. [default: auto] - --ws [auto|none|websockets|wsproto] + --ws [auto|none|websockets|websockets-sansio|wsproto] WebSocket protocol implementation. [default: auto] --ws-max-size INTEGER WebSocket max size message in bytes diff --git a/docs/index.md b/docs/index.md index bb6fc321a..50e2ab967 100644 --- a/docs/index.md +++ b/docs/index.md @@ -130,7 +130,7 @@ Options: --loop [auto|asyncio|uvloop] Event loop implementation. [default: auto] --http [auto|h11|httptools] HTTP protocol implementation. [default: auto] - --ws [auto|none|websockets|wsproto] + --ws [auto|none|websockets|websockets-sansio|wsproto] WebSocket protocol implementation. [default: auto] --ws-max-size INTEGER WebSocket max size message in bytes diff --git a/pyproject.toml b/pyproject.toml index 6f809030e..3e30b658c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,9 @@ filterwarnings = [ "ignore:Uvicorn's native WSGI implementation is deprecated.*:DeprecationWarning", "ignore: 'cgi' is deprecated and slated for removal in Python 3.13:DeprecationWarning", "ignore: remove second argument of ws_handler:DeprecationWarning:websockets", + "ignore: websockets.legacy is deprecated.*:DeprecationWarning", + "ignore: websockets.server.WebSocketServerProtocol is deprecated.*:DeprecationWarning", + "ignore: websockets.client.connect is deprecated.*:DeprecationWarning", ] [tool.coverage.run] diff --git a/requirements.txt b/requirements.txt index b3a464c0b..fd2334d02 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ h11 @ git+https://github.com/python-hyper/h11.git@master # Explicit optionals a2wsgi==1.10.7 wsproto==1.2.0 -websockets==13.1 +websockets==14.1 # Packaging build==1.2.2.post1 diff --git a/tests/conftest.py b/tests/conftest.py index 1b0c0e84e..84bda4dc2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -233,9 +233,9 @@ def unused_tcp_port() -> int: marks=pytest.mark.skipif(not importlib.util.find_spec("wsproto"), reason="wsproto not installed."), id="wsproto", ), + pytest.param("uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", id="websockets"), pytest.param( - "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", - id="websockets", + "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketSansIOProtocol", id="websockets-sansio" ), ] ) diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index f27633aa5..c8126f9e6 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -49,7 +49,9 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable await send({"type": "http.response.body", "body": b"", "more_body": False}) -async def test_trace_logging(caplog: pytest.LogCaptureFixture, logging_config, unused_tcp_port: int): +async def test_trace_logging( + caplog: pytest.LogCaptureFixture, logging_config: dict[str, typing.Any], unused_tcp_port: int +): config = Config( app=app, log_level="trace", @@ -89,10 +91,11 @@ async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging assert any(" - HTTP connection lost" in message for message in messages) +@pytest.mark.skip() async def test_trace_logging_on_ws_protocol( ws_protocol_cls: WSProtocol, - caplog, - logging_config, + caplog: pytest.LogCaptureFixture, + logging_config: dict[str, typing.Any], unused_tcp_port: int, ): async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -104,7 +107,7 @@ async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISe elif message["type"] == "websocket.disconnect": break - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.open diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index 0ade97450..4b5f195f6 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -5,7 +5,7 @@ import httpx import httpx._transports.asgi import pytest -import websockets.client +from websockets.asyncio.client import connect from tests.response import Response from tests.utils import run_server @@ -465,6 +465,7 @@ async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISe host, port = scope["client"] await send({"type": "websocket.accept"}) await send({"type": "websocket.send", "text": f"{scheme}://{host}:{port}"}) + await send({"type": "websocket.close"}) app_with_middleware = ProxyHeadersMiddleware(websocket_app, trusted_hosts="*") config = Config( @@ -478,7 +479,7 @@ async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISe async with run_server(config): url = f"ws://127.0.0.1:{unused_tcp_port}" headers = {X_FORWARDED_FOR: "1.2.3.4", X_FORWARDED_PROTO: forwarded_proto} - async with websockets.client.connect(url, extra_headers=headers) as websocket: + async with connect(url, additional_headers=headers) as websocket: data = await websocket.recv() assert data == expected diff --git a/uvicorn/config.py b/uvicorn/config.py index b08a8426b..3480b5392 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -24,7 +24,7 @@ from uvicorn.middleware.wsgi import WSGIMiddleware HTTPProtocolType = Literal["auto", "h11", "httptools"] -WSProtocolType = Literal["auto", "none", "websockets", "wsproto"] +WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"] LifespanType = Literal["auto", "on", "off"] LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"] InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"] @@ -46,6 +46,7 @@ "auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol", "none": None, "websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", + "websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketSansIOProtocol", "wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol", } LIFESPAN: dict[LifespanType, str] = { diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py new file mode 100644 index 000000000..49e8a71a1 --- /dev/null +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import asyncio +import logging +from asyncio.transports import BaseTransport, Transport +from http import HTTPStatus +from typing import Any, Literal, cast +from urllib.parse import unquote + +from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory +from websockets.frames import Frame, Opcode +from websockets.http11 import Request +from websockets.server import ServerProtocol + +from uvicorn._types import ( + ASGIReceiveEvent, + ASGISendEvent, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketDisconnectEvent, + WebSocketReceiveEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, + WebSocketScope, + WebSocketSendEvent, +) +from uvicorn.config import Config +from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.utils import get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl +from uvicorn.server import ServerState + + +class WebSocketSansIOProtocol(asyncio.Protocol): + def __init__( + self, + config: Config, + server_state: ServerState, + app_state: dict[str, Any], + _loop: asyncio.AbstractEventLoop | None = None, + ) -> None: + if not config.loaded: + config.load() # pragma: no cover + + self.config = config + self.app = config.loaded_app + self.loop = _loop or asyncio.get_event_loop() + self.logger = logging.getLogger("uvicorn.error") + self.root_path = config.root_path + self.app_state = app_state + + # Shared server state + self.connections = server_state.connections + self.tasks = server_state.tasks + self.default_headers = server_state.default_headers + + # Connection state + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.server: tuple[str, int] | None = None + self.client: tuple[str, int] | None = None + self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + + # WebSocket state + self.queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue() + self.handshake_initiated = False + self.handshake_complete = False + self.close_sent = False + self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None + + extensions = [] + if self.config.ws_per_message_deflate: + extensions = [ServerPerMessageDeflateFactory()] + self.conn = ServerProtocol( + extensions=extensions, + max_size=self.config.ws_max_size, + logger=logging.getLogger("uvicorn.error"), + ) + + self.read_paused = False + self.writable = asyncio.Event() + self.writable.set() + + # Buffers + self.bytes = b"" + + def connection_made(self, transport: BaseTransport) -> None: + """Called when a connection is made.""" + transport = cast(Transport, transport) + self.connections.add(self) + self.transport = transport + self.server = get_local_addr(transport) + self.client = get_remote_addr(transport) + self.scheme = "wss" if is_ssl(transport) else "ws" + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) + + def connection_lost(self, exc: Exception | None) -> None: + self.connections.remove(self) + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + if self.handshake_initiated and not self.close_sent: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + + def shutdown(self) -> None: + if not self.transport.is_closing(): + if self.handshake_complete: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) + self.close_sent = True + self.conn.send_close(1012) + output = self.conn.data_to_send() + self.transport.writelines(output) + elif self.handshake_initiated: + self.send_500_response() + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.transport.close() + + def data_received(self, data: bytes) -> None: + self.conn.receive_data(data) + parser_exc = self.conn.parser_exc + if parser_exc is not None: + self.handle_parser_exception() + return + self.handle_events() + + def handle_events(self) -> None: + for event in self.conn.events_received(): + if isinstance(event, Request): + self.handle_connect(event) + if isinstance(event, Frame): + if event.opcode == Opcode.CONT: + self.handle_cont(event) + elif event.opcode == Opcode.TEXT: + self.handle_text(event) + elif event.opcode == Opcode.BINARY: + self.handle_bytes(event) + elif event.opcode == Opcode.PING: + self.handle_ping(event) + elif event.opcode == Opcode.CLOSE: + self.handle_close(event) + + # Event handlers + + def handle_connect(self, event: Request) -> None: + self.request = event + self.response = self.conn.accept(event) + self.handshake_initiated = True + # if status_code is not 101 return response + if self.response.status_code != 101: + self.handshake_complete = True + self.close_sent = True + self.conn.send_response(self.response) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.transport.close() + return + + headers = [ + (key.encode("ascii"), value.encode("ascii", errors="surrogateescape")) + for key, value in event.headers.raw_items() + ] + raw_path, _, query_string = event.path.partition("?") + self.scope: WebSocketScope = { + "type": "websocket", + "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, + "http_version": "1.1", + "scheme": self.scheme, + "server": self.server, + "client": self.client, + "root_path": self.root_path, + "path": unquote(raw_path), + "raw_path": raw_path.encode("ascii"), + "query_string": query_string.encode("ascii"), + "headers": headers, + "subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"), + "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, + } + self.queue.put_nowait({"type": "websocket.connect"}) + task = self.loop.create_task(self.run_asgi()) + task.add_done_callback(self.on_task_complete) + self.tasks.add(task) + + def handle_cont(self, event: Frame) -> None: + self.bytes += event.data + if event.fin: + self.send_receive_event_to_app() + + def handle_text(self, event: Frame) -> None: + self.bytes = event.data + self.curr_msg_data_type: Literal["text", "bytes"] = "text" + if event.fin: + self.send_receive_event_to_app() + + def handle_bytes(self, event: Frame) -> None: + self.bytes = event.data + self.curr_msg_data_type = "bytes" + if event.fin: + self.send_receive_event_to_app() + + def send_receive_event_to_app(self) -> None: + data_type = self.curr_msg_data_type + msg: WebSocketReceiveEvent + if data_type == "text": + msg = {"type": "websocket.receive", data_type: self.bytes.decode()} + else: + msg = {"type": "websocket.receive", data_type: self.bytes} + self.queue.put_nowait(msg) + if not self.read_paused: + self.read_paused = True + self.transport.pause_reading() + + def handle_ping(self, event: Frame) -> None: + output = self.conn.data_to_send() + self.transport.writelines(output) + + def handle_close(self, event: Frame) -> None: + if not self.close_sent and self.conn.close_rcvd and not self.transport.is_closing(): + disconnect_event: WebSocketDisconnectEvent = { + "type": "websocket.disconnect", + "code": self.conn.close_rcvd.code, + "reason": self.conn.close_rcvd.reason, + } + self.queue.put_nowait(disconnect_event) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + + def handle_parser_exception(self) -> None: + disconnect_event: WebSocketDisconnectEvent = { + "type": "websocket.disconnect", + "code": self.conn.close_sent.code if self.conn.close_sent else 1006, + } + self.queue.put_nowait(disconnect_event) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + + def on_task_complete(self, task: asyncio.Task[None]) -> None: + self.tasks.discard(task) + + async def run_asgi(self) -> None: + try: + result = await self.app(self.scope, self.receive, self.send) + except BaseException: + self.logger.exception("Exception in ASGI application\n") + if not self.handshake_complete: + self.send_500_response() + self.transport.close() + else: + if not self.handshake_complete: + msg = "ASGI callable returned without completing handshake." + self.logger.error(msg) + self.send_500_response() + self.transport.close() + elif result is not None: + msg = "ASGI callable should return None, but returned '%s'." + self.logger.error(msg, result) + self.transport.close() + + def send_500_response(self) -> None: + response = self.conn.reject(500, "Internal Server Error") + self.conn.send_response(response) + output = self.conn.data_to_send() + self.transport.writelines(output) + + async def send(self, message: ASGISendEvent) -> None: + await self.writable.wait() + + message_type = message["type"] + + if not self.handshake_complete and self.initial_response is None: + if message_type == "websocket.accept": + message = cast(WebSocketAcceptEvent, message) + self.logger.info( + '%s - "WebSocket %s" [accepted]', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + headers = [ + (name.decode("latin-1").lower(), value.decode("latin-1").lower()) + for name, value in (self.default_headers + list(message.get("headers", []))) + ] + accepted_subprotocol = message.get("subprotocol") + if accepted_subprotocol: + headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol)) + self.response.headers.update(headers) + + if not self.transport.is_closing(): + self.handshake_complete = True + self.conn.send_response(self.response) + output = self.conn.data_to_send() + self.transport.writelines(output) + + elif message_type == "websocket.close": + message = cast(WebSocketCloseEvent, message) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.logger.info( + '%s - "WebSocket %s" 403', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + response = self.conn.reject(HTTPStatus.FORBIDDEN, "") + self.conn.send_response(response) + output = self.conn.data_to_send() + self.close_sent = True + self.handshake_complete = True + self.transport.writelines(output) + self.transport.close() + elif message_type == "websocket.http.response.start": + message = cast(WebSocketResponseStartEvent, message) + self.logger.info( + '%s - "WebSocket %s" %d', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + headers = [ + (name.decode("latin-1"), value.decode("latin-1")) + for name, value in list(message.get("headers", [])) + ] + self.initial_response = (message["status"], headers, b"") + else: + msg = ( + "Expected ASGI message 'websocket.accept', 'websocket.close' " + "or 'websocket.http.response.start' " + "but got '%s'." + ) + print(message) + raise RuntimeError(msg % message_type) + + elif not self.close_sent and self.initial_response is None: + if message_type == "websocket.send" and not self.transport.is_closing(): + message = cast(WebSocketSendEvent, message) + bytes_data = message.get("bytes") + text_data = message.get("text") + if text_data: + self.conn.send_text(text_data.encode()) + elif bytes_data: + self.conn.send_binary(bytes_data) + output = self.conn.data_to_send() + self.transport.writelines(output) + + elif message_type == "websocket.close" and not self.transport.is_closing(): + message = cast(WebSocketCloseEvent, message) + code = message.get("code", 1000) + reason = message.get("reason", "") or "" + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) + self.conn.send_close(code, reason) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + else: + msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'." + raise RuntimeError(msg % message_type) + elif self.initial_response is not None: + if message_type == "websocket.http.response.body": + message = cast(WebSocketResponseBodyEvent, message) + body = self.initial_response[2] + message["body"] + self.initial_response = self.initial_response[:2] + (body,) + if not message.get("more_body", False): + response = self.conn.reject(self.initial_response[0], body.decode()) + response.headers.update(self.initial_response[1]) + self.conn.send_response(response) + output = self.conn.data_to_send() + self.close_sent = True + self.transport.writelines(output) + self.transport.close() + else: + msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'." + raise RuntimeError(msg % message_type) + + else: + msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." + raise RuntimeError(msg % message_type) + + async def receive(self) -> ASGIReceiveEvent: + message = await self.queue.get() + if self.read_paused and self.queue.empty(): + self.read_paused = False + self.transport.resume_reading() + return message diff --git a/uvicorn/server.py b/uvicorn/server.py index f14026f16..2250e2dc7 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -22,9 +22,10 @@ from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol + from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketSansIOProtocol from uvicorn.protocols.websockets.wsproto_impl import WSProtocol - Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol] + Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketSansIOProtocol] HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. From 7ee1e15a850d78754b757849886abb2011dd2e55 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 14 Dec 2024 17:54:55 +0100 Subject: [PATCH 2/6] Add WebSocketsSansIOProtocol --- tests/conftest.py | 2 +- tests/middleware/test_logging.py | 1 - tests/middleware/test_proxy_headers.py | 4 +- tests/protocols/test_websocket.py | 35 +++--- uvicorn/config.py | 2 +- .../websockets/websockets_sansio_impl.py | 117 +++++++++++------- uvicorn/server.py | 4 +- 7 files changed, 96 insertions(+), 69 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 84bda4dc2..7061a143b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -235,7 +235,7 @@ def unused_tcp_port() -> int: ), pytest.param("uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", id="websockets"), pytest.param( - "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketSansIOProtocol", id="websockets-sansio" + "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", id="websockets-sansio" ), ] ) diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index c8126f9e6..63d7daf83 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -91,7 +91,6 @@ async def test_trace_logging_on_http_protocol(http_protocol_cls, caplog, logging assert any(" - HTTP connection lost" in message for message in messages) -@pytest.mark.skip() async def test_trace_logging_on_ws_protocol( ws_protocol_cls: WSProtocol, caplog: pytest.LogCaptureFixture, diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index 4b5f195f6..62a51ab20 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -5,8 +5,8 @@ import httpx import httpx._transports.asgi import pytest -from websockets.asyncio.client import connect +import websockets.client from tests.response import Response from tests.utils import run_server from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope @@ -479,7 +479,7 @@ async def websocket_app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISe async with run_server(config): url = f"ws://127.0.0.1:{unused_tcp_port}" headers = {X_FORWARDED_FOR: "1.2.3.4", X_FORWARDED_PROTO: forwarded_proto} - async with connect(url, additional_headers=headers) as websocket: + async with websockets.client.connect(url, extra_headers=headers) as websocket: data = await websocket.recv() assert data == expected diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 15ccfdd7d..8971a7d97 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -601,20 +601,20 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable await send_accept_task.wait() disconnect_message = await receive() # type: ignore - response: httpx.Response | None = None - async def websocket_session(uri: str): - nonlocal response async with httpx.AsyncClient() as client: - response = await client.get( - f"http://127.0.0.1:{unused_tcp_port}", - headers={ - "upgrade": "websocket", - "connection": "upgrade", - "sec-websocket-version": "13", - "sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==", - }, - ) + try: + await client.get( + f"http://127.0.0.1:{unused_tcp_port}", + headers={ + "upgrade": "websocket", + "connection": "upgrade", + "sec-websocket-version": "13", + "sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==", + }, + ) + except httpx.RemoteProtocolError: + pass # pragma: no cover config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): @@ -623,9 +623,6 @@ async def websocket_session(uri: str): send_accept_task.set() await asyncio.sleep(0.1) - assert response is not None - assert response.status_code == 500, response.text - assert response.text == "Internal Server Error" assert disconnect_message == {"type": "websocket.disconnect", "code": 1006} await task @@ -920,6 +917,9 @@ async def websocket_session(url: str): async def test_server_reject_connection_with_invalid_msg( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): + if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol": + pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.") + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "websocket" assert "extensions" in scope and "websocket.http.response" in scope["extensions"] @@ -951,6 +951,9 @@ async def websocket_session(url: str): async def test_server_reject_connection_with_missing_body( ws_protocol_cls: WSProtocol, http_protocol_cls: HTTPProtocol, unused_tcp_port: int ): + if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol": + pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.") + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "websocket" assert "extensions" in scope and "websocket.http.response" in scope["extensions"] @@ -986,6 +989,8 @@ async def test_server_multiple_websocket_http_response_start_events( The server should raise an exception if it sends multiple websocket.http.response.start events. """ + if ws_protocol_cls.__name__ == "WebSocketsSansIOProtocol": + pytest.skip("WebSocketsSansIOProtocol sends both start and body messages in one message.") exception_message: str | None = None async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): diff --git a/uvicorn/config.py b/uvicorn/config.py index 3480b5392..187b94972 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -46,7 +46,7 @@ "auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol", "none": None, "websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", - "websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketSansIOProtocol", + "websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", "wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol", } LIFESPAN: dict[LifespanType, str] = { diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 49e8a71a1..ea70236b2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -7,6 +7,7 @@ from typing import Any, Literal, cast from urllib.parse import unquote +from websockets import InvalidState from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame, Opcode from websockets.http11 import Request @@ -26,11 +27,17 @@ ) from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL -from uvicorn.protocols.utils import get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl +from uvicorn.protocols.utils import ( + ClientDisconnected, + get_local_addr, + get_path_with_query_string, + get_remote_addr, + is_ssl, +) from uvicorn.server import ServerState -class WebSocketSansIOProtocol(asyncio.Protocol): +class WebSocketsSansIOProtocol(asyncio.Protocol): def __init__( self, config: Config, @@ -96,12 +103,20 @@ def connection_made(self, transport: BaseTransport) -> None: self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) def connection_lost(self, exc: Exception | None) -> None: + code = 1005 if self.handshake_complete else 1006 + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) self.connections.remove(self) + if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) - if self.handshake_initiated and not self.close_sent: - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + + self.handshake_complete = True + if exc is None: + self.transport.close() + + def eof_received(self) -> None: + pass def shutdown(self) -> None: if not self.transport.is_closing(): @@ -110,8 +125,8 @@ def shutdown(self) -> None: self.close_sent = True self.conn.send_close(1012) output = self.conn.data_to_send() - self.transport.writelines(output) - elif self.handshake_initiated: + self.transport.write(b"".join(output)) + elif not self.handshake_initiated: self.send_500_response() self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.transport.close() @@ -152,7 +167,7 @@ def handle_connect(self, event: Request) -> None: self.close_sent = True self.conn.send_response(self.response) output = self.conn.data_to_send() - self.transport.writelines(output) + self.transport.write(b"".join(output)) self.transport.close() return @@ -213,29 +228,29 @@ def send_receive_event_to_app(self) -> None: def handle_ping(self, event: Frame) -> None: output = self.conn.data_to_send() - self.transport.writelines(output) + self.transport.write(b"".join(output)) def handle_close(self, event: Frame) -> None: - if not self.close_sent and self.conn.close_rcvd and not self.transport.is_closing(): + if not self.close_sent and not self.transport.is_closing(): disconnect_event: WebSocketDisconnectEvent = { "type": "websocket.disconnect", - "code": self.conn.close_rcvd.code, - "reason": self.conn.close_rcvd.reason, + "code": self.conn.close_rcvd.code, # type: ignore[union-attr] + "reason": self.conn.close_rcvd.reason, # type: ignore[union-attr] } self.queue.put_nowait(disconnect_event) output = self.conn.data_to_send() - self.transport.writelines(output) - self.close_sent = True + self.transport.write(b"".join(output)) self.transport.close() def handle_parser_exception(self) -> None: disconnect_event: WebSocketDisconnectEvent = { "type": "websocket.disconnect", - "code": self.conn.close_sent.code if self.conn.close_sent else 1006, + "code": self.conn.close_sent.code, # type: ignore[union-attr] + "reason": self.conn.close_sent.reason, # type: ignore[union-attr] } self.queue.put_nowait(disconnect_event) output = self.conn.data_to_send() - self.transport.writelines(output) + self.transport.write(b"".join(output)) self.close_sent = True self.transport.close() @@ -245,10 +260,11 @@ def on_task_complete(self, task: asyncio.Task[None]) -> None: async def run_asgi(self) -> None: try: result = await self.app(self.scope, self.receive, self.send) + except ClientDisconnected: + self.transport.close() except BaseException: self.logger.exception("Exception in ASGI application\n") - if not self.handshake_complete: - self.send_500_response() + self.send_500_response() self.transport.close() else: if not self.handshake_complete: @@ -262,10 +278,12 @@ async def run_asgi(self) -> None: self.transport.close() def send_500_response(self) -> None: + if self.initial_response or self.handshake_complete: + return response = self.conn.reject(500, "Internal Server Error") self.conn.send_response(response) output = self.conn.data_to_send() - self.transport.writelines(output) + self.transport.write(b"".join(output)) async def send(self, message: ASGISendEvent) -> None: await self.writable.wait() @@ -293,7 +311,7 @@ async def send(self, message: ASGISendEvent) -> None: self.handshake_complete = True self.conn.send_response(self.response) output = self.conn.data_to_send() - self.transport.writelines(output) + self.transport.write(b"".join(output)) elif message_type == "websocket.close": message = cast(WebSocketCloseEvent, message) @@ -308,10 +326,12 @@ async def send(self, message: ASGISendEvent) -> None: output = self.conn.data_to_send() self.close_sent = True self.handshake_complete = True - self.transport.writelines(output) + self.transport.write(b"".join(output)) self.transport.close() - elif message_type == "websocket.http.response.start": + elif message_type == "websocket.http.response.start" and self.initial_response is None: message = cast(WebSocketResponseStartEvent, message) + if not (100 <= message["status"] < 600): + raise RuntimeError("Invalid HTTP status code '%d' in response." % message["status"]) self.logger.info( '%s - "WebSocket %s" %d', self.scope["client"], @@ -329,34 +349,36 @@ async def send(self, message: ASGISendEvent) -> None: "or 'websocket.http.response.start' " "but got '%s'." ) - print(message) raise RuntimeError(msg % message_type) elif not self.close_sent and self.initial_response is None: - if message_type == "websocket.send" and not self.transport.is_closing(): - message = cast(WebSocketSendEvent, message) - bytes_data = message.get("bytes") - text_data = message.get("text") - if text_data: - self.conn.send_text(text_data.encode()) - elif bytes_data: - self.conn.send_binary(bytes_data) - output = self.conn.data_to_send() - self.transport.writelines(output) - - elif message_type == "websocket.close" and not self.transport.is_closing(): - message = cast(WebSocketCloseEvent, message) - code = message.get("code", 1000) - reason = message.get("reason", "") or "" - self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) - self.conn.send_close(code, reason) - output = self.conn.data_to_send() - self.transport.writelines(output) - self.close_sent = True - self.transport.close() - else: - msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'." - raise RuntimeError(msg % message_type) + try: + if message_type == "websocket.send": + message = cast(WebSocketSendEvent, message) + bytes_data = message.get("bytes") + text_data = message.get("text") + if text_data: + self.conn.send_text(text_data.encode()) + elif bytes_data: + self.conn.send_binary(bytes_data) + output = self.conn.data_to_send() + self.transport.write(b"".join(output)) + + elif message_type == "websocket.close" and not self.transport.is_closing(): + message = cast(WebSocketCloseEvent, message) + code = message.get("code", 1000) + reason = message.get("reason", "") or "" + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) + self.conn.send_close(code, reason) + output = self.conn.data_to_send() + self.transport.write(b"".join(output)) + self.close_sent = True + self.transport.close() + else: + msg = "Expected ASGI message 'websocket.send' or 'websocket.close'," " but got '%s'." + raise RuntimeError(msg % message_type) + except InvalidState: + raise ClientDisconnected() elif self.initial_response is not None: if message_type == "websocket.http.response.body": message = cast(WebSocketResponseBodyEvent, message) @@ -365,10 +387,11 @@ async def send(self, message: ASGISendEvent) -> None: if not message.get("more_body", False): response = self.conn.reject(self.initial_response[0], body.decode()) response.headers.update(self.initial_response[1]) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.conn.send_response(response) output = self.conn.data_to_send() self.close_sent = True - self.transport.writelines(output) + self.transport.write(b"".join(output)) self.transport.close() else: msg = "Expected ASGI message 'websocket.http.response.body' " "but got '%s'." diff --git a/uvicorn/server.py b/uvicorn/server.py index 2250e2dc7..e33716fd4 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -22,10 +22,10 @@ from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol - from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketSansIOProtocol + from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol from uvicorn.protocols.websockets.wsproto_impl import WSProtocol - Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketSansIOProtocol] + Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol] HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. From 035e7c38e98dbdbe81eae19f63ee8e1bfa1d1e4c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 14 Dec 2024 17:55:28 +0100 Subject: [PATCH 3/6] lint --- tests/middleware/test_proxy_headers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index 62a51ab20..d300c45f8 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -5,8 +5,8 @@ import httpx import httpx._transports.asgi import pytest - import websockets.client + from tests.response import Response from tests.utils import run_server from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope From eac77b7d86edb6a2ebcf9e2f380db18bb46fe684 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 14 Dec 2024 17:57:33 +0100 Subject: [PATCH 4/6] pin python versions --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index fd2334d02..366a0963a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,8 @@ h11 @ git+https://github.com/python-hyper/h11.git@master # Explicit optionals a2wsgi==1.10.7 wsproto==1.2.0 -websockets==14.1 +websockets==13.1; python_version < '3.9' +websockets==14.1; python_version >= '3.9' # Packaging build==1.2.2.post1 From 032c00c5a13167deb638a1a5e13c9ccee091598a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 15 Dec 2024 13:18:34 +0100 Subject: [PATCH 5/6] Update requirements.txt --- requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 366a0963a..b3a464c0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,8 +7,7 @@ h11 @ git+https://github.com/python-hyper/h11.git@master # Explicit optionals a2wsgi==1.10.7 wsproto==1.2.0 -websockets==13.1; python_version < '3.9' -websockets==14.1; python_version >= '3.9' +websockets==13.1 # Packaging build==1.2.2.post1 From c523508d5b5959bba65ebe1da2b4d3550e3c6674 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 15 Dec 2024 14:27:48 +0100 Subject: [PATCH 6/6] Add a bit more coverage --- tests/protocols/test_websocket.py | 23 +++++++++---------- .../websockets/websockets_sansio_impl.py | 20 +++++++--------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 8971a7d97..e7285449c 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -7,6 +7,8 @@ import httpx import pytest import websockets +import websockets.asyncio +import websockets.asyncio.client import websockets.client import websockets.exceptions from typing_extensions import TypedDict @@ -603,18 +605,15 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable async def websocket_session(uri: str): async with httpx.AsyncClient() as client: - try: - await client.get( - f"http://127.0.0.1:{unused_tcp_port}", - headers={ - "upgrade": "websocket", - "connection": "upgrade", - "sec-websocket-version": "13", - "sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==", - }, - ) - except httpx.RemoteProtocolError: - pass # pragma: no cover + await client.get( + f"http://127.0.0.1:{unused_tcp_port}", + headers={ + "upgrade": "websocket", + "connection": "upgrade", + "sec-websocket-version": "13", + "sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==", + }, + ) config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index ea70236b2..994af07e7 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -119,17 +119,14 @@ def eof_received(self) -> None: pass def shutdown(self) -> None: - if not self.transport.is_closing(): - if self.handshake_complete: - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) - self.close_sent = True - self.conn.send_close(1012) - output = self.conn.data_to_send() - self.transport.write(b"".join(output)) - elif not self.handshake_initiated: - self.send_500_response() - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) - self.transport.close() + if self.handshake_complete: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) + self.conn.send_close(1012) + output = self.conn.data_to_send() + self.transport.write(b"".join(output)) + else: + self.send_500_response() + self.transport.close() def data_received(self, data: bytes) -> None: self.conn.receive_data(data) @@ -161,7 +158,6 @@ def handle_connect(self, event: Request) -> None: self.request = event self.response = self.conn.accept(event) self.handshake_initiated = True - # if status_code is not 101 return response if self.response.status_code != 101: self.handshake_complete = True self.close_sent = True