Skip to content

Commit

Permalink
Support the WebSocket Denial Response ASGI extension (#2041)
Browse files Browse the repository at this point in the history
* supply asgi_extensions to TestClient

* Add WebSocket.send_response()

* Add response support for WebSocket testclient

* fix test for filesystem line-endings

* lintint

* support websocket.http.response extension by default

* Improve coverate

* Apply suggestions from code review

Co-authored-by: Marcelo Trylesinski <[email protected]>

* Undo unrelated change

* fix incorrect error message

* Update starlette/websockets.py

Co-authored-by: Marcelo Trylesinski <[email protected]>

* formatting

* Re-introduce close-code and close-reason to WebSocketReject

* Make sure the "websocket.connect" message is received in tests

* Deliver a websocket.disconnect message to the app even if it closes/rejects itself.

* Add test for filling out missing `websocket.disconnect` code

* Add rejection headers.  Expand tests.

* Fix types, headers in message are `bytes` tuples.

* Minimal WebSocket Denial Response implementation

* Revert "Minimal WebSocket Denial Response implementation"

This reverts commit 7af10dd.

* Rename to send_denial_response and update documentation

* Remove the app_disconnect_msg.  This can be added later in a separate PR

* Remove status code 1005 from this PR

* Assume that the application has tested for the extension before sending websocket.http.response.start

* Rename WebSocketReject to WebSocketDenialResponse

* Remove code and status from WebSocketDenialResponse.
Just send a regular WebSocketDisconnect even when connection is rejected with close()

* Raise an exception if attempting to send a http response and server does not support it.

* WebSocketDenialClose and WebSocketDenialResponse
These are both instances of WebSocketDenial.

* Update starlette/testclient.py

Co-authored-by: Marcelo Trylesinski <[email protected]>

* Revert "WebSocketDenialClose and WebSocketDenialResponse"

This reverts commit 71b76e3.

* Rename parameters, member variables

* Use httpx.Response as the base for WebSocketDenialResponse.

* Apply suggestions from code review

Co-authored-by: Marcelo Trylesinski <[email protected]>

* Update sanity check message

* Remove un-needed function

* Expand error message test regex

* Add type hings to test methods

* Add doc string to test.

* Fix mypy complaining about mismatching parent methods.

* nitpick & remove test

* Simplify the documentation

* Update starlette/testclient.py

* Update starlette/testclient.py

* Remove an unnecessary test

* there is no special "close because of rejection" in the testclient anymore.

---------

Co-authored-by: Marcelo Trylesinski <[email protected]>
  • Loading branch information
kristjanvalur and Kludex authored Feb 4, 2024
1 parent 043c800 commit 93e74a4
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 7 deletions.
14 changes: 14 additions & 0 deletions docs/websockets.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,17 @@ correctly updated.

* `await websocket.send(message)`
* `await websocket.receive()`

### Send Denial Response

If you call `websocket.close()` before calling `websocket.accept()` then
the server will automatically send a HTTP 403 error to the client.

If you want to send a different error response, you can use the
`websocket.send_denial_response()` method. This will send the response
and then close the connection.

* `await websocket.send_denial_response(response)`

This requires the ASGI server to support the WebSocket Denial Response
extension. If it is not supported a `RuntimeError` will be raised.
5 changes: 3 additions & 2 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,15 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
"type": "http.response.start",
"type": prefix + "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
await send({"type": "http.response.body", "body": self.body})
await send({"type": prefix + "http.response.body", "body": self.body})

if self.background is not None:
await self.background()
Expand Down
28 changes: 27 additions & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def __init__(self, session: WebSocketTestSession) -> None:
self.session = session


class WebSocketDenialResponse( # type: ignore[misc]
httpx.Response,
WebSocketDisconnect,
):
"""
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
`WebSocket` is closed before being accepted with a `send_denial_response()`.
"""


class WebSocketTestSession:
def __init__(
self,
Expand Down Expand Up @@ -159,7 +169,22 @@ async def _asgi_send(self, message: Message) -> None:
def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(
message.get("code", 1000), message.get("reason", "")
code=message.get("code", 1000), reason=message.get("reason", "")
)
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
body: list[bytes] = []
while True:
message = self.receive()
assert message["type"] == "websocket.http.response.body"
body.append(message["body"])
if not message.get("more_body", False):
break
raise WebSocketDenialResponse(
status_code=status_code,
headers=headers,
content=b"".join(body),
)

def send(self, message: Message) -> None:
Expand Down Expand Up @@ -277,6 +302,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
Expand Down
33 changes: 30 additions & 3 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import typing

from starlette.requests import HTTPConnection
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send


class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
RESPONSE = 3


class WebSocketDisconnect(Exception):
Expand Down Expand Up @@ -65,13 +67,20 @@ async def send(self, message: Message) -> None:
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
if message_type not in {"websocket.accept", "websocket.close"}:
if message_type not in {
"websocket.accept",
"websocket.close",
"websocket.http.response.start",
}:
raise RuntimeError(
'Expected ASGI message "websocket.accept" or '
f'"websocket.close", but got {message_type!r}'
'Expected ASGI message "websocket.accept",'
'"websocket.close" or "websocket.http.response.start",'
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
Expand All @@ -89,6 +98,16 @@ async def send(self, message: Message) -> None:
except IOError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
raise RuntimeError(
'Expected ASGI message "websocket.http.response.body", '
f"but got {message_type!r}"
)
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')

Expand Down Expand Up @@ -185,6 +204,14 @@ async def close(self, code: int = 1000, reason: str | None = None) -> None:
{"type": "websocket.close", "code": code, "reason": reason or ""}
)

async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
raise RuntimeError(
"The server doesn't support the Websocket Denial Response extension."
)


class WebSocketClose:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
Expand Down
110 changes: 109 additions & 1 deletion tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette import status
from starlette.testclient import TestClient
from starlette.responses import Response
from starlette.testclient import TestClient, WebSocketDenialResponse
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState

Expand Down Expand Up @@ -293,6 +294,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
await websocket.close(status.WS_1001_GOING_AWAY)

client = test_client_factory(app)
Expand All @@ -302,6 +305,111 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert exc.value.code == status.WS_1001_GOING_AWAY


def test_send_denial_response(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
await websocket.send_denial_response(response)

client = test_client_factory(app)
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/"):
pass # pragma: no cover
assert exc.value.status_code == 404
assert exc.value.content == b"foo"


def test_send_response_multi(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
await websocket.send(
{
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")],
}
)
await websocket.send(
{
"type": "websocket.http.response.body",
"body": b"hard",
"more_body": True,
}
)
await websocket.send(
{
"type": "websocket.http.response.body",
"body": b"body",
}
)

client = test_client_factory(app)
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/"):
pass # pragma: no cover
assert exc.value.status_code == 404
assert exc.value.content == b"hardbody"
assert exc.value.headers["foo"] == "bar"


def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
del scope["extensions"]["websocket.http.response"]
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
with pytest.raises(
RuntimeError,
match="The server doesn't support the Websocket Denial Response extension.",
):
await websocket.send_denial_response(response)
await websocket.close()

client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect) as exc:
with client.websocket_connect("/"):
pass # pragma: no cover
assert exc.value.code == status.WS_1000_NORMAL_CLOSURE


def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
await websocket.send(
{
"type": "websocket.http.response.start",
"status": response.status_code,
"headers": response.raw_headers,
}
)
await websocket.send(
{
"type": "websocket.http.response.start",
"status": response.status_code,
"headers": response.raw_headers,
}
)

client = test_client_factory(app)
with pytest.raises(
RuntimeError,
match=(
'Expected ASGI message "websocket.http.response.body", but got '
"'websocket.http.response.start'"
),
):
with client.websocket_connect("/"):
pass # pragma: no cover


def test_subprotocol(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
Expand Down

0 comments on commit 93e74a4

Please sign in to comment.