Skip to content

Commit

Permalink
Add a bit more coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Dec 15, 2024
1 parent 97c6117 commit c523508
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
23 changes: 11 additions & 12 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import httpx
import pytest
import websockets
import websockets.asyncio
import websockets.asyncio.client
import websockets.client
import websockets.exceptions
from typing_extensions import TypedDict
Expand Down Expand Up @@ -603,18 +605,15 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable

async def websocket_session(uri: str):
async with httpx.AsyncClient() as client:
try:
await client.get(
f"http://127.0.0.1:{unused_tcp_port}",
headers={
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-version": "13",
"sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==",
},
)
except httpx.RemoteProtocolError:
pass # pragma: no cover
await client.get(
f"http://127.0.0.1:{unused_tcp_port}",
headers={
"upgrade": "websocket",
"connection": "upgrade",
"sec-websocket-version": "13",
"sec-websocket-key": "dGhlIHNhbXBsZSBub25jZQ==",
},
)

config = Config(app=app, ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", port=unused_tcp_port)
async with run_server(config):
Expand Down
20 changes: 8 additions & 12 deletions uvicorn/protocols/websockets/websockets_sansio_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,14 @@ def eof_received(self) -> None:
pass

def shutdown(self) -> None:
if not self.transport.is_closing():
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
self.close_sent = True
self.conn.send_close(1012)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
elif not self.handshake_initiated:
self.send_500_response()
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006})
self.transport.close()
if self.handshake_complete:
self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012})
self.conn.send_close(1012)
output = self.conn.data_to_send()
self.transport.write(b"".join(output))
else:
self.send_500_response()
self.transport.close()

def data_received(self, data: bytes) -> None:
self.conn.receive_data(data)
Expand Down Expand Up @@ -161,7 +158,6 @@ def handle_connect(self, event: Request) -> None:
self.request = event
self.response = self.conn.accept(event)
self.handshake_initiated = True
# if status_code is not 101 return response
if self.response.status_code != 101:
self.handshake_complete = True
self.close_sent = True
Expand Down

0 comments on commit c523508

Please sign in to comment.