diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 8971a7d97..e7285449c 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -7,6 +7,8 @@ import httpx import pytest import websockets +import websockets.asyncio +import websockets.asyncio.client import websockets.client import websockets.exceptions from typing_extensions import TypedDict @@ -603,18 +605,15 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable async def websocket_session(uri: str): async with httpx.AsyncClient() as client: - try: - await client.get( - f"http://127.0.0.1:{unused_tcp_port}", - headers={ - "upgrade": "websocket", - "connection": "upgrade", - "sec-websocket-version": "13", - "sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==", - }, - ) - except httpx.RemoteProtocolError: - pass # pragma: no cover + await client.get( + f"http://127.0.0.1:{unused_tcp_port}", + headers={ + "upgrade": "websocket", + "connection": "upgrade", + "sec-websocket-version": "13", + "sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==", + }, + ) config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port) async with run_server(config): diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index ea70236b2..994af07e7 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -119,17 +119,14 @@ def eof_received(self) -> None: pass def shutdown(self) -> None: - if not self.transport.is_closing(): - if self.handshake_complete: - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) - self.close_sent = True - self.conn.send_close(1012) - output = self.conn.data_to_send() - self.transport.write(b"".join(output)) - elif not self.handshake_initiated: - self.send_500_response() - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) - self.transport.close() + if self.handshake_complete: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) + self.conn.send_close(1012) + output = self.conn.data_to_send() + self.transport.write(b"".join(output)) + else: + self.send_500_response() + self.transport.close() def data_received(self, data: bytes) -> None: self.conn.receive_data(data) @@ -161,7 +158,6 @@ def handle_connect(self, event: Request) -> None: self.request = event self.response = self.conn.accept(event) self.handshake_initiated = True - # if status_code is not 101 return response if self.response.status_code != 101: self.handshake_complete = True self.close_sent = True