From 190ce7b00c4bad7248f69c2396cc9a8d9558185c Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Wed, 9 Jun 2021 09:56:01 +1000 Subject: [PATCH 01/16] First attempt at new Websockets implementation based on websockets >= 9.0, with sans-i/o features. Requires more work. --- sanic/app.py | 12 - sanic/config.py | 1 - sanic/mixins/routes.py | 7 +- sanic/server.py | 165 +++++--- sanic/websocket.py | 874 ++++++++++++++++++++++++++++++++++++++--- sanic/worker.py | 2 +- tests/test_app.py | 2 - tests/test_routes.py | 16 +- tests/test_worker.py | 4 +- 9 files changed, 934 insertions(+), 149 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 1bbd5e29c9..ebe5097687 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -745,22 +745,10 @@ async def handle_request(self, request: Request): async def _websocket_handler( self, handler, request, *args, subprotocols=None, **kwargs ): - request.app = self - if not getattr(handler, "__blueprintname__", False): - request._name = handler.__name__ - else: - request._name = ( - getattr(handler, "__blueprintname__", "") + handler.__name__ - ) - - pass - if self.asgi: ws = request.transport.get_websocket_connection() else: protocol = request.transport.get_protocol() - protocol.app = self - ws = await protocol.websocket_handshake(request, subprotocols) # schedule the application handler diff --git a/sanic/config.py b/sanic/config.py index c8121da977..3512e8eb7d 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -24,7 +24,6 @@ "KEEP_ALIVE": True, "KEEP_ALIVE_TIMEOUT": 5, # 5 seconds "WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte - "WEBSOCKET_MAX_QUEUE": 32, "WEBSOCKET_READ_LIMIT": 2 ** 16, "WEBSOCKET_WRITE_LIMIT": 2 ** 16, "WEBSOCKET_PING_TIMEOUT": 20, diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index e468be69b4..7fc52c4081 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -115,8 +115,11 @@ def decorator(handler): "Expected either string or Iterable of host strings, " "not %s" % host ) - - if isinstance(subprotocols, (list, tuple, set)): + if isinstance(subprotocols, list): + # Ordered subprotocols, maintain order + subprotocols = tuple(subprotocols) + if isinstance(subprotocols, set): + # subprotocol is unordered, keep it unordered subprotocols = frozenset(subprotocols) route = FutureRoute( diff --git a/sanic/server.py b/sanic/server.py index 0d88259a69..a2b229f689 100644 --- a/sanic/server.py +++ b/sanic/server.py @@ -99,11 +99,7 @@ def __init__(self, transport: TransportProtocol, unix=None): self.client_port = addr[1] -class HttpProtocol(asyncio.Protocol): - """ - This class provides a basic HTTP implementation of the sanic framework. - """ - +class SanicProtocol(asyncio.Protocol): __slots__ = ( # app "app", @@ -111,9 +107,110 @@ class HttpProtocol(asyncio.Protocol): "loop", "transport", "connections", - "signal", "conn_info", - "ctx", + "signal", + "_can_write", + "_time", + "_task", + "_unix", + "_data_received", + ) + + def __init__( + self, + *, + loop, + app: Sanic, + signal=None, + connections=None, + unix=None, + **kwargs, + ): + asyncio.set_event_loop(loop) + self.loop = loop + self.app: Sanic = app + self.signal = signal or Signal() + self.transport: Optional[Transport] = None + self.connections = connections if connections is not None else set() + self.conn_info: Optional[ConnInfo] = None + self._can_write = asyncio.Event() + self._can_write.set() + self._unix = unix + self._time = 0.0 # type: float + self._task = None # type: Optional[asyncio.Task] + self._data_received = asyncio.Event() + + @property + def ctx(self): + if self.conn_info is not None: + return self.conn_info.ctx + else: + return self.ctx + + async def send(self, data): + """ + Writes data with backpressure control. + """ + await self._can_write.wait() + if self.transport.is_closing(): + raise CancelledError + self.transport.write(data) + self._time = current_time() + + def close(self): + """ + Force close the connection. + """ + # Cause a call to connection_lost where further cleanup occurs + if self.transport: + self.transport.close() + self.transport = None + + # asyncio.Protocol API Callbacks # + # ------------------------------ # + def connection_made(self, transport): + try: + # TODO: Benchmark to find suitable write buffer limits + transport.set_write_buffer_limits(low=16384, high=65536) + self.connections.add(self) + self.transport = transport + self.conn_info = ConnInfo(self.transport, unix=self._unix) + except Exception: + error_logger.exception("protocol.connect_made") + + def connection_lost(self, exc): + try: + self.connections.discard(self) + self.resume_writing() + if self._task: + self._task.cancel() + except BaseException: + error_logger.exception("protocol.connection_lost") + + def pause_writing(self): + self._can_write.clear() + + def resume_writing(self): + self._can_write.set() + + def data_received(self, data: bytes): + try: + self._time = current_time() + if not data: + return self.close() + + if self._data_received: + self._data_received.set() + except BaseException: + error_logger.exception("protocol.data_received") + + +class HttpProtocol(SanicProtocol): + """ + This class provides a basic HTTP implementation of the sanic framework. + """ + + __slots__ = ( # request params "request", # request config @@ -131,14 +228,9 @@ class HttpProtocol(asyncio.Protocol): "state", "url", "_handler_task", - "_can_write", - "_data_received", - "_time", - "_task", "_http", "_exception", "recv_buffer", - "_unix", ) def __init__( @@ -152,16 +244,10 @@ def __init__( unix=None, **kwargs, ): - asyncio.set_event_loop(loop) - self.loop = loop - self.app: Sanic = app + super().__init__(loop=loop, app=app, signal=signal, connections=connections, unix=unix) self.url = None - self.transport: Optional[Transport] = None - self.conn_info: Optional[ConnInfo] = None self.request: Optional[Request] = None - self.signal = signal or Signal() self.access_log = self.app.config.ACCESS_LOG - self.connections = connections if connections is not None else set() self.request_handler = self.app.handle_request self.error_handler = self.app.error_handler self.request_timeout = self.app.config.REQUEST_TIMEOUT @@ -175,11 +261,7 @@ def __init__( self.state = state if state else {} if "requests_count" not in self.state: self.state["requests_count"] = 0 - self._data_received = asyncio.Event() - self._can_write = asyncio.Event() - self._can_write.set() self._exception = None - self._unix = unix def _setup_connection(self): self._http = Http(self) @@ -260,16 +342,6 @@ def check_timeouts(self): except Exception: error_logger.exception("protocol.check_timeouts") - async def send(self, data): - """ - Writes data with backpressure control. - """ - await self._can_write.wait() - if self.transport.is_closing(): - raise CancelledError - self.transport.write(data) - self._time = current_time() - def close_if_idle(self) -> bool: """ Close the connection if a request is not being sent or received @@ -281,15 +353,6 @@ def close_if_idle(self) -> bool: return True return False - def close(self): - """ - Force close the connection. - """ - # Cause a call to connection_lost where further cleanup occurs - if self.transport: - self.transport.close() - self.transport = None - # -------------------------------------------- # # Only asyncio.Protocol callbacks below this # -------------------------------------------- # @@ -306,21 +369,6 @@ def connection_made(self, transport): except Exception: error_logger.exception("protocol.connect_made") - def connection_lost(self, exc): - try: - self.connections.discard(self) - self.resume_writing() - if self._task: - self._task.cancel() - except Exception: - error_logger.exception("protocol.connection_lost") - - def pause_writing(self): - self._can_write.clear() - - def resume_writing(self): - self._can_write.set() - def data_received(self, data: bytes): try: self._time = current_time() @@ -606,7 +654,7 @@ def serve( coros = [] for conn in connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) + coros.append(conn.websocket.close(code=1001)) else: conn.close() @@ -624,7 +672,6 @@ def _build_protocol_kwargs( if hasattr(protocol, "websocket_handshake"): return { "websocket_max_size": config.WEBSOCKET_MAX_SIZE, - "websocket_max_queue": config.WEBSOCKET_MAX_QUEUE, "websocket_read_limit": config.WEBSOCKET_READ_LIMIT, "websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT, "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, diff --git a/sanic/websocket.py b/sanic/websocket.py index 6b325e2676..1cb080ad3e 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -1,3 +1,6 @@ +import random +import struct +from email.utils import formatdate from typing import ( Any, Awaitable, @@ -7,23 +10,748 @@ MutableMapping, Optional, Union, + Iterable, + Sequence, + Mapping, AsyncIterator ) +import codecs from httptools import HttpParserUpgrade # type: ignore -from websockets import ( # type: ignore - ConnectionClosed, - InvalidHandshake, - WebSocketCommonProtocol, - handshake, -) +from websockets.server import ServerConnection +from websockets.connection import Event, OPEN, CLOSING, CLOSED +from websockets.exceptions import ConnectionClosed, InvalidHandshake, InvalidOrigin, InvalidUpgrade, \ + ConnectionClosedError +from websockets.typing import Data +from websockets.frames import Frame, Opcode, prepare_ctrl, OP_PONG +from websockets.utils import accept_key -from sanic.exceptions import InvalidUsage -from sanic.server import HttpProtocol +from sanic.exceptions import InvalidUsage, Forbidden, SanicException +from sanic.server import HttpProtocol, SanicProtocol +from sanic.log import error_logger, logger +import asyncio -__all__ = ["ConnectionClosed", "WebSocketProtocol", "WebSocketConnection"] ASIMessage = MutableMapping[str, Any] +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +class WebsocketFrameAssembler: + """ + Assemble a message from frames. + Code borrowed from aaugustin/websockets project: + https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py + """ + __slots__ = ("protocol", "read_mutex", "write_mutex", "message_complete", "message_fetched", "get_in_progress", "decoder", "completed_queue", "chunks", "chunks_queue", "paused", "get_id", "put_id") + + def __init__(self, protocol) -> None: + + self.protocol = protocol + + self.read_mutex = asyncio.Lock() + self.write_mutex = asyncio.Lock() + + self.completed_queue = asyncio.Queue(maxsize=1) # type: asyncio.Queue[Data] + + + # put() sets this event to tell get() that a message can be fetched. + self.message_complete = asyncio.Event() + # get() sets this event to let put() + self.message_fetched = asyncio.Event() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder: Optional[codecs.IncrementalDecoder] = None + + # Buffer data from frames belonging to the same message. + self.chunks: List[Data] = [] + + # When switching from "buffering" to "streaming", we use a thread-safe + # queue for transferring frames from the writing thread (library code) + # to the reading thread (user code). We're buffering when chunks_queue + # is None and streaming when it's a Queue. None is a sentinel + # value marking the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + self.chunks_queue: Optional[asyncio.Queue[Optional[Data]]] = None + + # Flag to indicate we've paused the protocol + self.paused = False + + + async def get(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Read the next message. + :meth:`get` returns a single :class:`str` or :class:`bytes`. + If the :message was fragmented, :meth:`get` waits until the last frame + is received, then it reassembles the message. + If ``timeout`` is set and elapses before a complete message is + received, :meth:`get` returns ``None``. + """ + async with self.read_mutex: + if timeout is not None and timeout <= 0: + if not self.message_complete.is_set(): + return None + assert not self.get_in_progress + self.get_in_progress = True + + # If the message_complete event isn't set yet, release the lock to + # allow put() to run and eventually set it. + # Locking with get_in_progress ensures only one thread can get here. + if timeout is None: + completed = await self.message_complete.wait() + elif timeout <= 0: + completed = self.message_complete.is_set() + else: + completed = await asyncio.wait_for(self.message_complete.wait(), timeout=timeout) + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + assert self.get_in_progress + self.get_in_progress = False + + # Waiting for a complete message timed out. + if not completed: + return None + + assert self.message_complete.is_set() + self.message_complete.clear() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + + assert not self.message_fetched.is_set() + self.message_fetched.set() + self.chunks = [] + assert self.chunks_queue is None + + return message + + async def get_iter(self) -> AsyncIterator[Data]: + """ + Stream the next message. + Iterating the return value of :meth:`get_iter` yields a :class:`str` + or :class:`bytes` for each frame in the message. + """ + async with self.read_mutex: + assert not self.get_in_progress + self.get_in_progress = True + + chunks = self.chunks + self.chunks = [] + self.chunks_queue = asyncio.Queue() + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.is_set(): + await self.chunks_queue.put(None) + + # Locking with get_in_progress ensures only one thread can get here. + for c in chunks: + yield c + while True: + chunk = await self.chunks_queue.get() + if chunk is None: + break + yield chunk + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + assert self.get_in_progress + self.get_in_progress = False + assert self.message_complete.is_set() + self.message_complete.clear() + + assert not self.message_fetched.is_set() + + self.message_fetched.set() + + assert self.chunks == [] + self.chunks_queue = None + + async def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + When ``frame`` is the final frame in a message, :meth:`put` waits + until the message is fetched, either by calling :meth:`get` or by + iterating the return value of :meth:`get_iter`. + :meth:`put` assumes that the stream of frames respects the protocol. + If it doesn't, the behavior is undefined. + """ + id = self.put_id + self.put_id += 1 + async with self.write_mutex: + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + if self.chunks_queue is None: + self.chunks.append(data) + else: + await self.chunks_queue.put(data) + + if not frame.fin: + return + if not self.get_in_progress: + self.paused = self.protocol.pause_frames() + # Message is complete. Wait until it's fetched to return. + + if self.chunks_queue is not None: + await self.chunks_queue.put(None) + + assert not self.message_complete.is_set() + self.message_complete.set() + assert not self.message_fetched.is_set() + + # Release the lock to allow get() to run and eventually set the event. + await self.message_fetched.wait() + assert self.message_fetched.is_set() + self.message_fetched.clear() + self.decoder = None + +class WebsocketImplProtocol: + def __init__(self, connection, max_queue=None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, loop=None): + self.connection = connection # type: ServerConnection + self.io_proto = None # type: Optional[SanicProtocol] + self.loop = None # type: Optional[asyncio.BaseEventLoop] + self.max_queue = max_queue + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.assembler = WebsocketFrameAssembler(self) + self.pings: Dict[bytes, asyncio.Future[None]] = {} + self.conn_mutex = asyncio.Lock() + self.recv_lock = asyncio.Lock() + self.process_event_mutex = asyncio.Lock() + self.data_finished_fut = None # type: Optional[asyncio.Future[None]] + self.can_pause = True + self.pause_frame_fut = None # type: Optional[asyncio.Future[None]] + self.keepalive_ping_task = None + self.close_connection_task = None + self.connection_lost_waiter = None # type: Optional[asyncio.Future[None]] + + @property + def subprotocol(self): + return self.connection.subprotocol + + def pause_frames(self): + if not self.can_pause: + return False + if self.pause_frame_fut is not None: + return False + if self.loop is None or self.io_proto is None: + return False + if self.io_proto.transport is not None: + self.io_proto.transport.pause_reading() + self.pause_frame_fut = self.loop.create_future() + return True + + def resume_frames(self): + if self.pause_frame_fut is None: + return False + if self.loop is None or self.io_proto is None: + error_logger.warning("Websocket attempting to resume reading frames, but connection is gone.") + return False + if self.io_proto.transport is not None: + self.io_proto.transport.resume_reading() + self.pause_frame_fut.set_result(None) + self.pause_frame_fut = None + return True + + async def connection_made(self, io_proto: asyncio.Protocol, loop=None): + if loop is None: + try: + loop = getattr(io_proto, "loop") + except AttributeError: + loop = asyncio.get_event_loop() + self.loop = loop + self.io_proto = io_proto # this will be a WebSocketProtocol + self.connection_lost_waiter = self.loop.create_future() + self.data_finished_fut = asyncio.shield(self.loop.create_future()) + + if self.ping_interval is not None: + self.keepalive_ping_task = asyncio.create_task(self.keepalive_ping()) + self.close_connection_task = asyncio.create_task(self.auto_close_connection()) + + async def wait_for_connection_lost(self, timeout=10) -> bool: + """ + Wait until the TCP connection is closed or ``timeout`` elapses. + + Return ``True`` if the connection is closed and ``False`` otherwise. + + """ + if self.connection_lost_waiter is None: + return False + if not self.connection_lost_waiter.done(): + try: + await asyncio.wait_for( + asyncio.shield(self.connection_lost_waiter), timeout + ) + except asyncio.TimeoutError: + pass + # Re-check self.connection_lost_waiter.done() synchronously because + # connection_lost() could run between the moment the timeout occurs + # and the moment this coroutine resumes running. + return self.connection_lost_waiter.done() + + async def process_events(self, events: Sequence[Event]) -> None: + """ + Process a list of incoming events. + """ + # Wrapped in a mutex lock, to prevent other incoming events + # from processing at the same time + async with self.process_event_mutex: + for event in events: + if event.opcode == OP_PONG: + await self.process_pong(event) + else: + await self.assembler.put(event) + + async def process_pong(self, frame: Frame) -> None: + if frame.data in self.pings: + # Acknowledge all pings up to the one matching this pong. + ping_id = None + ping_ids = [] + for ping_id, ping in self.pings.items(): + ping_ids.append(ping_id) + if not ping.done(): + ping.set_result(None) + if ping_id == frame.data: + break + else: # pragma: no cover + assert False, "ping_id is in self.pings" + # Remove acknowledged pings from self.pings. + for ping_id in ping_ids: + del self.pings[ping_id] + + async def keepalive_ping(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + This coroutine exits when the connection terminates and one of the + following happens: + - :meth:`ping` raises :exc:`ConnectionClosed`, or + - :meth:`auto_close_connection` cancels :attr:`keepalive_ping_task`. + """ + if self.ping_interval is None: + return + + try: + while True: + await asyncio.sleep(self.ping_interval) + + # ping() raises CancelledError if the connection is closed, + # when auto_close_connection() cancels self.keepalive_ping_task. + + # ping() raises ConnectionClosed if the connection is lost, + # when connection_lost() calls abort_pings(). + + ping_waiter = await self.ping() + + if self.ping_timeout is not None: + try: + await asyncio.wait_for(ping_waiter, self.ping_timeout) + except asyncio.TimeoutError: + error_logger.warning("Websocket timed out waiting for pong") + self.fail_connection(1011) + break + except asyncio.CancelledError: + raise + except ConnectionClosed: + pass + except BaseException: + error_logger.warning("Unexpected exception in keepalive ping task") + + def fail_connection(self, code: int = 1006, reason: str = "") -> bool: + """ + Fail the WebSocket Connection + This requires: + 1. Stopping all processing of incoming data, which means cancelling + pausing the underlying io protocol. The close code will be 1006 + unless a close frame was received earlier. + 2. Sending a close frame with an appropriate code if the opening + handshake succeeded and the other side is likely to process it. + 3. Closing the connection. :meth:`auto_close_connection` takes care + of this. + (The specification describes these steps in the opposite order.) + """ + if self.io_proto.transport is not None: + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # i.e. it can be called when the transport is already paused or closed. + self.io_proto.transport.pause_reading() + + # Keeping fail_connection() synchronous guarantees it can't + # get stuck and simplifies the implementation of the callers. + # Not draining the write buffer is acceptable in this context. + + # clear the send buffer + _ = self.connection.data_to_send() + # If we're not already CLOSED or CLOSING, then send the close. + if self.connection.state is OPEN: + self.connection.fail_connection(code, reason) + for frame_data in self.connection.data_to_send(): + self.io_proto.transport.write(frame_data) + if self.close_connection_task is not None and not self.close_connection_task.done(): + if self.data_finished_fut is not None and not self.data_finished_fut.done(): + self.data_finished_fut.cancel() + # Don't close, auto_close_connection will take care of it. + return False + SanicProtocol.close(self.io_proto) + return True + + async def auto_close_connection(self) -> None: + """ + Close the WebSocket Connection + When the opening handshake succeeds, :meth:`connection_open` starts + this coroutine in a task. It waits for the data transfer phase to + complete then it closes the TCP connection cleanly. + When the opening handshake fails, :meth:`fail_connection` does the + same. There's no data transfer phase in that case. + """ + try: + # Wait for the data transfer phase to complete. + if self.data_finished_fut is not None: + try: + await self.data_finished_fut + except asyncio.CancelledError: + pass + + # Cancel the keepalive ping task. + if self.keepalive_ping_task is not None: + self.keepalive_ping_task.cancel() + + # Half-close the TCP connection if possible (when there's no TLS). + if self.io_proto.transport is not None and self.io_proto.transport.can_write_eof(): + error_logger.warning("Websocket half-closing TCP connection") + self.io_proto.transport.write_eof() + if self.connection_lost_waiter is not None: + if await self.wait_for_connection_lost(timeout=0): + return + finally: + # The try/finally ensures that the transport never remains open, + # even if this coroutine is cancelled (for example). + if self.connection_lost_waiter is not None and self.connection_lost_waiter.done(): + if self.io_proto.transport is None or self.io_proto.transport.is_closing(): + return + SanicProtocol.close(self.io_proto) + if self.connection_lost_waiter is not None: + await self.wait_for_connection_lost() + if self.connection_lost_waiter.done(): + return + error_logger.warning("Timeout waiting for TCP connection to close. Aborting") + if self.io_proto.transport is not None: + self.io_proto.transport.abort() + + def abort_pings(self) -> None: + """ + Raise ConnectionClosed in pending keepalive pings. + They'll never receive a pong once the connection is closed. + """ + assert self.connection.state is CLOSED + + for ping in self.pings.values(): + ping.set_exception(ConnectionClosedError(1006, "")) + # If the exception is never retrieved, it will be logged when ping + # is garbage-collected. This is confusing for users. + # Given that ping is done (with an exception), canceling it does + # nothing, but it prevents logging the exception. + ping.cancel() + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + :param code: WebSocket close code + :param reason: WebSocket close reason + """ + async with self.conn_mutex: + if self.connection.state is OPEN: + self.connection.send_close(code, reason) + data_to_send = self.connection.data_to_send() + await self.send_data(data_to_send) + + async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Receive the next message. + Return a :class:`str` for a text frame and :class:`bytes` for a binary + frame. + When the end of the message stream is reached, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + If ``timeout`` is ``None``, block until a message is received. Else, + if no message is received within ``timeout`` seconds, return ``None``. + Set ``timeout`` to ``0`` to check if a message was already received. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises RuntimeError: if two tasks call :meth:`recv` or + :meth:`recv_streaming` concurrently + """ + + # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED + + if self.recv_lock.locked(): + raise RuntimeError( + "cannot call recv while another task " + "is already waiting for the next message" + ) + await self.recv_lock.acquire() + try: + return await self.assembler.get(timeout) + finally: + self.recv_lock.release() + + async def recv_burst(self, max_recv=256) -> Sequence[Data]: + """ + Receive the messages which have arrived since last checking. + Return a :class:`list` containing :class:`str` for a text frame + and :class:`bytes` for a binary frame. + When the end of the message stream is reached, :meth:`recv_burst` + raises :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, + it raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a + normal connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + :raises ~websockets.exceptions.ConnectionClosed: when the + connection is closed + :raises RuntimeError: if two threads call :meth:`recv` or + :meth:`recv_streaming` concurrently + """ + + # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED + + if self.recv_lock.locked(): + raise RuntimeError( + "cannot call recv_burst while another task " + "is already waiting for the next message" + ) + await self.recv_lock.acquire() + + messages = [] + try: + # Prevent pausing the transport when we're + # receiving a burst of messages + self.can_pause = False + while True: + m = await self.assembler.get(timeout=0) + if m is None: + # None left in the burst. This is good! + break + messages.append(m) + if len(messages) >= max_recv: + # Too much data in the pipe. Hit our burst limit. + break + # Allow an eventloop iteration for the + # next message to pass into the Assembler + await asyncio.sleep(0) + finally: + self.can_pause = True + self.recv_lock.release() + return messages + + async def recv_streaming(self) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + Return an iterator of :class:`str` for a text frame and :class:`bytes` + for a binary frame. The iterator should be exhausted, or else the + connection will become unusable. + With the exception of the return value, :meth:`recv_streaming` behaves + like :meth:`recv`. + """ + if self.recv_lock.locked(): + raise RuntimeError( + "cannot call recv_streaming while another task " + "is already waiting for the next message" + ) + await self.recv_lock.acquire() + try: + self.can_pause = False + async for m in self.assembler.get_iter(): + yield m + finally: + self.can_pause = True + self.recv_lock.release() + + async def send(self, message: Union[Data, Iterable[Data]]) -> None: + """ + Send a message. + A string (:class:`str`) is sent as a `Text frame`_. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a `Binary frame`_. + .. _Text frame: https://tools.ietf.org/html/rfc6455#section-5.6 + .. _Binary frame: https://tools.ietf.org/html/rfc6455#section-5.6 + :meth:`send` also accepts an iterable of strings, bytestrings, or + bytes-like objects. In that case the message is fragmented. Each item + is treated as a message fragment and sent in its own frame. All items + must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + :meth:`send` rejects dict-like objects because this is often an error. + If you wish to send the keys of a dict-like object as fragments, call + its :meth:`~dict.keys` method and pass the result to :meth:`send`. + :raises TypeError: for unsupported inputs + """ + async with self.conn_mutex: + + # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + self.connection.send_text(message.encode("utf-8")) + await self.send_data(self.connection.data_to_send()) + + elif isinstance(message, (bytes, bytearray, memoryview)): + self.connection.send_binary(message) + await self.send_data(self.connection.data_to_send()) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + # TODO use an incremental encoder maybe? + raise NotImplementedError + + else: + raise TypeError("data must be bytes, str, or iterable") + + async def ping(self, data: Optional[Data] = None) -> asyncio.Future: + """ + Send a ping. + Return an :class:`~asyncio.Future` that will be resolved when the + corresponding pong is received. You can ignore it if you don't intend + to wait. + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point:: + await pong_event = ws.ping() + await pong_event # only if you want to wait for the pong + By default, the ping contains four random bytes. This payload may be + overridden with the optional ``data`` argument which must be a string + (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + + # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED + + if data is not None: + data = prepare_ctrl(data) + + # Protect against duplicates if a payload is explicitly set. + if data in self.pings: + raise ValueError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pings: + data = struct.pack("!I", random.getrandbits(32)) + + self.pings[data] = self.io_proto.loop.create_future() + + self.connection.send_ping(data) + await self.send_data(self.connection.data_to_send()) + + return asyncio.shield(self.pings[data]) + + async def pong(self, data: Data = b"") -> None: + """ + Send a pong. + An unsolicited pong may serve as a unidirectional heartbeat. + The payload may be set with the optional ``data`` argument which must + be a string (which will be encoded to UTF-8) or a bytes-like object. + """ + async with self.conn_mutex: + # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED + + data = prepare_ctrl(data) + + self.connection.send_pong(data) + await self.send_data(self.connection.data_to_send()) + + async def send_data(self, data_to_send): + for data in data_to_send: + if data: + await self.io_proto.send(data) + else: + # Send an EOF + # We don't actually send it, just close the connection + if self.close_connection_task is not None and not self.close_connection_task.done() and \ + self.data_finished_fut is not None and not self.data_finished_fut.done(): + # Auto-close the connection + self.data_finished_fut.set_result(None) + else: + # This will fail the connection appropriately + self.io_proto.close() + + async def async_data_received(self, data_to_send, events_to_process): + if len(data_to_send) > 0: + # receiving data can generate data to send (eg, pong for a ping) + # send connection.data_to_send() + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + + def data_received(self, data): + self.connection.receive_data(data) + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + if len(data_to_send) > 0 or len(events_to_process) > 0: + asyncio.create_task(self.async_data_received(data_to_send, events_to_process)) + + async def async_eof_received(self, data_to_send, events_to_process): + # receiving EOF can generate data to send + # send connection.data_to_send() + await self.send_data(data_to_send) + if len(events_to_process) > 0: + await self.process_events(events_to_process) + + if self.close_connection_task is not None and not self.close_connection_task.done() and \ + self.data_finished_fut is not None and not self.data_finished_fut.done(): + # Auto-close the connection + self.data_finished_fut.set_result(None) + else: + # This will fail the connection appropriately + self.io_proto.close() + + def eof_received(self) -> Optional[bool]: + self.connection.receive_eof() + data_to_send = self.connection.data_to_send() + events_to_process = self.connection.events_received() + if len(data_to_send) > 0 or len(events_to_process) > 0: + asyncio.create_task(self.async_eof_received(data_to_send, events_to_process)) + return False + + def connection_lost(self, exc): + """ + The WebSocket Connection is Closed. + """ + self.connection.set_state(CLOSED) + self.abort_pings() + if self.connection_lost_waiter is not None: + self.connection_lost_waiter.set_result(None) class WebSocketProtocol(HttpProtocol): @@ -40,11 +768,12 @@ def __init__( **kwargs ): super().__init__(*args, **kwargs) - self.websocket = None + self.websocket = None # type: Union[None, WebsocketImplProtocol] # self.app = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size - self.websocket_max_queue = websocket_max_queue + if websocket_max_queue is not None and int(websocket_max_queue) > 0: + error_logger.warning(DeprecationWarning("websocket_max_queue is no longer used. No websocket message queueing is implemented.")) self.websocket_read_limit = websocket_read_limit self.websocket_write_limit = websocket_write_limit self.websocket_ping_interval = websocket_ping_interval @@ -70,73 +799,94 @@ def connection_lost(self, exc): def data_received(self, data): if self.websocket is not None: - # pass the data to the websocket protocol self.websocket.data_received(data) else: - try: - super().data_received(data) - except HttpParserUpgrade: - # this is okay, it just indicates we've got an upgrade request - pass + # Pass it to HttpProtocol handler first + # That will (hopefully) upgrade it to a websocket. + super().data_received(data) - def write_response(self, response): + def eof_received(self) -> Optional[bool]: if self.websocket is not None: - # websocket requests do not write a response - self.transport.close() + return self.websocket.eof_received() else: - super().write_response(response) + return False - async def websocket_handshake(self, request, subprotocols=None): - # let the websockets package do the handshake with the client - headers = {} + def close(self): + # Called by HttpProtocol at the end of connection_task + # If we've upgraded to websocket, we do our own closure + if self.websocket is not None: + self.websocket.fail_connection(1001) + else: + super().close() + + def close_if_idle(self): + # Called by Sanic Server when shutting down + # If we've upgraded to websocket, shut it down + if self.websocket is not None: + if self.websocket.connection.state in (CLOSING, CLOSED): + return True + else: + return self.websocket.fail_connection(1001) + else: + return super().close_if_idle() + async def websocket_handshake(self, request, subprotocols=Optional[Sequence[str]]): + # let the websockets package do the handshake with the client + headers = {"Upgrade": "websocket", "Connection": "Upgrade"} try: - key = handshake.check_request(request.headers) - handshake.build_response(headers, key) - except InvalidHandshake: - raise InvalidUsage("Invalid websocket request") - - subprotocol = None - if subprotocols and "Sec-Websocket-Protocol" in request.headers: - # select a subprotocol - client_subprotocols = [ - p.strip() - for p in request.headers["Sec-Websocket-Protocol"].split(",") - ] - for p in client_subprotocols: - if p in subprotocols: - subprotocol = p - headers["Sec-Websocket-Protocol"] = subprotocol - break + if subprotocols is not None: + # subprotocols can be a set or frozenset, but ServerConnection needs a list + subprotocols = list(subprotocols) + ws_server = ServerConnection(max_size=self.websocket_max_size, subprotocols=subprotocols, + state=OPEN, logger=error_logger) + key, extensions_header, protocol_header = ws_server.process_request(request) + except InvalidOrigin as exc: + raise Forbidden( + f"Failed to open a WebSocket connection: {exc}.\n", + ) + except InvalidUpgrade as exc: + msg = ( + f"Failed to open a WebSocket connection: {exc}.\n" + f"\n" + f"You cannot access a WebSocket server directly " + f"with a browser. You need a WebSocket client.\n" + ) + raise SanicException(msg, status_code=426) + except InvalidHandshake as exc: + raise InvalidUsage(f"Failed to open a WebSocket connection: {exc}.\n") + except Exception as exc: + msg = ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ) + raise SanicException(msg, status_code=500) + + headers["Sec-WebSocket-Accept"] = accept_key(key) + if extensions_header is not None: + headers["Sec-WebSocket-Extensions"] = extensions_header + + if protocol_header is not None: + headers["Sec-WebSocket-Protocol"] = protocol_header + headers["Date"] = formatdate(usegmt=True) # write the 101 response back to the client rv = b"HTTP/1.1 101 Switching Protocols\r\n" for k, v in headers.items(): rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" rv += b"\r\n" - request.transport.write(rv) - - # hook up the websocket protocol - self.websocket = WebSocketCommonProtocol( - close_timeout=self.websocket_timeout, - max_size=self.websocket_max_size, - max_queue=self.websocket_max_queue, - read_limit=self.websocket_read_limit, - write_limit=self.websocket_write_limit, - ping_interval=self.websocket_ping_interval, - ping_timeout=self.websocket_ping_timeout, - ) - # Following two lines are required for websockets 8.x - self.websocket.is_client = False - self.websocket.side = "server" - self.websocket.subprotocol = subprotocol - self.websocket.connection_made(request.transport) - self.websocket.connection_open() + await super().send(rv) + self.websocket = WebsocketImplProtocol(ws_server, ping_interval=self.websocket_ping_interval, ping_timeout=self.websocket_ping_timeout) + loop = request.transport.loop if hasattr(request, "transport") and hasattr(request.transport, "loop") else None + await self.websocket.connection_made(self, loop=loop) return self.websocket - + class WebSocketConnection: - + """ + This is for ASGI Connections. + It provides an interface similar to WebsocketProtocol, but + sends/receives over an ASGI connection. + """ # TODO # - Implement ping/pong @@ -180,5 +930,5 @@ async def accept(self) -> None: } ) - async def close(self) -> None: + async def close(self, code: int = 1000, reason: str = "") -> None: pass diff --git a/sanic/worker.py b/sanic/worker.py index 342900e6b1..0ab6ab0cb5 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -140,7 +140,7 @@ async def close(self): coros = [] for conn in self.connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) + coros.append(conn.websocket.close(code=1001)) else: conn.close() _shutdown = asyncio.gather(*coros, loop=self.loop) diff --git a/tests/test_app.py b/tests/test_app.py index 187267da8a..b081727f2e 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -177,7 +177,6 @@ async def handler(request, ws): @patch("sanic.app.WebSocketProtocol") def test_app_websocket_parameters(websocket_protocol_mock, app): app.config.WEBSOCKET_MAX_SIZE = 44 - app.config.WEBSOCKET_MAX_QUEUE = 45 app.config.WEBSOCKET_READ_LIMIT = 46 app.config.WEBSOCKET_WRITE_LIMIT = 47 app.config.WEBSOCKET_PING_TIMEOUT = 48 @@ -196,7 +195,6 @@ async def handler(request, ws): websocket_protocol_call_args = websocket_protocol_mock.call_args ws_kwargs = websocket_protocol_call_args[1] assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE - assert ws_kwargs["websocket_max_queue"] == app.config.WEBSOCKET_MAX_QUEUE assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT assert ( ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT diff --git a/tests/test_routes.py b/tests/test_routes.py index 06b4d799d2..3b43cc7b2c 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -668,28 +668,28 @@ async def handler(request, ws): def test_websocket_route_with_subprotocols(app): results = [] - @app.websocket("/ws", subprotocols=["foo", "bar"]) + @app.websocket("/ws", subprotocols=["zero", "one", "two", "three"]) async def handler(request, ws): results.append(ws.subprotocol) assert ws.subprotocol is not None - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["bar"]) + _, response = SanicTestClient(app).websocket("/ws", subprotocols=["one"]) assert response.opened is True - assert results == ["bar"] + assert results == ["one"] _, response = SanicTestClient(app).websocket( - "/ws", subprotocols=["bar", "foo"] + "/ws", subprotocols=["three", "one"] ) assert response.opened is True - assert results == ["bar", "bar"] + assert results == ["one", "one"] - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["baz"]) + _, response = SanicTestClient(app).websocket("/ws", subprotocols=["tree"]) assert response.opened is True - assert results == ["bar", "bar", None] + assert results == ["one", "one", None] _, response = SanicTestClient(app).websocket("/ws") assert response.opened is True - assert results == ["bar", "bar", None, None] + assert results == ["one", "one", None, None] @pytest.mark.parametrize("strict_slashes", [True, False, None]) diff --git a/tests/test_worker.py b/tests/test_worker.py index 252bdb3662..2db02b50bb 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -175,7 +175,7 @@ def test_worker_close(worker): worker.wsgi = mock.Mock() conn = mock.Mock() conn.websocket = mock.Mock() - conn.websocket.close_connection = mock.Mock(wraps=_a_noop) + conn.websocket.close = mock.Mock(wraps=_a_noop) worker.connections = set([conn]) worker.log = mock.Mock() worker.loop = loop @@ -190,5 +190,5 @@ def test_worker_close(worker): loop.run_until_complete(_close) assert worker.signal.stopped - assert conn.websocket.close_connection.called + assert conn.websocket.close.called assert len(worker.servers) == 0 From 4e9d98452cf2e165c8167df9e1e9902beb6162b1 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Wed, 16 Jun 2021 15:39:00 +1000 Subject: [PATCH 02/16] Update sanic/websocket.py Co-authored-by: Adam Hopkins --- sanic/websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sanic/websocket.py b/sanic/websocket.py index 1cb080ad3e..c87859e852 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -281,7 +281,7 @@ async def connection_made(self, io_proto: asyncio.Protocol, loop=None): except AttributeError: loop = asyncio.get_event_loop() self.loop = loop - self.io_proto = io_proto # this will be a WebSocketProtocol + self.io_proto: WebSocketProtocol = io_proto self.connection_lost_waiter = self.loop.create_future() self.data_finished_fut = asyncio.shield(self.loop.create_future()) From 2c8f750c87977304381286a6f1c59b627330b8dd Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Wed, 16 Jun 2021 15:40:09 +1000 Subject: [PATCH 03/16] Update sanic/websocket.py Co-authored-by: Adam Hopkins --- sanic/websocket.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sanic/websocket.py b/sanic/websocket.py index c87859e852..1ddd50d2c0 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -228,9 +228,9 @@ async def put(self, frame: Frame) -> None: class WebsocketImplProtocol: def __init__(self, connection, max_queue=None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, loop=None): - self.connection = connection # type: ServerConnection - self.io_proto = None # type: Optional[SanicProtocol] - self.loop = None # type: Optional[asyncio.BaseEventLoop] + self.connection: ServerConnection = connection + self.io_proto: Optional[SanicProtocol] = None + self.loop: Optional[asyncio.BaseEventLoop] = None self.max_queue = max_queue self.ping_interval = ping_interval self.ping_timeout = ping_timeout @@ -239,12 +239,12 @@ def __init__(self, connection, max_queue=None, ping_interval: Optional[float] = self.conn_mutex = asyncio.Lock() self.recv_lock = asyncio.Lock() self.process_event_mutex = asyncio.Lock() - self.data_finished_fut = None # type: Optional[asyncio.Future[None]] + self.data_finished_fut: Optional[asyncio.Future[None]] = None self.can_pause = True - self.pause_frame_fut = None # type: Optional[asyncio.Future[None]] + self.pause_frame_fut: Optional[asyncio.Future[None]] = None self.keepalive_ping_task = None self.close_connection_task = None - self.connection_lost_waiter = None # type: Optional[asyncio.Future[None]] + self.connection_lost_waiter: Optional[asyncio.Future[None]] = None @property def subprotocol(self): From e2d019871e954d0ec19dcf253b8dedaca5eb7a48 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Wed, 16 Jun 2021 15:43:09 +1000 Subject: [PATCH 04/16] Update sanic/websocket.py Co-authored-by: Adam Hopkins --- sanic/websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sanic/websocket.py b/sanic/websocket.py index 1ddd50d2c0..bd962bbbee 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -768,7 +768,7 @@ def __init__( **kwargs ): super().__init__(*args, **kwargs) - self.websocket = None # type: Union[None, WebsocketImplProtocol] + self.websocket: Union[None, WebsocketImplProtocol] = None # self.app = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size From cc2082c1dfbd74a58e0e0fbd541d82de0b8149fb Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Wed, 1 Sep 2021 10:44:23 +1000 Subject: [PATCH 05/16] wip, update websockets code to new Sans/IO API --- sanic/server/runners.py | 1 - sanic/websocket.py | 137 ++++++++++++++++++++-------------------- 2 files changed, 69 insertions(+), 69 deletions(-) diff --git a/sanic/server/runners.py b/sanic/server/runners.py index c28c525ee7..68e7eef182 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -278,7 +278,6 @@ def _build_protocol_kwargs( if hasattr(protocol, "websocket_handshake"): return { "websocket_max_size": config.WEBSOCKET_MAX_SIZE, - "websocket_max_queue": config.WEBSOCKET_MAX_QUEUE, "websocket_read_limit": config.WEBSOCKET_READ_LIMIT, "websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT, "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, diff --git a/sanic/websocket.py b/sanic/websocket.py index 8b33182580..f97966057b 100644 --- a/sanic/websocket.py +++ b/sanic/websocket.py @@ -12,7 +12,8 @@ Union, Iterable, Sequence, - Mapping, AsyncIterator + Mapping, AsyncIterator, + TYPE_CHECKING ) import codecs @@ -26,12 +27,13 @@ from websockets.utils import accept_key from sanic.exceptions import InvalidUsage, Forbidden, SanicException -from sanic.server import HttpProtocol, SanicProtocol +from sanic.server import HttpProtocol +from sanic.server.protocols.base_protocol import SanicProtocol +from sanic.response import BaseHTTPResponse from sanic.log import error_logger, logger import asyncio - ASIMessage = MutableMapping[str, Any] UTF8Decoder = codecs.getincrementaldecoder("utf-8") @@ -42,6 +44,20 @@ class WebsocketFrameAssembler: https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py """ __slots__ = ("protocol", "read_mutex", "write_mutex", "message_complete", "message_fetched", "get_in_progress", "decoder", "completed_queue", "chunks", "chunks_queue", "paused", "get_id", "put_id") + if TYPE_CHECKING: + protocol: "WebsocketImplProtocol" + read_mutex: asyncio.Lock + write_mutex: asyncio.Lock + message_complete: asyncio.Event + message_fetched: asyncio.Event + completed_queue: asyncio.Queue + get_in_progress: bool + decoder: Optional[codecs.IncrementalDecoder] + # For streaming chunks rather than messages: + chunks: List[Data] + chunks_queue: Optional[asyncio.Queue[Optional[Data]]] + paused: bool + def __init__(self, protocol) -> None: @@ -62,10 +78,10 @@ def __init__(self, protocol) -> None: self.get_in_progress = False # Decoder for text frames, None for binary frames. - self.decoder: Optional[codecs.IncrementalDecoder] = None + self.decoder = None # Buffer data from frames belonging to the same message. - self.chunks: List[Data] = [] + self.chunks = [] # When switching from "buffering" to "streaming", we use a thread-safe # queue for transferring frames from the writing thread (library code) @@ -74,7 +90,7 @@ def __init__(self, protocol) -> None: # value marking the end of the stream, superseding message_complete. # Stream data from frames belonging to the same message. - self.chunks_queue: Optional[asyncio.Queue[Optional[Data]]] = None + self.chunks_queue = None # Flag to indicate we've paused the protocol self.paused = False @@ -185,8 +201,8 @@ async def put(self, frame: Frame) -> None: :meth:`put` assumes that the stream of frames respects the protocol. If it doesn't, the behavior is undefined. """ - id = self.put_id - self.put_id += 1 + #id = self.put_id + #self.put_id += 1 async with self.write_mutex: if frame.opcode is Opcode.TEXT: self.decoder = UTF8Decoder(errors="strict") @@ -226,25 +242,45 @@ async def put(self, frame: Frame) -> None: self.message_fetched.clear() self.decoder = None + class WebsocketImplProtocol: + connection: ServerConnection + io_proto: Optional[SanicProtocol] + loop: Optional[asyncio.BaseEventLoop] + max_queue: int + ping_interval: Optional[float] + ping_timeout: Optional[float] + assembler: WebsocketFrameAssembler + pings: Dict[bytes, asyncio.Future] # Dict[bytes, asyncio.Future[None]] + conn_mutex: asyncio.Lock + recv_lock: asyncio.Lock + process_event_mutex: asyncio.Lock + can_pause: bool + data_finished_fut: Optional[asyncio.Future] # Optional[asyncio.Future[None]] + pause_frame_fut: Optional[asyncio.Future] # Optional[asyncio.Future[None]] + connection_lost_waiter: Optional[asyncio.Future] # Optional[asyncio.Future[None]] + keepalive_ping_task: Optional[asyncio.Task] + close_connection_task: Optional[asyncio.Task] + + def __init__(self, connection, max_queue=None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, loop=None): - self.connection: ServerConnection = connection - self.io_proto: Optional[SanicProtocol] = None - self.loop: Optional[asyncio.BaseEventLoop] = None + self.connection = connection + self.io_proto = None + self.loop = None self.max_queue = max_queue self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.assembler = WebsocketFrameAssembler(self) - self.pings: Dict[bytes, asyncio.Future[None]] = {} + self.pings = {} self.conn_mutex = asyncio.Lock() self.recv_lock = asyncio.Lock() self.process_event_mutex = asyncio.Lock() - self.data_finished_fut: Optional[asyncio.Future[None]] = None + self.data_finished_fut = None self.can_pause = True - self.pause_frame_fut: Optional[asyncio.Future[None]] = None + self.pause_frame_fut = None self.keepalive_ping_task = None self.close_connection_task = None - self.connection_lost_waiter: Optional[asyncio.Future[None]] = None + self.connection_lost_waiter = None @property def subprotocol(self): @@ -274,14 +310,14 @@ def resume_frames(self): self.pause_frame_fut = None return True - async def connection_made(self, io_proto: asyncio.Protocol, loop=None): + async def connection_made(self, io_proto: SanicProtocol, loop=None): if loop is None: try: loop = getattr(io_proto, "loop") except AttributeError: loop = asyncio.get_event_loop() self.loop = loop - self.io_proto: WebSocketProtocol = io_proto + self.io_proto = io_proto # this will be a WebSocketProtocol self.connection_lost_waiter = self.loop.create_future() self.data_finished_fut = asyncio.shield(self.loop.create_future()) @@ -626,19 +662,15 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None: self.connection.send_binary(message) await self.send_data(self.connection.data_to_send()) - # Catch a common mistake -- passing a dict to send(). - elif isinstance(message, Mapping): + # Catch a common mistake -- passing a dict to send(). raise TypeError("data is a dict-like object") - # Fragmented message -- regular iterator. - elif isinstance(message, Iterable): - # TODO use an incremental encoder maybe? - raise NotImplementedError - + # Fragmented message -- regular iterator. + raise NotImplementedError("Fragmented websocket messages are not supported.") else: - raise TypeError("data must be bytes, str, or iterable") + raise TypeError("Websocket data must be bytes, str.") async def ping(self, data: Optional[Data] = None) -> asyncio.Future: """ @@ -768,7 +800,7 @@ def __init__( **kwargs, ): super().__init__(*args, **kwargs) - self.websocket: Union[None, WebsocketImplProtocol] = None + self.websocket = None # type: Union[None, WebsocketImplProtocol] # self.app = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size @@ -779,19 +811,6 @@ def __init__( self.websocket_ping_interval = websocket_ping_interval self.websocket_ping_timeout = websocket_ping_timeout - # timeouts make no sense for websocket routes - def request_timeout_callback(self): - if self.websocket is None: - super().request_timeout_callback() - - def response_timeout_callback(self): - if self.websocket is None: - super().response_timeout_callback() - - def keep_alive_timeout_callback(self): - if self.websocket is None: - super().keep_alive_timeout_callback() - def connection_lost(self, exc): if self.websocket is not None: self.websocket.connection_lost(exc) @@ -839,42 +858,24 @@ async def websocket_handshake(self, request, subprotocols=Optional[Sequence[str] subprotocols = list(subprotocols) ws_server = ServerConnection(max_size=self.websocket_max_size, subprotocols=subprotocols, state=OPEN, logger=error_logger) - key, extensions_header, protocol_header = ws_server.process_request(request) - except InvalidOrigin as exc: - raise Forbidden( - f"Failed to open a WebSocket connection: {exc}.\n", - ) - except InvalidUpgrade as exc: - msg = ( - f"Failed to open a WebSocket connection: {exc}.\n" - f"\n" - f"You cannot access a WebSocket server directly " - f"with a browser. You need a WebSocket client.\n" - ) - raise SanicException(msg, status_code=426) - except InvalidHandshake as exc: - raise InvalidUsage(f"Failed to open a WebSocket connection: {exc}.\n") + resp = ws_server.accept(request) # type: websockets.http11.Response except Exception as exc: msg = ( "Failed to open a WebSocket connection.\n" "See server log for more information.\n" ) raise SanicException(msg, status_code=500) + if 100 <= resp.status_code <= 299: + rbytes = b"".join([b"HTTP/1.1 ", b'%d' % resp.status_code, b" ", resp.reason_phrase.encode("utf-8"), b"\r\n"]) + for k, v in resp.headers.items(): + rbytes += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" + if resp.body: + rbytes += b"\r\n" + resp.body + b"\r\n" + rbytes += b"\r\n" + await super().send(rbytes) + else: + raise SanicException(resp.body, resp.status_code) - headers["Sec-WebSocket-Accept"] = accept_key(key) - - if extensions_header is not None: - headers["Sec-WebSocket-Extensions"] = extensions_header - - if protocol_header is not None: - headers["Sec-WebSocket-Protocol"] = protocol_header - headers["Date"] = formatdate(usegmt=True) - # write the 101 response back to the client - rv = b"HTTP/1.1 101 Switching Protocols\r\n" - for k, v in headers.items(): - rv += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" - rv += b"\r\n" - await super().send(rv) self.websocket = WebsocketImplProtocol(ws_server, ping_interval=self.websocket_ping_interval, ping_timeout=self.websocket_ping_timeout) loop = request.transport.loop if hasattr(request, "transport") and hasattr(request.transport, "loop") else None await self.websocket.connection_made(self, loop=loop) From 2436780ea330e0a890d4ec7429b3dc73ad5687c2 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Wed, 1 Sep 2021 17:12:39 +1000 Subject: [PATCH 06/16] Refactored new websockets impl into own modules Incorporated other suggestions made by team --- sanic/app.py | 3 +- sanic/asgi.py | 2 +- sanic/models/asgi.py | 2 +- sanic/server/protocols/websocket_protocol.py | 113 +++++ sanic/server/runners.py | 2 +- sanic/server/websockets/__init__.py | 0 sanic/server/websockets/connection.py | 70 +++ sanic/server/websockets/frame.py | 214 ++++++++ .../websockets/impl.py} | 470 ++---------------- sanic/worker.py | 4 +- tests/test_asgi.py | 2 +- 11 files changed, 449 insertions(+), 433 deletions(-) create mode 100644 sanic/server/protocols/websocket_protocol.py create mode 100644 sanic/server/websockets/__init__.py create mode 100644 sanic/server/websockets/connection.py create mode 100644 sanic/server/websockets/frame.py rename sanic/{websocket.py => server/websockets/impl.py} (59%) diff --git a/sanic/app.py b/sanic/app.py index 225f244c34..42e503c34e 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -76,7 +76,8 @@ from sanic.server import serve, serve_multiple, serve_single from sanic.signals import Signal, SignalRouter from sanic.touchup import TouchUp, TouchUpMeta -from sanic.websocket import ConnectionClosed, WebSocketProtocol +from sanic.server.protocols.websocket_protocol import WebSocketProtocol +from sanic.server.websockets.impl import ConnectionClosed class Sanic(BaseSanic, metaclass=TouchUpMeta): diff --git a/sanic/asgi.py b/sanic/asgi.py index 13d4f87c20..55c18d5cf5 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -10,7 +10,7 @@ from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.request import Request from sanic.server import ConnInfo -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection class Lifespan: diff --git a/sanic/models/asgi.py b/sanic/models/asgi.py index 595b05532a..1b707ebc03 100644 --- a/sanic/models/asgi.py +++ b/sanic/models/asgi.py @@ -3,7 +3,7 @@ from typing import Any, Awaitable, Callable, MutableMapping, Optional, Union from sanic.exceptions import InvalidUsage -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection ASGIScope = MutableMapping[str, Any] diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py new file mode 100644 index 0000000000..7faeec852c --- /dev/null +++ b/sanic/server/protocols/websocket_protocol.py @@ -0,0 +1,113 @@ +from typing import ( + Optional, + Union, + Sequence, + TYPE_CHECKING +) + +from httptools import HttpParserUpgrade # type: ignore +from websockets.server import ServerConnection +from websockets.connection import OPEN, CLOSING, CLOSED + +from sanic.exceptions import SanicException +from sanic.server import HttpProtocol +from sanic.log import error_logger +from ..websockets.impl import WebsocketImplProtocol + +if TYPE_CHECKING: + from websockets import http11 + +class WebSocketProtocol(HttpProtocol): + def __init__( + self, + *args, + websocket_timeout=10, + websocket_max_size=None, + websocket_max_queue=None, + websocket_read_limit=2 ** 16, + websocket_write_limit=2 ** 16, + websocket_ping_interval=20, + websocket_ping_timeout=20, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.websocket = None # type: Union[None, WebsocketImplProtocol] + # self.app = None + self.websocket_timeout = websocket_timeout + self.websocket_max_size = websocket_max_size + if websocket_max_queue is not None and int(websocket_max_queue) > 0: + error_logger.warning(DeprecationWarning("websocket_max_queue is no longer used. No websocket message queueing is implemented.")) + self.websocket_read_limit = websocket_read_limit + self.websocket_write_limit = websocket_write_limit + self.websocket_ping_interval = websocket_ping_interval + self.websocket_ping_timeout = websocket_ping_timeout + + def connection_lost(self, exc): + if self.websocket is not None: + self.websocket.connection_lost(exc) + super().connection_lost(exc) + + def data_received(self, data): + if self.websocket is not None: + self.websocket.data_received(data) + else: + # Pass it to HttpProtocol handler first + # That will (hopefully) upgrade it to a websocket. + super().data_received(data) + + def eof_received(self) -> Optional[bool]: + if self.websocket is not None: + return self.websocket.eof_received() + else: + return False + + def close(self, timeout: Optional[float] = None): + # Called by HttpProtocol at the end of connection_task + # If we've upgraded to websocket, we do our own closing + if self.websocket is not None: + self.websocket.fail_connection(1001) + else: + super().close() + + def close_if_idle(self): + # Called by Sanic Server when shutting down + # If we've upgraded to websocket, shut it down + if self.websocket is not None: + if self.websocket.connection.state in (CLOSING, CLOSED): + return True + else: + return self.websocket.fail_connection(1001) + else: + return super().close_if_idle() + + async def websocket_handshake(self, request, subprotocols=Optional[Sequence[str]]): + # let the websockets package do the handshake with the client + headers = {"Upgrade": "websocket", "Connection": "Upgrade"} + try: + if subprotocols is not None: + # subprotocols can be a set or frozenset, but ServerConnection needs a list + subprotocols = list(subprotocols) + ws_conn = ServerConnection(max_size=self.websocket_max_size, subprotocols=subprotocols, + state=OPEN, logger=error_logger) + resp: "http11.Response" = ws_conn.accept(request) + except Exception as exc: + msg = ( + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ) + raise SanicException(msg, status_code=500) + if 100 <= resp.status_code <= 299: + rbytes = b"".join([b"HTTP/1.1 ", b'%d' % resp.status_code, b" ", resp.reason_phrase.encode("utf-8"), b"\r\n"]) + for k, v in resp.headers.items(): + rbytes += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" + if resp.body: + rbytes += b"\r\n" + resp.body + b"\r\n" + rbytes += b"\r\n" + await super().send(rbytes) + else: + raise SanicException(resp.body, resp.status_code) + + self.websocket = WebsocketImplProtocol(ws_conn, ping_interval=self.websocket_ping_interval, ping_timeout=self.websocket_ping_timeout) + loop = request.transport.loop if hasattr(request, "transport") and hasattr(request.transport, "loop") else None + await self.websocket.connection_made(self, loop=loop) + return self.websocket diff --git a/sanic/server/runners.py b/sanic/server/runners.py index b40808a9c4..8edc6214bc 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -178,7 +178,7 @@ def serve( coros = [] for conn in connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close_connection()) + coros.append(conn.websocket.close(code=1001)) else: conn.abort() diff --git a/sanic/server/websockets/__init__.py b/sanic/server/websockets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sanic/server/websockets/connection.py b/sanic/server/websockets/connection.py new file mode 100644 index 0000000000..95bd21abdd --- /dev/null +++ b/sanic/server/websockets/connection.py @@ -0,0 +1,70 @@ +from typing import Optional, List, Callable, Awaitable, Union, Dict, MutableMapping, Any + +ASIMessage = MutableMapping[str, Any] + +class WebSocketConnection: + """ + This is for ASGI Connections. + It provides an interface similar to WebsocketProtocol, but + sends/receives over an ASGI connection. + """ + # TODO + # - Implement ping/pong + + def __init__( + self, + send: Callable[[ASIMessage], Awaitable[None]], + receive: Callable[[], Awaitable[ASIMessage]], + subprotocols: Optional[List[str]] = None, + ) -> None: + self._send = send + self._receive = receive + self._subprotocols = subprotocols or [] + + async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: + message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} + + if isinstance(data, bytes): + message.update({"bytes": data}) + else: + message.update({"text": str(data)}) + + await self._send(message) + + async def recv(self, *args, **kwargs) -> Optional[str]: + message = await self._receive() + + if message["type"] == "websocket.receive": + return message["text"] + elif message["type"] == "websocket.disconnect": + pass + + return None + + receive = recv + + async def accept(self, subprotocols: Optional[List[str]] = None) -> None: + subprotocol = None + if subprotocols: + for subp in subprotocols: + if subp in self.subprotocols: + subprotocol = subp + break + + await self._send( + { + "type": "websocket.accept", + "subprotocol": subprotocol, + } + ) + + async def close(self, code: int = 1000, reason: str = "") -> None: + pass + + @property + def subprotocols(self): + return self._subprotocols + + @subprotocols.setter + def subprotocols(self, subprotocols: Optional[List[str]] = None): + self._subprotocols = subprotocols or [] diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py new file mode 100644 index 0000000000..a28269bcce --- /dev/null +++ b/sanic/server/websockets/frame.py @@ -0,0 +1,214 @@ +import asyncio +import codecs +from typing import TYPE_CHECKING, Optional, List, AsyncIterator + +from websockets.frames import Opcode, Frame +from websockets.typing import Data + + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +class WebsocketFrameAssembler: + """ + Assemble a message from frames. + Code borrowed from aaugustin/websockets project: + https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py + """ + __slots__ = ("protocol", "read_mutex", "write_mutex", "message_complete", "message_fetched", "get_in_progress", "decoder", "completed_queue", "chunks", "chunks_queue", "paused", "get_id", "put_id") + if TYPE_CHECKING: + protocol: "WebsocketImplProtocol" + read_mutex: asyncio.Lock + write_mutex: asyncio.Lock + message_complete: asyncio.Event + message_fetched: asyncio.Event + completed_queue: asyncio.Queue + get_in_progress: bool + decoder: Optional[codecs.IncrementalDecoder] + # For streaming chunks rather than messages: + chunks: List[Data] + chunks_queue: Optional[asyncio.Queue[Optional[Data]]] + paused: bool + + + def __init__(self, protocol) -> None: + + self.protocol = protocol + + self.read_mutex = asyncio.Lock() + self.write_mutex = asyncio.Lock() + + self.completed_queue = asyncio.Queue(maxsize=1) # type: asyncio.Queue[Data] + + + # put() sets this event to tell get() that a message can be fetched. + self.message_complete = asyncio.Event() + # get() sets this event to let put() + self.message_fetched = asyncio.Event() + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # Decoder for text frames, None for binary frames. + self.decoder = None + + # Buffer data from frames belonging to the same message. + self.chunks = [] + + # When switching from "buffering" to "streaming", we use a thread-safe + # queue for transferring frames from the writing thread (library code) + # to the reading thread (user code). We're buffering when chunks_queue + # is None and streaming when it's a Queue. None is a sentinel + # value marking the end of the stream, superseding message_complete. + + # Stream data from frames belonging to the same message. + self.chunks_queue = None + + # Flag to indicate we've paused the protocol + self.paused = False + + + async def get(self, timeout: Optional[float] = None) -> Optional[Data]: + """ + Read the next message. + :meth:`get` returns a single :class:`str` or :class:`bytes`. + If the :message was fragmented, :meth:`get` waits until the last frame + is received, then it reassembles the message. + If ``timeout`` is set and elapses before a complete message is + received, :meth:`get` returns ``None``. + """ + async with self.read_mutex: + if timeout is not None and timeout <= 0: + if not self.message_complete.is_set(): + return None + assert not self.get_in_progress + self.get_in_progress = True + + # If the message_complete event isn't set yet, release the lock to + # allow put() to run and eventually set it. + # Locking with get_in_progress ensures only one thread can get here. + if timeout is None: + completed = await self.message_complete.wait() + elif timeout <= 0: + completed = self.message_complete.is_set() + else: + completed = await asyncio.wait_for(self.message_complete.wait(), timeout=timeout) + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + assert self.get_in_progress + self.get_in_progress = False + + # Waiting for a complete message timed out. + if not completed: + return None + + assert self.message_complete.is_set() + self.message_complete.clear() + + joiner: Data = b"" if self.decoder is None else "" + # mypy cannot figure out that chunks have the proper type. + message: Data = joiner.join(self.chunks) # type: ignore + + assert not self.message_fetched.is_set() + self.message_fetched.set() + self.chunks = [] + assert self.chunks_queue is None + + return message + + async def get_iter(self) -> AsyncIterator[Data]: + """ + Stream the next message. + Iterating the return value of :meth:`get_iter` yields a :class:`str` + or :class:`bytes` for each frame in the message. + """ + async with self.read_mutex: + assert not self.get_in_progress + self.get_in_progress = True + + chunks = self.chunks + self.chunks = [] + self.chunks_queue = asyncio.Queue() + + # Sending None in chunk_queue supersedes setting message_complete + # when switching to "streaming". If message is already complete + # when the switch happens, put() didn't send None, so we have to. + if self.message_complete.is_set(): + await self.chunks_queue.put(None) + + # Locking with get_in_progress ensures only one thread can get here. + for c in chunks: + yield c + while True: + chunk = await self.chunks_queue.get() + if chunk is None: + break + yield chunk + + # Unpause the transport, if its paused + if self.paused: + self.protocol.resume_frames() + self.paused = False + assert self.get_in_progress + self.get_in_progress = False + assert self.message_complete.is_set() + self.message_complete.clear() + + assert not self.message_fetched.is_set() + + self.message_fetched.set() + + assert self.chunks == [] + self.chunks_queue = None + + async def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + When ``frame`` is the final frame in a message, :meth:`put` waits + until the message is fetched, either by calling :meth:`get` or by + iterating the return value of :meth:`get_iter`. + :meth:`put` assumes that the stream of frames respects the protocol. + If it doesn't, the behavior is undefined. + """ + #id = self.put_id + #self.put_id += 1 + async with self.write_mutex: + if frame.opcode is Opcode.TEXT: + self.decoder = UTF8Decoder(errors="strict") + elif frame.opcode is Opcode.BINARY: + self.decoder = None + elif frame.opcode is Opcode.CONT: + pass + else: + # Ignore control frames. + return + data: Data + if self.decoder is not None: + data = self.decoder.decode(frame.data, frame.fin) + else: + data = frame.data + if self.chunks_queue is None: + self.chunks.append(data) + else: + await self.chunks_queue.put(data) + + if not frame.fin: + return + if not self.get_in_progress: + self.paused = self.protocol.pause_frames() + # Message is complete. Wait until it's fetched to return. + + if self.chunks_queue is not None: + await self.chunks_queue.put(None) + + assert not self.message_complete.is_set() + self.message_complete.set() + assert not self.message_fetched.is_set() + + # Release the lock to allow get() to run and eventually set the event. + await self.message_fetched.wait() + assert self.message_fetched.is_set() + self.message_fetched.clear() + self.decoder = None diff --git a/sanic/websocket.py b/sanic/server/websockets/impl.py similarity index 59% rename from sanic/websocket.py rename to sanic/server/websockets/impl.py index f97966057b..c05a6bb63f 100644 --- a/sanic/websocket.py +++ b/sanic/server/websockets/impl.py @@ -1,247 +1,17 @@ +import asyncio import random import struct -from email.utils import formatdate -from typing import ( - Any, - Awaitable, - Callable, - Dict, - List, - MutableMapping, - Optional, - Union, - Iterable, - Sequence, - Mapping, AsyncIterator, - TYPE_CHECKING -) -import codecs - -from httptools import HttpParserUpgrade # type: ignore +from typing import Optional, Mapping, Iterable, Union, AsyncIterator, Sequence, Dict + +from websockets.connection import CLOSED, OPEN, Event, CLOSING +from websockets.exceptions import ConnectionClosedError, ConnectionClosed +from websockets.frames import prepare_ctrl, Frame, OP_PONG from websockets.server import ServerConnection -from websockets.connection import Event, OPEN, CLOSING, CLOSED -from websockets.exceptions import ConnectionClosed, InvalidHandshake, InvalidOrigin, InvalidUpgrade, \ - ConnectionClosedError from websockets.typing import Data -from websockets.frames import Frame, Opcode, prepare_ctrl, OP_PONG -from websockets.utils import accept_key -from sanic.exceptions import InvalidUsage, Forbidden, SanicException -from sanic.server import HttpProtocol +from sanic.log import error_logger from sanic.server.protocols.base_protocol import SanicProtocol -from sanic.response import BaseHTTPResponse -from sanic.log import error_logger, logger - -import asyncio - -ASIMessage = MutableMapping[str, Any] -UTF8Decoder = codecs.getincrementaldecoder("utf-8") - -class WebsocketFrameAssembler: - """ - Assemble a message from frames. - Code borrowed from aaugustin/websockets project: - https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py - """ - __slots__ = ("protocol", "read_mutex", "write_mutex", "message_complete", "message_fetched", "get_in_progress", "decoder", "completed_queue", "chunks", "chunks_queue", "paused", "get_id", "put_id") - if TYPE_CHECKING: - protocol: "WebsocketImplProtocol" - read_mutex: asyncio.Lock - write_mutex: asyncio.Lock - message_complete: asyncio.Event - message_fetched: asyncio.Event - completed_queue: asyncio.Queue - get_in_progress: bool - decoder: Optional[codecs.IncrementalDecoder] - # For streaming chunks rather than messages: - chunks: List[Data] - chunks_queue: Optional[asyncio.Queue[Optional[Data]]] - paused: bool - - - def __init__(self, protocol) -> None: - - self.protocol = protocol - - self.read_mutex = asyncio.Lock() - self.write_mutex = asyncio.Lock() - - self.completed_queue = asyncio.Queue(maxsize=1) # type: asyncio.Queue[Data] - - - # put() sets this event to tell get() that a message can be fetched. - self.message_complete = asyncio.Event() - # get() sets this event to let put() - self.message_fetched = asyncio.Event() - - # This flag prevents concurrent calls to get() by user code. - self.get_in_progress = False - - # Decoder for text frames, None for binary frames. - self.decoder = None - - # Buffer data from frames belonging to the same message. - self.chunks = [] - - # When switching from "buffering" to "streaming", we use a thread-safe - # queue for transferring frames from the writing thread (library code) - # to the reading thread (user code). We're buffering when chunks_queue - # is None and streaming when it's a Queue. None is a sentinel - # value marking the end of the stream, superseding message_complete. - - # Stream data from frames belonging to the same message. - self.chunks_queue = None - - # Flag to indicate we've paused the protocol - self.paused = False - - - async def get(self, timeout: Optional[float] = None) -> Optional[Data]: - """ - Read the next message. - :meth:`get` returns a single :class:`str` or :class:`bytes`. - If the :message was fragmented, :meth:`get` waits until the last frame - is received, then it reassembles the message. - If ``timeout`` is set and elapses before a complete message is - received, :meth:`get` returns ``None``. - """ - async with self.read_mutex: - if timeout is not None and timeout <= 0: - if not self.message_complete.is_set(): - return None - assert not self.get_in_progress - self.get_in_progress = True - - # If the message_complete event isn't set yet, release the lock to - # allow put() to run and eventually set it. - # Locking with get_in_progress ensures only one thread can get here. - if timeout is None: - completed = await self.message_complete.wait() - elif timeout <= 0: - completed = self.message_complete.is_set() - else: - completed = await asyncio.wait_for(self.message_complete.wait(), timeout=timeout) - - # Unpause the transport, if its paused - if self.paused: - self.protocol.resume_frames() - self.paused = False - assert self.get_in_progress - self.get_in_progress = False - - # Waiting for a complete message timed out. - if not completed: - return None - - assert self.message_complete.is_set() - self.message_complete.clear() - - joiner: Data = b"" if self.decoder is None else "" - # mypy cannot figure out that chunks have the proper type. - message: Data = joiner.join(self.chunks) # type: ignore - - assert not self.message_fetched.is_set() - self.message_fetched.set() - self.chunks = [] - assert self.chunks_queue is None - - return message - - async def get_iter(self) -> AsyncIterator[Data]: - """ - Stream the next message. - Iterating the return value of :meth:`get_iter` yields a :class:`str` - or :class:`bytes` for each frame in the message. - """ - async with self.read_mutex: - assert not self.get_in_progress - self.get_in_progress = True - - chunks = self.chunks - self.chunks = [] - self.chunks_queue = asyncio.Queue() - - # Sending None in chunk_queue supersedes setting message_complete - # when switching to "streaming". If message is already complete - # when the switch happens, put() didn't send None, so we have to. - if self.message_complete.is_set(): - await self.chunks_queue.put(None) - - # Locking with get_in_progress ensures only one thread can get here. - for c in chunks: - yield c - while True: - chunk = await self.chunks_queue.get() - if chunk is None: - break - yield chunk - - # Unpause the transport, if its paused - if self.paused: - self.protocol.resume_frames() - self.paused = False - assert self.get_in_progress - self.get_in_progress = False - assert self.message_complete.is_set() - self.message_complete.clear() - - assert not self.message_fetched.is_set() - - self.message_fetched.set() - - assert self.chunks == [] - self.chunks_queue = None - - async def put(self, frame: Frame) -> None: - """ - Add ``frame`` to the next message. - When ``frame`` is the final frame in a message, :meth:`put` waits - until the message is fetched, either by calling :meth:`get` or by - iterating the return value of :meth:`get_iter`. - :meth:`put` assumes that the stream of frames respects the protocol. - If it doesn't, the behavior is undefined. - """ - #id = self.put_id - #self.put_id += 1 - async with self.write_mutex: - if frame.opcode is Opcode.TEXT: - self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is Opcode.BINARY: - self.decoder = None - elif frame.opcode is Opcode.CONT: - pass - else: - # Ignore control frames. - return - data: Data - if self.decoder is not None: - data = self.decoder.decode(frame.data, frame.fin) - else: - data = frame.data - if self.chunks_queue is None: - self.chunks.append(data) - else: - await self.chunks_queue.put(data) - - if not frame.fin: - return - if not self.get_in_progress: - self.paused = self.protocol.pause_frames() - # Message is complete. Wait until it's fetched to return. - - if self.chunks_queue is not None: - await self.chunks_queue.put(None) - - assert not self.message_complete.is_set() - self.message_complete.set() - assert not self.message_fetched.is_set() - - # Release the lock to allow get() to run and eventually set the event. - await self.message_fetched.wait() - assert self.message_fetched.is_set() - self.message_fetched.clear() - self.decoder = None - +from .frame import WebsocketFrameAssembler class WebsocketImplProtocol: connection: ServerConnection @@ -252,6 +22,7 @@ class WebsocketImplProtocol: ping_timeout: Optional[float] assembler: WebsocketFrameAssembler pings: Dict[bytes, asyncio.Future] # Dict[bytes, asyncio.Future[None]] + pings: Dict[bytes, asyncio.Future] # Dict[bytes, asyncio.Future[None]] conn_mutex: asyncio.Lock recv_lock: asyncio.Lock process_event_mutex: asyncio.Lock @@ -334,7 +105,9 @@ async def wait_for_connection_lost(self, timeout=10) -> bool: """ if self.connection_lost_waiter is None: return False - if not self.connection_lost_waiter.done(): + if self.connection_lost_waiter.done(): + return True + else: try: await asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), timeout @@ -370,7 +143,7 @@ async def process_pong(self, frame: Frame) -> None: ping.set_result(None) if ping_id == frame.data: break - else: # pragma: no cover + else: # noqa assert False, "ping_id is in self.pings" # Remove acknowledged pings from self.pings. for ping_id in ping_ids: @@ -410,7 +183,7 @@ async def keepalive_ping(self) -> None: raise except ConnectionClosed: pass - except BaseException: + except Exception: error_logger.warning("Unexpected exception in keepalive ping task") def fail_connection(self, code: int = 1006, reason: str = "") -> bool: @@ -443,12 +216,17 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: self.connection.fail_connection(code, reason) for frame_data in self.connection.data_to_send(): self.io_proto.transport.write(frame_data) + if code == 1006: + # Special case: 1006 consider the transport already closed + self.connection.set_state(CLOSED) if self.close_connection_task is not None and not self.close_connection_task.done(): if self.data_finished_fut is not None and not self.data_finished_fut.done(): self.data_finished_fut.cancel() # Don't close, auto_close_connection will take care of it. return False - SanicProtocol.close(self.io_proto) + + # No auto-closer available, just abort the connection + SanicProtocol.close(self.io_proto, timeout=1.0) return True async def auto_close_connection(self) -> None: @@ -485,14 +263,12 @@ async def auto_close_connection(self) -> None: if self.connection_lost_waiter is not None and self.connection_lost_waiter.done(): if self.io_proto.transport is None or self.io_proto.transport.is_closing(): return - SanicProtocol.close(self.io_proto) + SanicProtocol.close(self.io_proto, timeout=999) if self.connection_lost_waiter is not None: - await self.wait_for_connection_lost() - if self.connection_lost_waiter.done(): + if await self.wait_for_connection_lost(timeout=5): return error_logger.warning("Timeout waiting for TCP connection to close. Aborting") - if self.io_proto.transport is not None: - self.io_proto.transport.abort() + SanicProtocol.abort(self.io_proto) def abort_pings(self) -> None: """ @@ -545,14 +321,14 @@ async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: :meth:`recv_streaming` concurrently """ - # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED - if self.recv_lock.locked(): raise RuntimeError( "cannot call recv while another task " "is already waiting for the next message" ) await self.recv_lock.acquire() + if self.connection.state in (CLOSED, CLOSING): + raise RuntimeError("Cannot receive from websocket interface after it is closed.") try: return await self.assembler.get(timeout) finally: @@ -575,15 +351,14 @@ async def recv_burst(self, max_recv=256) -> Sequence[Data]: :meth:`recv_streaming` concurrently """ - # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED - if self.recv_lock.locked(): raise RuntimeError( "cannot call recv_burst while another task " "is already waiting for the next message" ) await self.recv_lock.acquire() - + if self.connection.state in (CLOSED, CLOSING): + raise RuntimeError("Cannot receive from websocket interface after it is closed.") messages = [] try: # Prevent pausing the transport when we're @@ -621,6 +396,8 @@ async def recv_streaming(self) -> AsyncIterator[Data]: "is already waiting for the next message" ) await self.recv_lock.acquire() + if self.connection.state in (CLOSED, CLOSING): + raise RuntimeError("Cannot receive from websocket interface after it is closed.") try: self.can_pause = False async for m in self.assembler.get_iter(): @@ -649,7 +426,10 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None: """ async with self.conn_mutex: - # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED + if self.connection.state in (CLOSED, CLOSING): + raise RuntimeError("Cannot write to websocket interface after it is closed.") + if self.data_finished_fut is None or self.data_finished_fut.cancelled() or self.data_finished_fut.done(): + raise RuntimeError("Cannot write to websocket interface after it is finished.") # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -687,9 +467,8 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future: (which will be encoded to UTF-8) or a bytes-like object. """ async with self.conn_mutex: - - # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED - + if self.connection.state in (CLOSED, CLOSING): + raise RuntimeError("Cannot send a ping when the websocket interface is closed.") if data is not None: data = prepare_ctrl(data) @@ -716,7 +495,9 @@ async def pong(self, data: Data = b"") -> None: be a string (which will be encoded to UTF-8) or a bytes-like object. """ async with self.conn_mutex: - # TODO HANDLE THE SITUATION WHERE THE CONNECTION IS CLOSED + if self.connection.state in (CLOSED, CLOSING): + # Cannot send pong after transport is shutting down + return data = prepare_ctrl(data) @@ -729,17 +510,17 @@ async def send_data(self, data_to_send): await self.io_proto.send(data) else: # Send an EOF - # We don't actually send it, just close the connection + # We don't actually send it, just trigger to autoclose the connection if self.close_connection_task is not None and not self.close_connection_task.done() and \ self.data_finished_fut is not None and not self.data_finished_fut.done(): # Auto-close the connection self.data_finished_fut.set_result(None) else: # This will fail the connection appropriately - self.io_proto.close() + SanicProtocol.close(self.io_proto, timeout=1.0) async def async_data_received(self, data_to_send, events_to_process): - if len(data_to_send) > 0: + if self.connection.state == OPEN and len(data_to_send) > 0: # receiving data can generate data to send (eg, pong for a ping) # send connection.data_to_send() await self.send_data(data_to_send) @@ -756,7 +537,8 @@ def data_received(self, data): async def async_eof_received(self, data_to_send, events_to_process): # receiving EOF can generate data to send # send connection.data_to_send() - await self.send_data(data_to_send) + if self.connection.state == OPEN: + await self.send_data(data_to_send) if len(events_to_process) > 0: await self.process_events(events_to_process) @@ -766,7 +548,7 @@ async def async_eof_received(self, data_to_send, events_to_process): self.data_finished_fut.set_result(None) else: # This will fail the connection appropriately - self.io_proto.close() + SanicProtocol.close(self.io_proto, timeout=1.0) def eof_received(self) -> Optional[bool]: self.connection.receive_eof() @@ -784,167 +566,3 @@ def connection_lost(self, exc): self.abort_pings() if self.connection_lost_waiter is not None: self.connection_lost_waiter.set_result(None) - - -class WebSocketProtocol(HttpProtocol): - def __init__( - self, - *args, - websocket_timeout=10, - websocket_max_size=None, - websocket_max_queue=None, - websocket_read_limit=2 ** 16, - websocket_write_limit=2 ** 16, - websocket_ping_interval=20, - websocket_ping_timeout=20, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.websocket = None # type: Union[None, WebsocketImplProtocol] - # self.app = None - self.websocket_timeout = websocket_timeout - self.websocket_max_size = websocket_max_size - if websocket_max_queue is not None and int(websocket_max_queue) > 0: - error_logger.warning(DeprecationWarning("websocket_max_queue is no longer used. No websocket message queueing is implemented.")) - self.websocket_read_limit = websocket_read_limit - self.websocket_write_limit = websocket_write_limit - self.websocket_ping_interval = websocket_ping_interval - self.websocket_ping_timeout = websocket_ping_timeout - - def connection_lost(self, exc): - if self.websocket is not None: - self.websocket.connection_lost(exc) - super().connection_lost(exc) - - def data_received(self, data): - if self.websocket is not None: - self.websocket.data_received(data) - else: - # Pass it to HttpProtocol handler first - # That will (hopefully) upgrade it to a websocket. - super().data_received(data) - - def eof_received(self) -> Optional[bool]: - if self.websocket is not None: - return self.websocket.eof_received() - else: - return False - - def close(self): - # Called by HttpProtocol at the end of connection_task - # If we've upgraded to websocket, we do our own closure - if self.websocket is not None: - self.websocket.fail_connection(1001) - else: - super().close() - - def close_if_idle(self): - # Called by Sanic Server when shutting down - # If we've upgraded to websocket, shut it down - if self.websocket is not None: - if self.websocket.connection.state in (CLOSING, CLOSED): - return True - else: - return self.websocket.fail_connection(1001) - else: - return super().close_if_idle() - - async def websocket_handshake(self, request, subprotocols=Optional[Sequence[str]]): - # let the websockets package do the handshake with the client - headers = {"Upgrade": "websocket", "Connection": "Upgrade"} - try: - if subprotocols is not None: - # subprotocols can be a set or frozenset, but ServerConnection needs a list - subprotocols = list(subprotocols) - ws_server = ServerConnection(max_size=self.websocket_max_size, subprotocols=subprotocols, - state=OPEN, logger=error_logger) - resp = ws_server.accept(request) # type: websockets.http11.Response - except Exception as exc: - msg = ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ) - raise SanicException(msg, status_code=500) - if 100 <= resp.status_code <= 299: - rbytes = b"".join([b"HTTP/1.1 ", b'%d' % resp.status_code, b" ", resp.reason_phrase.encode("utf-8"), b"\r\n"]) - for k, v in resp.headers.items(): - rbytes += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" - if resp.body: - rbytes += b"\r\n" + resp.body + b"\r\n" - rbytes += b"\r\n" - await super().send(rbytes) - else: - raise SanicException(resp.body, resp.status_code) - - self.websocket = WebsocketImplProtocol(ws_server, ping_interval=self.websocket_ping_interval, ping_timeout=self.websocket_ping_timeout) - loop = request.transport.loop if hasattr(request, "transport") and hasattr(request.transport, "loop") else None - await self.websocket.connection_made(self, loop=loop) - return self.websocket - - -class WebSocketConnection: - """ - This is for ASGI Connections. - It provides an interface similar to WebsocketProtocol, but - sends/receives over an ASGI connection. - """ - # TODO - # - Implement ping/pong - - def __init__( - self, - send: Callable[[ASIMessage], Awaitable[None]], - receive: Callable[[], Awaitable[ASIMessage]], - subprotocols: Optional[List[str]] = None, - ) -> None: - self._send = send - self._receive = receive - self._subprotocols = subprotocols or [] - - async def send(self, data: Union[str, bytes], *args, **kwargs) -> None: - message: Dict[str, Union[str, bytes]] = {"type": "websocket.send"} - - if isinstance(data, bytes): - message.update({"bytes": data}) - else: - message.update({"text": str(data)}) - - await self._send(message) - - async def recv(self, *args, **kwargs) -> Optional[str]: - message = await self._receive() - - if message["type"] == "websocket.receive": - return message["text"] - elif message["type"] == "websocket.disconnect": - pass - - return None - - receive = recv - - async def accept(self, subprotocols: Optional[List[str]] = None) -> None: - subprotocol = None - if subprotocols: - for subp in subprotocols: - if subp in self.subprotocols: - subprotocol = subp - break - - await self._send( - { - "type": "websocket.accept", - "subprotocol": subprotocol, - } - ) - - async def close(self, code: int = 1000, reason: str = "") -> None: - pass - - @property - def subprotocols(self): - return self._subprotocols - - @subprotocols.setter - def subprotocols(self, subprotocols: Optional[List[str]] = None): - self._subprotocols = subprotocols or [] diff --git a/sanic/worker.py b/sanic/worker.py index a196a95bc0..875ebc3cfc 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -9,7 +9,7 @@ from sanic.log import logger from sanic.server import HttpProtocol, Signal, serve -from sanic.websocket import WebSocketProtocol +from sanic.server.protocols.websocket_protocol import WebSocketProtocol try: @@ -147,7 +147,7 @@ async def close(self): if hasattr(conn, "websocket") and conn.websocket: coros.append(conn.websocket.close(code=1001)) else: - conn.close() + conn.abort() _shutdown = asyncio.gather(*coros, loop=self.loop) await _shutdown diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 0745c2edd7..3d464a4f55 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -10,7 +10,7 @@ from sanic.exceptions import Forbidden, InvalidUsage, ServiceUnavailable from sanic.request import Request from sanic.response import json, text -from sanic.websocket import WebSocketConnection +from sanic.server.websockets.connection import WebSocketConnection @pytest.fixture From aea3538ea11d97ba00ce3f9e85750d06cd8082f1 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Wed, 15 Sep 2021 15:01:49 +1000 Subject: [PATCH 07/16] Another round of work on the new websockets impl * Added websocket_timeout support (matching previous/legacy support) * Lots more comments * Incorporated suggested changes from previous round of review * Changed RuntimeError usage to ServerError * Changed SanicException usage to ServerError * Removed some redundant asserts * Change remaining asserts to ServerErrors * Fixed some timeout handling issues * Fixed websocket.close() handling, and made it more robust * Made auto_close task smarter and more error-resilient * Made fail_connection routine smarter and more error-resilient --- sanic/app.py | 4 +- sanic/mixins/routes.py | 2 +- sanic/server/protocols/websocket_protocol.py | 131 +++++--- sanic/server/runners.py | 6 +- sanic/server/websockets/connection.py | 14 +- sanic/server/websockets/frame.py | 125 ++++++-- sanic/server/websockets/impl.py | 318 ++++++++++++++----- sanic/worker.py | 5 +- 8 files changed, 433 insertions(+), 172 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index d036aceba6..effaabe86a 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -74,10 +74,10 @@ from sanic.server import AsyncioServer, HttpProtocol from sanic.server import Signal as ServerSignal from sanic.server import serve, serve_multiple, serve_single -from sanic.signals import Signal, SignalRouter -from sanic.touchup import TouchUp, TouchUpMeta from sanic.server.protocols.websocket_protocol import WebSocketProtocol from sanic.server.websockets.impl import ConnectionClosed +from sanic.signals import Signal, SignalRouter +from sanic.touchup import TouchUp, TouchUpMeta class Sanic(BaseSanic, metaclass=TouchUpMeta): diff --git a/sanic/mixins/routes.py b/sanic/mixins/routes.py index 114808cf8a..1f37ef872b 100644 --- a/sanic/mixins/routes.py +++ b/sanic/mixins/routes.py @@ -124,7 +124,7 @@ def decorator(handler): if isinstance(subprotocols, list): # Ordered subprotocols, maintain order subprotocols = tuple(subprotocols) - if isinstance(subprotocols, set): + elif isinstance(subprotocols, set): # subprotocol is unordered, keep it unordered subprotocols = frozenset(subprotocols) diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 7faeec852c..77288cc633 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -1,44 +1,65 @@ -from typing import ( - Optional, - Union, - Sequence, - TYPE_CHECKING -) +from typing import TYPE_CHECKING, Optional, Sequence, Union from httptools import HttpParserUpgrade # type: ignore +from websockets.connection import CLOSED, CLOSING, OPEN from websockets.server import ServerConnection -from websockets.connection import OPEN, CLOSING, CLOSED -from sanic.exceptions import SanicException -from sanic.server import HttpProtocol +from sanic.exceptions import ServerError from sanic.log import error_logger +from sanic.server import HttpProtocol + from ..websockets.impl import WebsocketImplProtocol + if TYPE_CHECKING: from websockets import http11 + class WebSocketProtocol(HttpProtocol): + + websocket: Union[None, WebsocketImplProtocol] + websocket_timeout: float + websocket_max_size = Union[None, int] + websocket_ping_interval = Union[None, float] + websocket_ping_timeout = Union[None, float] + def __init__( self, *args, - websocket_timeout=10, - websocket_max_size=None, - websocket_max_queue=None, - websocket_read_limit=2 ** 16, - websocket_write_limit=2 ** 16, - websocket_ping_interval=20, - websocket_ping_timeout=20, + websocket_timeout: Optional[float] = 10.0, + websocket_max_size: Optional[int] = None, + websocket_max_queue: Optional[int] = None, # max_queue is deprecated + websocket_read_limit: Optional[int] = 2 ** 16, + websocket_write_limit: Optional[int] = 2 ** 16, + websocket_ping_interval: Optional[float] = 20.0, + websocket_ping_timeout: Optional[float] = 20.0, **kwargs, ): super().__init__(*args, **kwargs) - self.websocket = None # type: Union[None, WebsocketImplProtocol] - # self.app = None + self.websocket = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size if websocket_max_queue is not None and int(websocket_max_queue) > 0: - error_logger.warning(DeprecationWarning("websocket_max_queue is no longer used. No websocket message queueing is implemented.")) - self.websocket_read_limit = websocket_read_limit - self.websocket_write_limit = websocket_write_limit + error_logger.warning( + DeprecationWarning( + "websocket_max_queue is no longer used. No websocket message queueing is implemented." + ) + ) + if websocket_read_limit is not None and int(websocket_read_limit) > 0: + error_logger.warning( + DeprecationWarning( + "websocket_read_limit is no longer used. No websocket rate limiting is implemented." + ) + ) + if ( + websocket_write_limit is not None + and int(websocket_write_limit) > 0 + ): + error_logger.warning( + DeprecationWarning( + "websocket_write_limit is no longer used. No websocket rate limiting is implemented." + ) + ) self.websocket_ping_interval = websocket_ping_interval self.websocket_ping_timeout = websocket_ping_timeout @@ -65,7 +86,11 @@ def close(self, timeout: Optional[float] = None): # Called by HttpProtocol at the end of connection_task # If we've upgraded to websocket, we do our own closing if self.websocket is not None: - self.websocket.fail_connection(1001) + if self.websocket.loop is not None: + ... + self.websocket.loop.create_task(self.websocket.close(1001)) + else: + self.websocket.fail_connection(1001) else: super().close() @@ -75,39 +100,65 @@ def close_if_idle(self): if self.websocket is not None: if self.websocket.connection.state in (CLOSING, CLOSED): return True + elif self.websocket.loop is not None: + self.websocket.loop.create_task(self.websocket.close(1001)) else: - return self.websocket.fail_connection(1001) + self.websocket.fail_connection(1001) else: return super().close_if_idle() - async def websocket_handshake(self, request, subprotocols=Optional[Sequence[str]]): + async def websocket_handshake( + self, request, subprotocols=Optional[Sequence[str]] + ): # let the websockets package do the handshake with the client headers = {"Upgrade": "websocket", "Connection": "Upgrade"} try: if subprotocols is not None: # subprotocols can be a set or frozenset, but ServerConnection needs a list subprotocols = list(subprotocols) - ws_conn = ServerConnection(max_size=self.websocket_max_size, subprotocols=subprotocols, - state=OPEN, logger=error_logger) + ws_conn = ServerConnection( + max_size=self.websocket_max_size, + subprotocols=subprotocols, + state=OPEN, + logger=error_logger, + ) resp: "http11.Response" = ws_conn.accept(request) except Exception as exc: msg = ( - "Failed to open a WebSocket connection.\n" - "See server log for more information.\n" - ) - raise SanicException(msg, status_code=500) + "Failed to open a WebSocket connection.\n" + "See server log for more information.\n" + ) + raise ServerError(msg, status_code=500) if 100 <= resp.status_code <= 299: - rbytes = b"".join([b"HTTP/1.1 ", b'%d' % resp.status_code, b" ", resp.reason_phrase.encode("utf-8"), b"\r\n"]) - for k, v in resp.headers.items(): - rbytes += k.encode("utf-8") + b": " + v.encode("utf-8") + b"\r\n" - if resp.body: - rbytes += b"\r\n" + resp.body + b"\r\n" - rbytes += b"\r\n" - await super().send(rbytes) + rbody = "".join( + [ + "HTTP/1.1 ", + str(resp.status_code), + " ", + resp.reason_phrase, + "\r\n", + ] + ) + rbody += "".join(f"{k}: {v}\r\n" for k, v in resp.headers.items()) + if resp.body is not None: + rbody += f"\r\n{resp.body}\r\n\r\n" + else: + rbody += "\r\n" + await super().send(rbody.encode()) else: - raise SanicException(resp.body, resp.status_code) + raise ServerError(resp.body, resp.status_code) - self.websocket = WebsocketImplProtocol(ws_conn, ping_interval=self.websocket_ping_interval, ping_timeout=self.websocket_ping_timeout) - loop = request.transport.loop if hasattr(request, "transport") and hasattr(request.transport, "loop") else None + self.websocket = WebsocketImplProtocol( + ws_conn, + ping_interval=self.websocket_ping_interval, + ping_timeout=self.websocket_ping_timeout, + close_timeout=self.websocket_timeout, + ) + loop = ( + request.transport.loop + if hasattr(request, "transport") + and hasattr(request.transport, "loop") + else None + ) await self.websocket.connection_made(self, loop=loop) return self.websocket diff --git a/sanic/server/runners.py b/sanic/server/runners.py index 8edc6214bc..0b1305a38a 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -175,15 +175,11 @@ def serve( # Force close non-idle connection after waiting for # graceful_shutdown_timeout - coros = [] for conn in connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close(code=1001)) + conn.websocket.fail_connection(code=1001) else: conn.abort() - - _shutdown = asyncio.gather(*coros) - loop.run_until_complete(_shutdown) loop.run_until_complete(app._server_event("shutdown", "after")) remove_unix_socket(unix) diff --git a/sanic/server/websockets/connection.py b/sanic/server/websockets/connection.py index 95bd21abdd..c53a65a58d 100644 --- a/sanic/server/websockets/connection.py +++ b/sanic/server/websockets/connection.py @@ -1,13 +1,25 @@ -from typing import Optional, List, Callable, Awaitable, Union, Dict, MutableMapping, Any +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + MutableMapping, + Optional, + Union, +) + ASIMessage = MutableMapping[str, Any] + class WebSocketConnection: """ This is for ASGI Connections. It provides an interface similar to WebsocketProtocol, but sends/receives over an ASGI connection. """ + # TODO # - Implement ping/pong diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index a28269bcce..bbe74fcf33 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -1,20 +1,39 @@ import asyncio import codecs -from typing import TYPE_CHECKING, Optional, List, AsyncIterator -from websockets.frames import Opcode, Frame +from typing import TYPE_CHECKING, AsyncIterator, List, Optional + +from websockets.frames import Frame, Opcode from websockets.typing import Data +from sanic.exceptions import ServerError + UTF8Decoder = codecs.getincrementaldecoder("utf-8") + class WebsocketFrameAssembler: """ Assemble a message from frames. Code borrowed from aaugustin/websockets project: https://github.com/aaugustin/websockets/blob/6eb98dd8fa5b2c896b9f6be7e8d117708da82a39/src/websockets/sync/messages.py """ - __slots__ = ("protocol", "read_mutex", "write_mutex", "message_complete", "message_fetched", "get_in_progress", "decoder", "completed_queue", "chunks", "chunks_queue", "paused", "get_id", "put_id") + + __slots__ = ( + "protocol", + "read_mutex", + "write_mutex", + "message_complete", + "message_fetched", + "get_in_progress", + "decoder", + "completed_queue", + "chunks", + "chunks_queue", + "paused", + "get_id", + "put_id", + ) if TYPE_CHECKING: protocol: "WebsocketImplProtocol" read_mutex: asyncio.Lock @@ -29,7 +48,6 @@ class WebsocketFrameAssembler: chunks_queue: Optional[asyncio.Queue[Optional[Data]]] paused: bool - def __init__(self, protocol) -> None: self.protocol = protocol @@ -37,8 +55,9 @@ def __init__(self, protocol) -> None: self.read_mutex = asyncio.Lock() self.write_mutex = asyncio.Lock() - self.completed_queue = asyncio.Queue(maxsize=1) # type: asyncio.Queue[Data] - + self.completed_queue = asyncio.Queue( + maxsize=1 + ) # type: asyncio.Queue[Data] # put() sets this event to tell get() that a message can be fetched. self.message_complete = asyncio.Event() @@ -66,7 +85,6 @@ def __init__(self, protocol) -> None: # Flag to indicate we've paused the protocol self.paused = False - async def get(self, timeout: Optional[float] = None) -> Optional[Data]: """ Read the next message. @@ -80,41 +98,63 @@ async def get(self, timeout: Optional[float] = None) -> Optional[Data]: if timeout is not None and timeout <= 0: if not self.message_complete.is_set(): return None - assert not self.get_in_progress + if self.get_in_progress: + # This should be guarded against with the read_mutex, exception is only here as a failsafe + raise ServerError( + "Called get() on Websocket frame assembler while asynchronous get is already in progress." + ) self.get_in_progress = True # If the message_complete event isn't set yet, release the lock to # allow put() to run and eventually set it. - # Locking with get_in_progress ensures only one thread can get here. + # Locking with get_in_progress ensures only one task can get here. if timeout is None: completed = await self.message_complete.wait() elif timeout <= 0: completed = self.message_complete.is_set() else: - completed = await asyncio.wait_for(self.message_complete.wait(), timeout=timeout) + try: + await asyncio.wait_for( + self.message_complete.wait(), timeout=timeout + ) + except asyncio.TimeoutError: + ... + finally: + completed = self.message_complete.is_set() # Unpause the transport, if its paused if self.paused: self.protocol.resume_frames() self.paused = False - assert self.get_in_progress + if not self.get_in_progress: + # This should be guarded against with the read_mutex, exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an asynchronous get was in progress." + ) self.get_in_progress = False # Waiting for a complete message timed out. if not completed: return None + if not self.message_complete.is_set(): + return None - assert self.message_complete.is_set() self.message_complete.clear() joiner: Data = b"" if self.decoder is None else "" # mypy cannot figure out that chunks have the proper type. message: Data = joiner.join(self.chunks) # type: ignore - - assert not self.message_fetched.is_set() + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, and get_in_progress check, + # this exception is here as a failsafe + raise ServerError( + "Websocket get() found a message when state was already fetched." + ) self.message_fetched.set() self.chunks = [] - assert self.chunks_queue is None + self.chunks_queue = ( + None # this should already be None, but set it here for safety + ) return message @@ -125,7 +165,11 @@ async def get_iter(self) -> AsyncIterator[Data]: or :class:`bytes` for each frame in the message. """ async with self.read_mutex: - assert not self.get_in_progress + if self.get_in_progress: + # This should be guarded against with the read_mutex, exception is only here as a failsafe + raise ServerError( + "Called get_iter on Websocket frame assembler while asynchronous get is already in progress." + ) self.get_in_progress = True chunks = self.chunks @@ -151,16 +195,29 @@ async def get_iter(self) -> AsyncIterator[Data]: if self.paused: self.protocol.resume_frames() self.paused = False - assert self.get_in_progress + if not self.get_in_progress: + # This should be guarded against with the read_mutex, exception is here as a failsafe + raise ServerError( + "State of Websocket frame assembler was modified while an asynchronous get was in progress." + ) self.get_in_progress = False - assert self.message_complete.is_set() + if not self.message_complete.is_set(): + # This should be guarded against with the read_mutex, exception is here as a failsafe + raise ServerError( + "Websocket frame assembler chunks queue ended before message was complete." + ) self.message_complete.clear() - - assert not self.message_fetched.is_set() + if self.message_fetched.is_set(): + # This should be guarded against with the read_mutex, and get_in_progress check, + # this exception is here as a failsafe + raise ServerError( + "Websocket get_iter() found a message when state was already fetched." + ) self.message_fetched.set() - - assert self.chunks == [] + self.chunks = ( + [] + ) # this should already be empty, but set it here for safety self.chunks_queue = None async def put(self, frame: Frame) -> None: @@ -172,8 +229,7 @@ async def put(self, frame: Frame) -> None: :meth:`put` assumes that the stream of frames respects the protocol. If it doesn't, the behavior is undefined. """ - #id = self.put_id - #self.put_id += 1 + async with self.write_mutex: if frame.opcode is Opcode.TEXT: self.decoder = UTF8Decoder(errors="strict") @@ -197,18 +253,25 @@ async def put(self, frame: Frame) -> None: if not frame.fin: return if not self.get_in_progress: + # nobody is waiting for this frame, so try to pause subsequent frames at the protocol level self.paused = self.protocol.pause_frames() # Message is complete. Wait until it's fetched to return. if self.chunks_queue is not None: await self.chunks_queue.put(None) - - assert not self.message_complete.is_set() - self.message_complete.set() - assert not self.message_fetched.is_set() - - # Release the lock to allow get() to run and eventually set the event. + if self.message_complete.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when a message was already in its chamber." + ) + self.message_complete.set() # Signal to get() it can serve the + if self.message_fetched.is_set(): + # This should be guarded against with the write_mutex + raise ServerError( + "Websocket put() got a new message when the previous message was not yet fetched." + ) + + # Allow get() to run and eventually set the event. await self.message_fetched.wait() - assert self.message_fetched.is_set() self.message_fetched.clear() self.decoder = None diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index c05a6bb63f..f8ec6f3e62 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -1,23 +1,41 @@ import asyncio import random import struct -from typing import Optional, Mapping, Iterable, Union, AsyncIterator, Sequence, Dict -from websockets.connection import CLOSED, OPEN, Event, CLOSING -from websockets.exceptions import ConnectionClosedError, ConnectionClosed -from websockets.frames import prepare_ctrl, Frame, OP_PONG +from typing import ( + TYPE_CHECKING, + AsyncIterator, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) + +from websockets.connection import CLOSED, CLOSING, OPEN, Event +from websockets.exceptions import ConnectionClosed, ConnectionClosedError +from websockets.frames import OP_PONG from websockets.server import ServerConnection from websockets.typing import Data from sanic.log import error_logger from sanic.server.protocols.base_protocol import SanicProtocol + +from ...exceptions import ServerError from .frame import WebsocketFrameAssembler + +if TYPE_CHECKING: + from websockets.frames import Frame + + class WebsocketImplProtocol: connection: ServerConnection io_proto: Optional[SanicProtocol] loop: Optional[asyncio.BaseEventLoop] max_queue: int + close_timeout: float ping_interval: Optional[float] ping_timeout: Optional[float] assembler: WebsocketFrameAssembler @@ -27,18 +45,30 @@ class WebsocketImplProtocol: recv_lock: asyncio.Lock process_event_mutex: asyncio.Lock can_pause: bool - data_finished_fut: Optional[asyncio.Future] # Optional[asyncio.Future[None]] + data_finished_fut: Optional[ + asyncio.Future + ] # Optional[asyncio.Future[None]] pause_frame_fut: Optional[asyncio.Future] # Optional[asyncio.Future[None]] - connection_lost_waiter: Optional[asyncio.Future] # Optional[asyncio.Future[None]] + connection_lost_waiter: Optional[ + asyncio.Future + ] # Optional[asyncio.Future[None]] keepalive_ping_task: Optional[asyncio.Task] - close_connection_task: Optional[asyncio.Task] - - - def __init__(self, connection, max_queue=None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, loop=None): + auto_closer_task: Optional[asyncio.Task] + + def __init__( + self, + connection, + max_queue=None, + ping_interval: Optional[float] = 20, + ping_timeout: Optional[float] = 20, + close_timeout: Optional[float] = 10, + loop=None, + ): self.connection = connection self.io_proto = None self.loop = None self.max_queue = max_queue + self.close_timeout = close_timeout self.ping_interval = ping_interval self.ping_timeout = ping_timeout self.assembler = WebsocketFrameAssembler(self) @@ -50,7 +80,7 @@ def __init__(self, connection, max_queue=None, ping_interval: Optional[float] = self.can_pause = True self.pause_frame_fut = None self.keepalive_ping_task = None - self.close_connection_task = None + self.auto_closer_task = None self.connection_lost_waiter = None @property @@ -73,7 +103,9 @@ def resume_frames(self): if self.pause_frame_fut is None: return False if self.loop is None or self.io_proto is None: - error_logger.warning("Websocket attempting to resume reading frames, but connection is gone.") + error_logger.warning( + "Websocket attempting to resume reading frames, but connection is gone." + ) return False if self.io_proto.transport is not None: self.io_proto.transport.resume_reading() @@ -93,12 +125,18 @@ async def connection_made(self, io_proto: SanicProtocol, loop=None): self.data_finished_fut = asyncio.shield(self.loop.create_future()) if self.ping_interval is not None: - self.keepalive_ping_task = asyncio.create_task(self.keepalive_ping()) - self.close_connection_task = asyncio.create_task(self.auto_close_connection()) + self.keepalive_ping_task = asyncio.create_task( + self.keepalive_ping() + ) + self.auto_closer_task = asyncio.create_task( + self.auto_close_connection() + ) - async def wait_for_connection_lost(self, timeout=10) -> bool: + async def wait_for_connection_lost(self, timeout=None) -> bool: """ Wait until the TCP connection is closed or ``timeout`` elapses. + If timeout is None, wait forever. + Recommend you should pass in self.close_timeout as timeout Return ``True`` if the connection is closed and ``False`` otherwise. @@ -112,12 +150,12 @@ async def wait_for_connection_lost(self, timeout=10) -> bool: await asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), timeout ) + return True except asyncio.TimeoutError: - pass - # Re-check self.connection_lost_waiter.done() synchronously because - # connection_lost() could run between the moment the timeout occurs - # and the moment this coroutine resumes running. - return self.connection_lost_waiter.done() + # Re-check self.connection_lost_waiter.done() synchronously because + # connection_lost() could run between the moment the timeout occurs + # and the moment this coroutine resumes running. + return self.connection_lost_waiter.done() async def process_events(self, events: Sequence[Event]) -> None: """ @@ -132,10 +170,9 @@ async def process_events(self, events: Sequence[Event]) -> None: else: await self.assembler.put(event) - async def process_pong(self, frame: Frame) -> None: + async def process_pong(self, frame: "Frame") -> None: if frame.data in self.pings: # Acknowledge all pings up to the one matching this pong. - ping_id = None ping_ids = [] for ping_id, ping in self.pings.items(): ping_ids.append(ping_id) @@ -144,7 +181,7 @@ async def process_pong(self, frame: Frame) -> None: if ping_id == frame.data: break else: # noqa - assert False, "ping_id is in self.pings" + raise ServerError("ping_id is not in self.pings") # Remove acknowledged pings from self.pings. for ping_id in ping_ids: del self.pings[ping_id] @@ -176,7 +213,9 @@ async def keepalive_ping(self) -> None: try: await asyncio.wait_for(ping_waiter, self.ping_timeout) except asyncio.TimeoutError: - error_logger.warning("Websocket timed out waiting for pong") + error_logger.warning( + "Websocket timed out waiting for pong" + ) self.fail_connection(1011) break except asyncio.CancelledError: @@ -199,7 +238,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: of this. (The specification describes these steps in the opposite order.) """ - if self.io_proto.transport is not None: + if self.io_proto and self.io_proto.transport is not None: # Stop new data coming in # In Python Version 3.7: pause_reading is idempotent # i.e. it can be called when the transport is already paused or closed. @@ -213,21 +252,54 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: _ = self.connection.data_to_send() # If we're not already CLOSED or CLOSING, then send the close. if self.connection.state is OPEN: - self.connection.fail_connection(code, reason) - for frame_data in self.connection.data_to_send(): - self.io_proto.transport.write(frame_data) + if code in (1000, 1001): + self.connection.send_close(code, reason) + else: + self.connection.fail_connection(code, reason) + try: + data_to_send = self.connection.data_to_send() + while ( + len(data_to_send) + and self.io_proto + and self.io_proto.transport is not None + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + ... if code == 1006: # Special case: 1006 consider the transport already closed self.connection.set_state(CLOSED) - if self.close_connection_task is not None and not self.close_connection_task.done(): - if self.data_finished_fut is not None and not self.data_finished_fut.done(): - self.data_finished_fut.cancel() - # Don't close, auto_close_connection will take care of it. + if ( + self.data_finished_fut is not None + and not self.data_finished_fut.done() + ): + # We have a graceful auto-closer. Use it to close the connection. + self.data_finished_fut.cancel() + self.data_finished_fut = None + if self.auto_closer_task is not None: + if self.auto_closer_task.done(): + # auto_closer has already closed the connection? + self.auto_closer_task = None + return True return False - - # No auto-closer available, just abort the connection - SanicProtocol.close(self.io_proto, timeout=1.0) - return True + else: + # Auto closer is not running. Do it manually. + if ( + self.loop is None + or self.io_proto is None + or self.io_proto.transport is None + ): + # We were never open, or already closed + return True + # cannot use the connection_lost_waiter future here, + # because this is a synchronous function. + self.io_proto.transport.close() + self.loop.call_later( + self.close_timeout, self.io_proto.transport.abort + ) async def auto_close_connection(self) -> None: """ @@ -244,38 +316,75 @@ async def auto_close_connection(self) -> None: try: await self.data_finished_fut except asyncio.CancelledError: - pass + # Cancelled error will be called when data phase is cancelled + # This can be if an error occurred or the client app closed the connection + ... # Cancel the keepalive ping task. if self.keepalive_ping_task is not None: self.keepalive_ping_task.cancel() # Half-close the TCP connection if possible (when there's no TLS). - if self.io_proto.transport is not None and self.io_proto.transport.can_write_eof(): + if ( + self.io_proto + and self.io_proto.transport is not None + and self.io_proto.transport.can_write_eof() + ): error_logger.warning("Websocket half-closing TCP connection") self.io_proto.transport.write_eof() if self.connection_lost_waiter is not None: if await self.wait_for_connection_lost(timeout=0): return + except asyncio.CancelledError: + ... finally: # The try/finally ensures that the transport never remains open, # even if this coroutine is cancelled (for example). - if self.connection_lost_waiter is not None and self.connection_lost_waiter.done(): - if self.io_proto.transport is None or self.io_proto.transport.is_closing(): - return - SanicProtocol.close(self.io_proto, timeout=999) - if self.connection_lost_waiter is not None: - if await self.wait_for_connection_lost(timeout=5): + self.auto_closer_task = None + if self.io_proto is None or self.io_proto.transport is None: + # we were never open, or already dead and buried. Can't do any finalization. + return + elif ( + self.connection_lost_waiter is not None + and self.connection_lost_waiter.done() + ): + # connection was confirmed closed already, proceed to abort waiter + ... + elif self.io_proto.transport.is_closing(): + # Connection is already closing (due to half-close above) + # proceed to abort waiter + ... + else: + self.io_proto.transport.close() + if self.connection_lost_waiter is None: + # Our connection monitor task isn't running. + try: + await asyncio.sleep(self.close_timeout) + except asyncio.CancelledError: + ... + if self.io_proto and self.io_proto.transport is not None: + self.io_proto.transport.abort() + else: + if await self.wait_for_connection_lost( + timeout=self.close_timeout + ): + # Connection aborted before the timeout expired. return - error_logger.warning("Timeout waiting for TCP connection to close. Aborting") - SanicProtocol.abort(self.io_proto) + error_logger.warning( + "Timeout waiting for TCP connection to close. Aborting" + ) + if self.io_proto and self.io_proto.transport is not None: + self.io_proto.transport.abort() def abort_pings(self) -> None: """ Raise ConnectionClosed in pending keepalive pings. They'll never receive a pong once the connection is closed. """ - assert self.connection.state is CLOSED + if self.connection.state is not CLOSED: + raise ServerError( + "webscoket about_pings should only be called after connection state is changed to CLOSED" + ) for ping in self.pings.values(): ping.set_exception(ConnectionClosedError(1006, "")) @@ -317,18 +426,19 @@ async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: Set ``timeout`` to ``0`` to check if a message was already received. :raises ~websockets.exceptions.ConnectionClosed: when the connection is closed - :raises RuntimeError: if two tasks call :meth:`recv` or + :raises ServerError: if two tasks call :meth:`recv` or :meth:`recv_streaming` concurrently """ if self.recv_lock.locked(): - raise RuntimeError( - "cannot call recv while another task " - "is already waiting for the next message" + raise ServerError( + "cannot call recv while another task is already waiting for the next message" ) await self.recv_lock.acquire() if self.connection.state in (CLOSED, CLOSING): - raise RuntimeError("Cannot receive from websocket interface after it is closed.") + raise ServerError( + "Cannot receive from websocket interface after it is closed." + ) try: return await self.assembler.get(timeout) finally: @@ -347,35 +457,36 @@ async def recv_burst(self, max_recv=256) -> Sequence[Data]: error or a network failure. :raises ~websockets.exceptions.ConnectionClosed: when the connection is closed - :raises RuntimeError: if two threads call :meth:`recv` or + :raises ServerError: if two tasks call :meth:`recv_burst` or :meth:`recv_streaming` concurrently """ if self.recv_lock.locked(): - raise RuntimeError( - "cannot call recv_burst while another task " - "is already waiting for the next message" + raise ServerError( + "cannot call recv_burst while another task is already waiting for the next message" ) await self.recv_lock.acquire() if self.connection.state in (CLOSED, CLOSING): - raise RuntimeError("Cannot receive from websocket interface after it is closed.") + raise ServerError( + "Cannot receive from websocket interface after it is closed." + ) messages = [] try: # Prevent pausing the transport when we're # receiving a burst of messages self.can_pause = False while True: - m = await self.assembler.get(timeout=0) - if m is None: - # None left in the burst. This is good! - break - messages.append(m) - if len(messages) >= max_recv: - # Too much data in the pipe. Hit our burst limit. - break - # Allow an eventloop iteration for the - # next message to pass into the Assembler - await asyncio.sleep(0) + m = await self.assembler.get(timeout=0) + if m is None: + # None left in the burst. This is good! + break + messages.append(m) + if len(messages) >= max_recv: + # Too much data in the pipe. Hit our burst limit. + break + # Allow an eventloop iteration for the + # next message to pass into the Assembler + await asyncio.sleep(0) finally: self.can_pause = True self.recv_lock.release() @@ -391,13 +502,14 @@ async def recv_streaming(self) -> AsyncIterator[Data]: like :meth:`recv`. """ if self.recv_lock.locked(): - raise RuntimeError( - "cannot call recv_streaming while another task " - "is already waiting for the next message" + raise ServerError( + "cannot call recv_streaming while another task is already waiting for the next message" ) await self.recv_lock.acquire() if self.connection.state in (CLOSED, CLOSING): - raise RuntimeError("Cannot receive from websocket interface after it is closed.") + raise ServerError( + "Cannot receive from websocket interface after it is closed." + ) try: self.can_pause = False async for m in self.assembler.get_iter(): @@ -427,9 +539,17 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None: async with self.conn_mutex: if self.connection.state in (CLOSED, CLOSING): - raise RuntimeError("Cannot write to websocket interface after it is closed.") - if self.data_finished_fut is None or self.data_finished_fut.cancelled() or self.data_finished_fut.done(): - raise RuntimeError("Cannot write to websocket interface after it is finished.") + raise ServerError( + "Cannot write to websocket interface after it is closed." + ) + if ( + self.data_finished_fut is None + or self.data_finished_fut.cancelled() + or self.data_finished_fut.done() + ): + raise ServerError( + "Cannot write to websocket interface after it is finished." + ) # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -448,7 +568,9 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None: elif isinstance(message, Iterable): # Fragmented message -- regular iterator. - raise NotImplementedError("Fragmented websocket messages are not supported.") + raise NotImplementedError( + "Fragmented websocket messages are not supported." + ) else: raise TypeError("Websocket data must be bytes, str.") @@ -468,13 +590,20 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future: """ async with self.conn_mutex: if self.connection.state in (CLOSED, CLOSING): - raise RuntimeError("Cannot send a ping when the websocket interface is closed.") + raise ServerError( + "Cannot send a ping when the websocket interface is closed." + ) if data is not None: - data = prepare_ctrl(data) + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) # Protect against duplicates if a payload is explicitly set. if data in self.pings: - raise ValueError("already waiting for a pong with the same data") + raise ValueError( + "already waiting for a pong with the same data" + ) # Generate a unique random payload otherwise. while data is None or data in self.pings: @@ -498,9 +627,10 @@ async def pong(self, data: Data = b"") -> None: if self.connection.state in (CLOSED, CLOSING): # Cannot send pong after transport is shutting down return - - data = prepare_ctrl(data) - + if isinstance(data, str): + data = data.encode("utf-8") + elif isinstance(data, (bytearray, memoryview)): + data = bytes(data) self.connection.send_pong(data) await self.send_data(self.connection.data_to_send()) @@ -511,8 +641,12 @@ async def send_data(self, data_to_send): else: # Send an EOF # We don't actually send it, just trigger to autoclose the connection - if self.close_connection_task is not None and not self.close_connection_task.done() and \ - self.data_finished_fut is not None and not self.data_finished_fut.done(): + if ( + self.auto_closer_task is not None + and not self.auto_closer_task.done() + and self.data_finished_fut is not None + and not self.data_finished_fut.done() + ): # Auto-close the connection self.data_finished_fut.set_result(None) else: @@ -532,7 +666,9 @@ def data_received(self, data): data_to_send = self.connection.data_to_send() events_to_process = self.connection.events_received() if len(data_to_send) > 0 or len(events_to_process) > 0: - asyncio.create_task(self.async_data_received(data_to_send, events_to_process)) + asyncio.create_task( + self.async_data_received(data_to_send, events_to_process) + ) async def async_eof_received(self, data_to_send, events_to_process): # receiving EOF can generate data to send @@ -542,8 +678,12 @@ async def async_eof_received(self, data_to_send, events_to_process): if len(events_to_process) > 0: await self.process_events(events_to_process) - if self.close_connection_task is not None and not self.close_connection_task.done() and \ - self.data_finished_fut is not None and not self.data_finished_fut.done(): + if ( + self.auto_closer_task is not None + and not self.auto_closer_task.done() + and self.data_finished_fut is not None + and not self.data_finished_fut.done() + ): # Auto-close the connection self.data_finished_fut.set_result(None) else: @@ -555,7 +695,9 @@ def eof_received(self) -> Optional[bool]: data_to_send = self.connection.data_to_send() events_to_process = self.connection.events_received() if len(data_to_send) > 0 or len(events_to_process) > 0: - asyncio.create_task(self.async_eof_received(data_to_send, events_to_process)) + asyncio.create_task( + self.async_eof_received(data_to_send, events_to_process) + ) return False def connection_lost(self, exc): diff --git a/sanic/worker.py b/sanic/worker.py index 875ebc3cfc..a3bc29b8b8 100644 --- a/sanic/worker.py +++ b/sanic/worker.py @@ -142,14 +142,11 @@ async def close(self): # Force close non-idle connection after waiting for # graceful_shutdown_timeout - coros = [] for conn in self.connections: if hasattr(conn, "websocket") and conn.websocket: - coros.append(conn.websocket.close(code=1001)) + conn.websocket.fail_connection(code=1001) else: conn.abort() - _shutdown = asyncio.gather(*coros, loop=self.loop) - await _shutdown async def _run(self): for sock in self.sockets: From 13d49b8e1c6fefad6b8adbcf35751f34fb1726f8 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Wed, 15 Sep 2021 17:00:23 +1000 Subject: [PATCH 08/16] Further new websockets impl fixes * Update compatibility with Websockets v10 * Track server connection state in a more precise way * Try to handle the shutdown process more gracefully * Add a new end_connection() helper, to use as an alterative to close() or fail_connection() * Kill the auto-close task and keepalive-timeout task when sanic is shutdown * Deprecate WEBSOCKET_READ_LIMIT and WEBSOCKET_WRITE_LIMIT configs, they are not used in this implementation. --- sanic/app.py | 10 +- sanic/config.py | 4 - sanic/server/protocols/websocket_protocol.py | 11 +- sanic/server/runners.py | 2 - sanic/server/websockets/impl.py | 123 ++++++++++++++----- setup.py | 2 +- tests/test_app.py | 6 - 7 files changed, 105 insertions(+), 53 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index effaabe86a..675e91a744 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -884,15 +884,19 @@ async def _websocket_handler( # needs to be cancelled due to the server being stopped fut = ensure_future(handler(request, ws, *args, **kwargs)) self.websocket_tasks.add(fut) + cancelled = False try: await fut except Exception as e: self.error_handler.log(request, e) - except (CancelledError, ConnectionClosed): - pass + except (CancelledError, ConnectionClosed) as E: + cancelled = True finally: self.websocket_tasks.remove(fut) - await ws.close() + if cancelled: + ws.end_connection(1000) + else: + await ws.close() # -------------------------------------------------------------------- # # Testing diff --git a/sanic/config.py b/sanic/config.py index 5b7217034a..2a90c5fb35 100644 --- a/sanic/config.py +++ b/sanic/config.py @@ -38,8 +38,6 @@ "WEBSOCKET_MAX_SIZE": 2 ** 20, # 1 megabyte "WEBSOCKET_PING_INTERVAL": 20, "WEBSOCKET_PING_TIMEOUT": 20, - "WEBSOCKET_READ_LIMIT": 2 ** 16, - "WEBSOCKET_WRITE_LIMIT": 2 ** 16, } @@ -64,8 +62,6 @@ class Config(dict): WEBSOCKET_MAX_SIZE: int WEBSOCKET_PING_INTERVAL: int WEBSOCKET_PING_TIMEOUT: int - WEBSOCKET_READ_LIMIT: int - WEBSOCKET_WRITE_LIMIT: int def __init__( self, diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 77288cc633..9cf33da648 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -86,11 +86,10 @@ def close(self, timeout: Optional[float] = None): # Called by HttpProtocol at the end of connection_task # If we've upgraded to websocket, we do our own closing if self.websocket is not None: - if self.websocket.loop is not None: - ... - self.websocket.loop.create_task(self.websocket.close(1001)) - else: - self.websocket.fail_connection(1001) + # Note, we don't want to use websocket.close() + # That is used for user's application code to send a + # websocket close packet. This is different. + self.websocket.end_connection(1001) else: super().close() @@ -103,7 +102,7 @@ def close_if_idle(self): elif self.websocket.loop is not None: self.websocket.loop.create_task(self.websocket.close(1001)) else: - self.websocket.fail_connection(1001) + self.websocket.end_connection(1001) else: return super().close_if_idle() diff --git a/sanic/server/runners.py b/sanic/server/runners.py index 0b1305a38a..f0bebb030c 100644 --- a/sanic/server/runners.py +++ b/sanic/server/runners.py @@ -274,8 +274,6 @@ def _build_protocol_kwargs( if hasattr(protocol, "websocket_handshake"): return { "websocket_max_size": config.WEBSOCKET_MAX_SIZE, - "websocket_read_limit": config.WEBSOCKET_READ_LIMIT, - "websocket_write_limit": config.WEBSOCKET_WRITE_LIMIT, "websocket_ping_timeout": config.WEBSOCKET_PING_TIMEOUT, "websocket_ping_interval": config.WEBSOCKET_PING_INTERVAL, } diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index f8ec6f3e62..1611a98f0b 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -103,7 +103,7 @@ def resume_frames(self): if self.pause_frame_fut is None: return False if self.loop is None or self.io_proto is None: - error_logger.warning( + error_logger.debug( "Websocket attempting to resume reading frames, but connection is gone." ) return False @@ -225,6 +225,39 @@ async def keepalive_ping(self) -> None: except Exception: error_logger.warning("Unexpected exception in keepalive ping task") + def _force_disconnect(self) -> bool: + """ + Internal methdod used by end_connection and fail_connection + only when the graceful auto-closer cannot be used + """ + if ( + self.auto_closer_task is not None + and not self.auto_closer_task.done() + ): + self.auto_closer_task.cancel() + if ( + self.data_finished_fut is not None + and not self.data_finished_fut.done() + ): + self.data_finished_fut.cancel() + self.data_finished_fut = None + if ( + self.keepalive_ping_task is not None + and not self.keepalive_ping_task.done() + ): + self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None + if ( + self.loop is None + or self.io_proto is None + or self.io_proto.transport is None + ): + # We were never open, or already closed + return True + self.io_proto.transport.close() + self.loop.call_later(self.close_timeout, self.io_proto.transport.abort) + return True + def fail_connection(self, code: int = 1006, reason: str = "") -> bool: """ Fail the WebSocket Connection @@ -255,7 +288,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: if code in (1000, 1001): self.connection.send_close(code, reason) else: - self.connection.fail_connection(code, reason) + self.connection.fail(code, reason) try: data_to_send = self.connection.data_to_send() while ( @@ -271,7 +304,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: ... if code == 1006: # Special case: 1006 consider the transport already closed - self.connection.set_state(CLOSED) + self.connection.state = CLOSED if ( self.data_finished_fut is not None and not self.data_finished_fut.done() @@ -279,27 +312,50 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: # We have a graceful auto-closer. Use it to close the connection. self.data_finished_fut.cancel() self.data_finished_fut = None - if self.auto_closer_task is not None: - if self.auto_closer_task.done(): - # auto_closer has already closed the connection? - self.auto_closer_task = None - return True - return False - else: - # Auto closer is not running. Do it manually. - if ( - self.loop is None - or self.io_proto is None - or self.io_proto.transport is None - ): - # We were never open, or already closed - return True - # cannot use the connection_lost_waiter future here, - # because this is a synchronous function. - self.io_proto.transport.close() - self.loop.call_later( - self.close_timeout, self.io_proto.transport.abort - ) + if self.auto_closer_task is None or self.auto_closer_task.done(): + return self._force_disconnect() + + def end_connection(self, code=1000, reason=""): + # This is like slightly more graceful form of fail_connection + # Use this instead of close() when you need an immediate + # close and cannot await websocket.close() handshake. + + if ( + code == 1006 + or self.io_proto is None + or self.io_proto.transport is None + ): + return self.fail_connection(code, reason) + + # Stop new data coming in + # In Python Version 3.7: pause_reading is idempotent + # i.e. it can be called when the transport is already paused or closed. + self.io_proto.transport.pause_reading() + if self.connection.state == OPEN: + data_to_send = self.connection.data_to_send() + self.connection.send_close(code, reason) + try: + data_to_send.extend(self.connection.data_to_send()) + while ( + len(data_to_send) + and self.io_proto + and self.io_proto.transport is not None + ): + frame_data = data_to_send.pop(0) + self.io_proto.transport.write(frame_data) + except Exception: + # sending close frames may fail if the + # transport closes during this period + ... + if ( + self.data_finished_fut is not None + and not self.data_finished_fut.done() + ): + # We have a graceful auto-closer. Use it to close the connection. + self.data_finished_fut.cancel() + self.data_finished_fut = None + if self.auto_closer_task is None or self.auto_closer_task.done(): + return self._force_disconnect() async def auto_close_connection(self) -> None: """ @@ -323,6 +379,7 @@ async def auto_close_connection(self) -> None: # Cancel the keepalive ping task. if self.keepalive_ping_task is not None: self.keepalive_ping_task.cancel() + self.keepalive_ping_task = None # Half-close the TCP connection if possible (when there's no TLS). if ( @@ -340,7 +397,6 @@ async def auto_close_connection(self) -> None: finally: # The try/finally ensures that the transport never remains open, # even if this coroutine is cancelled (for example). - self.auto_closer_task = None if self.io_proto is None or self.io_proto.transport is None: # we were never open, or already dead and buried. Can't do any finalization. return @@ -397,6 +453,7 @@ def abort_pings(self) -> None: async def close(self, code: int = 1000, reason: str = "") -> None: """ Perform the closing handshake. + This is a websocket-protocol level close. :meth:`close` waits for the other end to complete the handshake and for the TCP connection to terminate. :meth:`close` is idempotent: it doesn't do anything once the @@ -404,6 +461,9 @@ async def close(self, code: int = 1000, reason: str = "") -> None: :param code: WebSocket close code :param reason: WebSocket close reason """ + if code == 1006: + self.fail_connection(code, reason) + return async with self.conn_mutex: if self.connection.state is OPEN: self.connection.send_close(code, reason) @@ -542,11 +602,7 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None: raise ServerError( "Cannot write to websocket interface after it is closed." ) - if ( - self.data_finished_fut is None - or self.data_finished_fut.cancelled() - or self.data_finished_fut.done() - ): + if self.data_finished_fut is None or self.data_finished_fut.done(): raise ServerError( "Cannot write to websocket interface after it is finished." ) @@ -704,7 +760,12 @@ def connection_lost(self, exc): """ The WebSocket Connection is Closed. """ - self.connection.set_state(CLOSED) + if not self.connection.state == CLOSED: + # signal to the websocket connection handler + # we've lost the connection + self.connection.fail(code=1006) + self.connection.state = CLOSED + self.abort_pings() if self.connection_lost_waiter is not None: self.connection_lost_waiter.set_result(None) diff --git a/setup.py b/setup.py index d65703ac76..ebfe85c99d 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ def open_local(paths, mode="r", encoding="utf8"): uvloop, ujson, "aiofiles>=0.6.0", - "websockets>=9.0", + "websockets>=10.0", "multidict>=5.0,<6.0", ] diff --git a/tests/test_app.py b/tests/test_app.py index f6f10756af..196f34f5a6 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -178,8 +178,6 @@ async def handler(request, ws): @patch("sanic.app.WebSocketProtocol") def test_app_websocket_parameters(websocket_protocol_mock, app): app.config.WEBSOCKET_MAX_SIZE = 44 - app.config.WEBSOCKET_READ_LIMIT = 46 - app.config.WEBSOCKET_WRITE_LIMIT = 47 app.config.WEBSOCKET_PING_TIMEOUT = 48 app.config.WEBSOCKET_PING_INTERVAL = 50 @@ -196,10 +194,6 @@ async def handler(request, ws): websocket_protocol_call_args = websocket_protocol_mock.call_args ws_kwargs = websocket_protocol_call_args[1] assert ws_kwargs["websocket_max_size"] == app.config.WEBSOCKET_MAX_SIZE - assert ws_kwargs["websocket_read_limit"] == app.config.WEBSOCKET_READ_LIMIT - assert ( - ws_kwargs["websocket_write_limit"] == app.config.WEBSOCKET_WRITE_LIMIT - ) assert ( ws_kwargs["websocket_ping_timeout"] == app.config.WEBSOCKET_PING_TIMEOUT From 5f6cc068560d7c9128db2bdbf48823f3aeadac39 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Thu, 23 Sep 2021 14:12:38 +1000 Subject: [PATCH 09/16] Change a warning message to debug level Remove default values for deprecated websocket parameters --- sanic/server/protocols/websocket_protocol.py | 4 ++-- sanic/server/websockets/impl.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 9cf33da648..6934aaba86 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -29,8 +29,8 @@ def __init__( websocket_timeout: Optional[float] = 10.0, websocket_max_size: Optional[int] = None, websocket_max_queue: Optional[int] = None, # max_queue is deprecated - websocket_read_limit: Optional[int] = 2 ** 16, - websocket_write_limit: Optional[int] = 2 ** 16, + websocket_read_limit: Optional[int] = None, # read_limit is deprecated + websocket_write_limit: Optional[int] = None, # write_limit is deprecated websocket_ping_interval: Optional[float] = 20.0, websocket_ping_timeout: Optional[float] = 20.0, **kwargs, diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index 1611a98f0b..d8e626405b 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -387,7 +387,7 @@ async def auto_close_connection(self) -> None: and self.io_proto.transport is not None and self.io_proto.transport.can_write_eof() ): - error_logger.warning("Websocket half-closing TCP connection") + error_logger.debug("Websocket half-closing TCP connection") self.io_proto.transport.write_eof() if self.connection_lost_waiter is not None: if await self.wait_for_connection_lost(timeout=0): From 37d462ae6b31e24b96c2b3d3e33a697017d15237 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Thu, 23 Sep 2021 14:27:25 +1000 Subject: [PATCH 10/16] Fix flake8 errors --- sanic/app.py | 2 +- sanic/server/protocols/websocket_protocol.py | 18 ++++++++++-------- sanic/server/websockets/frame.py | 8 ++++++-- sanic/server/websockets/impl.py | 8 +++++--- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 675e91a744..38d4b1d27c 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -889,7 +889,7 @@ async def _websocket_handler( await fut except Exception as e: self.error_handler.log(request, e) - except (CancelledError, ConnectionClosed) as E: + except (CancelledError, ConnectionClosed): cancelled = True finally: self.websocket_tasks.remove(fut) diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 6934aaba86..d0fdb69d5e 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING, Optional, Sequence, Union -from httptools import HttpParserUpgrade # type: ignore from websockets.connection import CLOSED, CLOSING, OPEN from websockets.server import ServerConnection @@ -30,7 +29,7 @@ def __init__( websocket_max_size: Optional[int] = None, websocket_max_queue: Optional[int] = None, # max_queue is deprecated websocket_read_limit: Optional[int] = None, # read_limit is deprecated - websocket_write_limit: Optional[int] = None, # write_limit is deprecated + websocket_write_limit: Optional[int] = None, # write_limit deprecated websocket_ping_interval: Optional[float] = 20.0, websocket_ping_timeout: Optional[float] = 20.0, **kwargs, @@ -42,13 +41,15 @@ def __init__( if websocket_max_queue is not None and int(websocket_max_queue) > 0: error_logger.warning( DeprecationWarning( - "websocket_max_queue is no longer used. No websocket message queueing is implemented." + "websocket_max_queue is no longer used. " + "No websocket message queueing is implemented." ) ) if websocket_read_limit is not None and int(websocket_read_limit) > 0: error_logger.warning( DeprecationWarning( - "websocket_read_limit is no longer used. No websocket rate limiting is implemented." + "websocket_read_limit is no longer used. " + "No websocket rate limiting is implemented." ) ) if ( @@ -57,7 +58,8 @@ def __init__( ): error_logger.warning( DeprecationWarning( - "websocket_write_limit is no longer used. No websocket rate limiting is implemented." + "websocket_write_limit is no longer used. " + "No websocket rate limiting is implemented." ) ) self.websocket_ping_interval = websocket_ping_interval @@ -110,10 +112,10 @@ async def websocket_handshake( self, request, subprotocols=Optional[Sequence[str]] ): # let the websockets package do the handshake with the client - headers = {"Upgrade": "websocket", "Connection": "Upgrade"} try: if subprotocols is not None: - # subprotocols can be a set or frozenset, but ServerConnection needs a list + # subprotocols can be a set or frozenset, + # but ServerConnection needs a list subprotocols = list(subprotocols) ws_conn = ServerConnection( max_size=self.websocket_max_size, @@ -122,7 +124,7 @@ async def websocket_handshake( logger=error_logger, ) resp: "http11.Response" = ws_conn.accept(request) - except Exception as exc: + except Exception: msg = ( "Failed to open a WebSocket connection.\n" "See server log for more information.\n" diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index bbe74fcf33..98688b1c34 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -8,6 +8,8 @@ from sanic.exceptions import ServerError +if TYPE_CHECKING: + from .impl import WebsocketImplProtocol UTF8Decoder = codecs.getincrementaldecoder("utf-8") @@ -101,7 +103,8 @@ async def get(self, timeout: Optional[float] = None) -> Optional[Data]: if self.get_in_progress: # This should be guarded against with the read_mutex, exception is only here as a failsafe raise ServerError( - "Called get() on Websocket frame assembler while asynchronous get is already in progress." + "Called get() on Websocket frame assembler " + "while asynchronous get is already in progress." ) self.get_in_progress = True @@ -168,7 +171,8 @@ async def get_iter(self) -> AsyncIterator[Data]: if self.get_in_progress: # This should be guarded against with the read_mutex, exception is only here as a failsafe raise ServerError( - "Called get_iter on Websocket frame assembler while asynchronous get is already in progress." + "Called get_iter on Websocket frame assembler " + "while asynchronous get is already in progress." ) self.get_in_progress = True diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index d8e626405b..303fcef46b 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -439,11 +439,12 @@ def abort_pings(self) -> None: """ if self.connection.state is not CLOSED: raise ServerError( - "webscoket about_pings should only be called after connection state is changed to CLOSED" + "Webscoket about_pings should only be called " + "after connection state is changed to CLOSED" ) for ping in self.pings.values(): - ping.set_exception(ConnectionClosedError(1006, "")) + ping.set_exception(ConnectionClosedError(None)) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does @@ -563,7 +564,8 @@ async def recv_streaming(self) -> AsyncIterator[Data]: """ if self.recv_lock.locked(): raise ServerError( - "cannot call recv_streaming while another task is already waiting for the next message" + "Cannot call recv_streaming while another task " + "is already waiting for the next message" ) await self.recv_lock.acquire() if self.connection.state in (CLOSED, CLOSING): From cb495acfb10030a7e79196a7aaa67c9061bb2a53 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Thu, 23 Sep 2021 16:23:06 +1000 Subject: [PATCH 11/16] Fix a couple of missed failing tests --- examples/websocket_bench.html | 47 ++++++++++++++++++++++ examples/websocket_bench.py | 70 +++++++++++++++++++++++++++++++++ sanic/server/websockets/impl.py | 7 +++- tests/test_exceptions.py | 14 +++++-- tests/test_routes.py | 30 +++++++------- tests/test_worker.py | 4 +- 6 files changed, 148 insertions(+), 24 deletions(-) create mode 100644 examples/websocket_bench.html create mode 100644 examples/websocket_bench.py diff --git a/examples/websocket_bench.html b/examples/websocket_bench.html new file mode 100644 index 0000000000..98c01f3daf --- /dev/null +++ b/examples/websocket_bench.html @@ -0,0 +1,47 @@ + + + + WebSocket benchmark + + + + + diff --git a/examples/websocket_bench.py b/examples/websocket_bench.py new file mode 100644 index 0000000000..198bc9a5e9 --- /dev/null +++ b/examples/websocket_bench.py @@ -0,0 +1,70 @@ +import asyncio +import logging +from sanic import Sanic +from sanic.response import file +from sanic.log import error_logger +import time +error_logger.setLevel(logging.INFO) +app = Sanic(__name__) + +@app.route('/') +async def index(request): + return await file('websocket_bench.html') + + +@app.websocket('/bench') +async def bench_p_time(request, ws): + i = 0 + bytes_total = 0 + start = 0 + end_time = 0 + started = False + await ws.send("1") + while started is False or (end_time > time.time()): + i += 1 + in_data = await ws.recv() + if started is False: + del in_data + error_logger.info("received first data: starting benchmark now..") + started = True + start = time.time() + end_time = start + 30.0 + continue + bytes_total += len(in_data) + del in_data + end = time.time() + elapsed = end - start + error_logger.info("Done. Took {} seconds".format(elapsed)) + error_logger.info("{} bytes in 30 seconds = {}".format(bytes_total, (bytes_total/30.0))) + +@app.websocket('/benchp') +async def bench_p_time(request, ws): + i = 0 + bytes_total = 0 + real_start = 0 + start_ptime = 0 + end_ptime = 0 + started = False + await ws.send("1") + while started is False or (end_ptime > time.process_time()): + i += 1 + in_data = await ws.recv() + if started is False: + del in_data + error_logger.info("received first data: starting benchmark now..") + started = True + real_start = time.time() + start_ptime = time.process_time() + end_ptime = start_ptime + 30.0 + continue + bytes_total += len(in_data) + del in_data + real_end = time.time() + elapsed = real_end - real_start + error_logger.info("Done. Took {} seconds".format(elapsed)) + error_logger.info("{} bytes in 30 seconds = {}".format(bytes_total, (bytes_total/30.0))) + + +if __name__ == '__main__': + app.run(host="0.0.0.0", port=8000, debug=False, auto_reload=False) + diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index 303fcef46b..ca0954c087 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -334,8 +334,8 @@ def end_connection(self, code=1000, reason=""): if self.connection.state == OPEN: data_to_send = self.connection.data_to_send() self.connection.send_close(code, reason) + data_to_send.extend(self.connection.data_to_send()) try: - data_to_send.extend(self.connection.data_to_send()) while ( len(data_to_send) and self.io_proto @@ -346,15 +346,18 @@ def end_connection(self, code=1000, reason=""): except Exception: # sending close frames may fail if the # transport closes during this period + # But that doesn't matter at this point ... if ( self.data_finished_fut is not None and not self.data_finished_fut.done() ): - # We have a graceful auto-closer. Use it to close the connection. + # We have the ability to signal the auto-closer + # try to trigger it to auto-close the connection self.data_finished_fut.cancel() self.data_finished_fut = None if self.auto_closer_task is None or self.auto_closer_task.done(): + # Auto-closer is not running, do force disconnect return self._force_disconnect() async def auto_close_connection(self) -> None: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 1ccd55474c..29797e1e1f 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -16,6 +16,7 @@ abort, ) from sanic.response import text +from websockets.version import version as websockets_version class SanicExceptionTestException(Exception): @@ -260,9 +261,14 @@ async def feed(request, ws): with caplog.at_level(logging.INFO): app.test_client.websocket("/feed") - - assert caplog.record_tuples[1][0] == "sanic.error" - assert caplog.record_tuples[1][1] == logging.ERROR + # Websockets v10.0 and above output an additional + # INFO message when a ws connection is accepted + ws_version_parts = websockets_version.split(".") + ws_major = int(ws_version_parts[0]) + record_index = 2 if ws_major >= 10 else 1 + assert caplog.record_tuples[record_index][0] == "sanic.error" + assert caplog.record_tuples[record_index][1] == logging.ERROR assert ( - "Exception occurred while handling uri:" in caplog.record_tuples[1][2] + "Exception occurred while handling uri:" + in caplog.record_tuples[record_index][2] ) diff --git a/tests/test_routes.py b/tests/test_routes.py index 6c25b86b7f..520ab5be1f 100644 --- a/tests/test_routes.py +++ b/tests/test_routes.py @@ -671,31 +671,29 @@ async def check(request): assert response.json["set"] -def test_websocket_route_with_subprotocols(app): +@pytest.mark.parametrize( + "subprotocols,expected", + ( + (["one"], "one"), + (["three", "one"], "one"), + (["tree"], None), + (None, None), + ), +) +def test_websocket_route_with_subprotocols(app, subprotocols, expected): results = [] @app.websocket("/ws", subprotocols=["zero", "one", "two", "three"]) async def handler(request, ws): - results.append(ws.subprotocol) + nonlocal results + results = ws.subprotocol assert ws.subprotocol is not None - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["one"]) - assert response.opened is True - assert results == ["one"] - _, response = SanicTestClient(app).websocket( - "/ws", subprotocols=["three", "one"] + "/ws", subprotocols=subprotocols ) assert response.opened is True - assert results == ["one", "one"] - - _, response = SanicTestClient(app).websocket("/ws", subprotocols=["tree"]) - assert response.opened is True - assert results == ["one", "one", None] - - _, response = SanicTestClient(app).websocket("/ws") - assert response.opened is True - assert results == ["one", "one", None, None] + assert results == expected @pytest.mark.parametrize("strict_slashes", [True, False, None]) diff --git a/tests/test_worker.py b/tests/test_worker.py index 2db02b50bb..3850b8a691 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -175,7 +175,7 @@ def test_worker_close(worker): worker.wsgi = mock.Mock() conn = mock.Mock() conn.websocket = mock.Mock() - conn.websocket.close = mock.Mock(wraps=_a_noop) + conn.websocket.fail_connection = mock.Mock(wraps=_a_noop) worker.connections = set([conn]) worker.log = mock.Mock() worker.loop = loop @@ -190,5 +190,5 @@ def test_worker_close(worker): loop.run_until_complete(_close) assert worker.signal.stopped - assert conn.websocket.close.called + assert conn.websocket.fail_connection.called assert len(worker.servers) == 0 From 955d515c23e757b9e20353c7bc29c5e53eb09636 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Mon, 27 Sep 2021 09:32:58 +1000 Subject: [PATCH 12/16] remove websocket bench from examples --- examples/websocket_bench.html | 47 ----------------------- examples/websocket_bench.py | 70 ----------------------------------- 2 files changed, 117 deletions(-) delete mode 100644 examples/websocket_bench.html delete mode 100644 examples/websocket_bench.py diff --git a/examples/websocket_bench.html b/examples/websocket_bench.html deleted file mode 100644 index 98c01f3daf..0000000000 --- a/examples/websocket_bench.html +++ /dev/null @@ -1,47 +0,0 @@ - - - - WebSocket benchmark - - - - - diff --git a/examples/websocket_bench.py b/examples/websocket_bench.py deleted file mode 100644 index 198bc9a5e9..0000000000 --- a/examples/websocket_bench.py +++ /dev/null @@ -1,70 +0,0 @@ -import asyncio -import logging -from sanic import Sanic -from sanic.response import file -from sanic.log import error_logger -import time -error_logger.setLevel(logging.INFO) -app = Sanic(__name__) - -@app.route('/') -async def index(request): - return await file('websocket_bench.html') - - -@app.websocket('/bench') -async def bench_p_time(request, ws): - i = 0 - bytes_total = 0 - start = 0 - end_time = 0 - started = False - await ws.send("1") - while started is False or (end_time > time.time()): - i += 1 - in_data = await ws.recv() - if started is False: - del in_data - error_logger.info("received first data: starting benchmark now..") - started = True - start = time.time() - end_time = start + 30.0 - continue - bytes_total += len(in_data) - del in_data - end = time.time() - elapsed = end - start - error_logger.info("Done. Took {} seconds".format(elapsed)) - error_logger.info("{} bytes in 30 seconds = {}".format(bytes_total, (bytes_total/30.0))) - -@app.websocket('/benchp') -async def bench_p_time(request, ws): - i = 0 - bytes_total = 0 - real_start = 0 - start_ptime = 0 - end_ptime = 0 - started = False - await ws.send("1") - while started is False or (end_ptime > time.process_time()): - i += 1 - in_data = await ws.recv() - if started is False: - del in_data - error_logger.info("received first data: starting benchmark now..") - started = True - real_start = time.time() - start_ptime = time.process_time() - end_ptime = start_ptime + 30.0 - continue - bytes_total += len(in_data) - del in_data - real_end = time.time() - elapsed = real_end - real_start - error_logger.info("Done. Took {} seconds".format(elapsed)) - error_logger.info("{} bytes in 30 seconds = {}".format(bytes_total, (bytes_total/30.0))) - - -if __name__ == '__main__': - app.run(host="0.0.0.0", port=8000, debug=False, auto_reload=False) - From 19c98b905a58d2bb57e6c2627377a00f8807fa98 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Mon, 27 Sep 2021 09:42:14 +1000 Subject: [PATCH 13/16] Integrate suggestions from code reviews Use Optional[T] instead of union[T,None] Fix mypy type logic errors change "is not None" to truthy checks where appropriate change "is None" to falsy checks were appropriate Add more debug logging when debug mode is on Change to using sanic.logger for debug logging rather than error_logger. --- sanic/server/protocols/websocket_protocol.py | 33 ++-- sanic/server/websockets/impl.py | 178 ++++++++++--------- 2 files changed, 108 insertions(+), 103 deletions(-) diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index d0fdb69d5e..1a1fd12512 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Sequence, Union +from typing import TYPE_CHECKING, Optional, Sequence from websockets.connection import CLOSED, CLOSING, OPEN from websockets.server import ServerConnection @@ -16,16 +16,16 @@ class WebSocketProtocol(HttpProtocol): - websocket: Union[None, WebsocketImplProtocol] + websocket: Optional[WebsocketImplProtocol] websocket_timeout: float - websocket_max_size = Union[None, int] - websocket_ping_interval = Union[None, float] - websocket_ping_timeout = Union[None, float] + websocket_max_size = Optional[int] + websocket_ping_interval = Optional[float] + websocket_ping_timeout = Optional[float] def __init__( self, *args, - websocket_timeout: Optional[float] = 10.0, + websocket_timeout: float = 10.0, websocket_max_size: Optional[int] = None, websocket_max_queue: Optional[int] = None, # max_queue is deprecated websocket_read_limit: Optional[int] = None, # read_limit is deprecated @@ -38,28 +38,25 @@ def __init__( self.websocket = None self.websocket_timeout = websocket_timeout self.websocket_max_size = websocket_max_size - if websocket_max_queue is not None and int(websocket_max_queue) > 0: + if websocket_max_queue is not None and websocket_max_queue > 0: + # TODO: Reminder remove this warning in v22.3 error_logger.warning( DeprecationWarning( - "websocket_max_queue is no longer used. " - "No websocket message queueing is implemented." + "Websocket no longer uses queueing, so websocket_max_queue is no longer required." ) ) - if websocket_read_limit is not None and int(websocket_read_limit) > 0: + if websocket_read_limit is not None and websocket_read_limit > 0: + # TODO: Reminder remove this warning in v22.3 error_logger.warning( DeprecationWarning( - "websocket_read_limit is no longer used. " - "No websocket rate limiting is implemented." + "Websocket no longer uses read buffers, so websocket_read_limit is not required." ) ) - if ( - websocket_write_limit is not None - and int(websocket_write_limit) > 0 - ): + if websocket_write_limit is not None and websocket_write_limit > 0: + # TODO: Reminder remove this warning in v22.3 error_logger.warning( DeprecationWarning( - "websocket_write_limit is no longer used. " - "No websocket rate limiting is implemented." + "Websocket no longer uses write buffers, so websocket_write_limit is not required." ) ) self.websocket_ping_interval = websocket_ping_interval diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index ca0954c087..591e23b18d 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -19,7 +19,7 @@ from websockets.server import ServerConnection from websockets.typing import Data -from sanic.log import error_logger +from sanic.log import error_logger, logger from sanic.server.protocols.base_protocol import SanicProtocol from ...exceptions import ServerError @@ -33,25 +33,24 @@ class WebsocketImplProtocol: connection: ServerConnection io_proto: Optional[SanicProtocol] - loop: Optional[asyncio.BaseEventLoop] + loop: Optional[asyncio.AbstractEventLoop] max_queue: int close_timeout: float ping_interval: Optional[float] ping_timeout: Optional[float] assembler: WebsocketFrameAssembler - pings: Dict[bytes, asyncio.Future] # Dict[bytes, asyncio.Future[None]] - pings: Dict[bytes, asyncio.Future] # Dict[bytes, asyncio.Future[None]] + # Dict[bytes, asyncio.Future[None]] + pings: Dict[bytes, asyncio.Future] conn_mutex: asyncio.Lock recv_lock: asyncio.Lock process_event_mutex: asyncio.Lock can_pause: bool - data_finished_fut: Optional[ - asyncio.Future - ] # Optional[asyncio.Future[None]] - pause_frame_fut: Optional[asyncio.Future] # Optional[asyncio.Future[None]] - connection_lost_waiter: Optional[ - asyncio.Future - ] # Optional[asyncio.Future[None]] + # Optional[asyncio.Future[None]] + data_finished_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + pause_frame_fut: Optional[asyncio.Future] + # Optional[asyncio.Future[None]] + connection_lost_waiter: Optional[asyncio.Future] keepalive_ping_task: Optional[asyncio.Task] auto_closer_task: Optional[asyncio.Task] @@ -61,7 +60,7 @@ def __init__( max_queue=None, ping_interval: Optional[float] = 20, ping_timeout: Optional[float] = 20, - close_timeout: Optional[float] = 10, + close_timeout: float = 10, loop=None, ): self.connection = connection @@ -90,41 +89,53 @@ def subprotocol(self): def pause_frames(self): if not self.can_pause: return False - if self.pause_frame_fut is not None: + if self.pause_frame_fut: return False - if self.loop is None or self.io_proto is None: + if (not self.loop) or (not self.io_proto): return False - if self.io_proto.transport is not None: + if self.io_proto.transport: self.io_proto.transport.pause_reading() self.pause_frame_fut = self.loop.create_future() return True def resume_frames(self): - if self.pause_frame_fut is None: + if not self.pause_frame_fut: return False - if self.loop is None or self.io_proto is None: - error_logger.debug( + if (not self.loop) or (not self.io_proto): + logger.debug( "Websocket attempting to resume reading frames, but connection is gone." ) return False - if self.io_proto.transport is not None: + if self.io_proto.transport: self.io_proto.transport.resume_reading() self.pause_frame_fut.set_result(None) self.pause_frame_fut = None return True - async def connection_made(self, io_proto: SanicProtocol, loop=None): - if loop is None: + async def connection_made( + self, + io_proto: SanicProtocol, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + if not loop: try: loop = getattr(io_proto, "loop") except AttributeError: loop = asyncio.get_event_loop() + if not loop: + # This catch is for mypy type checker + # to assert loop is not None here. + raise ServerError("Connection received with no asyncio loop.") + if self.auto_closer_task: + raise ServerError( + "Cannot call connection_made more than once on a websocket connection." + ) self.loop = loop - self.io_proto = io_proto # this will be a WebSocketProtocol + self.io_proto = io_proto self.connection_lost_waiter = self.loop.create_future() self.data_finished_fut = asyncio.shield(self.loop.create_future()) - if self.ping_interval is not None: + if self.ping_interval: self.keepalive_ping_task = asyncio.create_task( self.keepalive_ping() ) @@ -141,7 +152,7 @@ async def wait_for_connection_lost(self, timeout=None) -> bool: Return ``True`` if the connection is closed and ``False`` otherwise. """ - if self.connection_lost_waiter is None: + if not self.connection_lost_waiter: return False if self.connection_lost_waiter.done(): return True @@ -165,6 +176,9 @@ async def process_events(self, events: Sequence[Event]) -> None: # from processing at the same time async with self.process_event_mutex: for event in events: + if not isinstance(event, Frame): + # Event is not a frame. Ignore it. + continue if event.opcode == OP_PONG: await self.process_pong(event) else: @@ -219,43 +233,36 @@ async def keepalive_ping(self) -> None: self.fail_connection(1011) break except asyncio.CancelledError: - raise + # It is expected for this task to be cancelled during during + # normal operation, when the connection is closed. + logger.debug("Websocket keepalive ping task was cancelled.") except ConnectionClosed: - pass - except Exception: - error_logger.warning("Unexpected exception in keepalive ping task") + logger.debug("Websocket closed. Keepalive ping task exiting.") + except Exception as e: + error_logger.warning( + "Unexpected exception in websocket keepalive ping task." + ) + logger.debug(str(e)) def _force_disconnect(self) -> bool: """ Internal methdod used by end_connection and fail_connection only when the graceful auto-closer cannot be used """ - if ( - self.auto_closer_task is not None - and not self.auto_closer_task.done() - ): + if self.auto_closer_task and not self.auto_closer_task.done(): self.auto_closer_task.cancel() - if ( - self.data_finished_fut is not None - and not self.data_finished_fut.done() - ): + if self.data_finished_fut and not self.data_finished_fut.done(): self.data_finished_fut.cancel() self.data_finished_fut = None - if ( - self.keepalive_ping_task is not None - and not self.keepalive_ping_task.done() - ): + if self.keepalive_ping_task and not self.keepalive_ping_task.done(): self.keepalive_ping_task.cancel() self.keepalive_ping_task = None - if ( - self.loop is None - or self.io_proto is None - or self.io_proto.transport is None - ): - # We were never open, or already closed - return True - self.io_proto.transport.close() - self.loop.call_later(self.close_timeout, self.io_proto.transport.abort) + if self.loop and self.io_proto and self.io_proto.transport: + self.io_proto.transport.close() + self.loop.call_later( + self.close_timeout, self.io_proto.transport.abort + ) + # We were never open, or already closed return True def fail_connection(self, code: int = 1006, reason: str = "") -> bool: @@ -271,7 +278,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: of this. (The specification describes these steps in the opposite order.) """ - if self.io_proto and self.io_proto.transport is not None: + if self.io_proto and self.io_proto.transport: # Stop new data coming in # In Python Version 3.7: pause_reading is idempotent # i.e. it can be called when the transport is already paused or closed. @@ -294,7 +301,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: while ( len(data_to_send) and self.io_proto - and self.io_proto.transport is not None + and self.io_proto.transport ): frame_data = data_to_send.pop(0) self.io_proto.transport.write(frame_data) @@ -305,26 +312,20 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: if code == 1006: # Special case: 1006 consider the transport already closed self.connection.state = CLOSED - if ( - self.data_finished_fut is not None - and not self.data_finished_fut.done() - ): + if self.data_finished_fut and not self.data_finished_fut.done(): # We have a graceful auto-closer. Use it to close the connection. self.data_finished_fut.cancel() self.data_finished_fut = None - if self.auto_closer_task is None or self.auto_closer_task.done(): + if (not self.auto_closer_task) or self.auto_closer_task.done(): return self._force_disconnect() + return False def end_connection(self, code=1000, reason=""): # This is like slightly more graceful form of fail_connection # Use this instead of close() when you need an immediate # close and cannot await websocket.close() handshake. - if ( - code == 1006 - or self.io_proto is None - or self.io_proto.transport is None - ): + if code == 1006 or not self.io_proto or not self.io_proto.transport: return self.fail_connection(code, reason) # Stop new data coming in @@ -339,7 +340,7 @@ def end_connection(self, code=1000, reason=""): while ( len(data_to_send) and self.io_proto - and self.io_proto.transport is not None + and self.io_proto.transport ): frame_data = data_to_send.pop(0) self.io_proto.transport.write(frame_data) @@ -348,17 +349,15 @@ def end_connection(self, code=1000, reason=""): # transport closes during this period # But that doesn't matter at this point ... - if ( - self.data_finished_fut is not None - and not self.data_finished_fut.done() - ): + if self.data_finished_fut and not self.data_finished_fut.done(): # We have the ability to signal the auto-closer # try to trigger it to auto-close the connection self.data_finished_fut.cancel() self.data_finished_fut = None - if self.auto_closer_task is None or self.auto_closer_task.done(): + if (not self.auto_closer_task) or self.auto_closer_task.done(): # Auto-closer is not running, do force disconnect return self._force_disconnect() + return False async def auto_close_connection(self) -> None: """ @@ -371,28 +370,33 @@ async def auto_close_connection(self) -> None: """ try: # Wait for the data transfer phase to complete. - if self.data_finished_fut is not None: + if self.data_finished_fut: try: await self.data_finished_fut + logger.debug( + "Websocket task finished. Closing the connection." + ) except asyncio.CancelledError: # Cancelled error will be called when data phase is cancelled # This can be if an error occurred or the client app closed the connection - ... + logger.debug( + "Websocket handler cancelled. Closing the connection." + ) # Cancel the keepalive ping task. - if self.keepalive_ping_task is not None: + if self.keepalive_ping_task: self.keepalive_ping_task.cancel() self.keepalive_ping_task = None # Half-close the TCP connection if possible (when there's no TLS). if ( self.io_proto - and self.io_proto.transport is not None + and self.io_proto.transport and self.io_proto.transport.can_write_eof() ): - error_logger.debug("Websocket half-closing TCP connection") + logger.debug("Websocket half-closing TCP connection") self.io_proto.transport.write_eof() - if self.connection_lost_waiter is not None: + if self.connection_lost_waiter: if await self.wait_for_connection_lost(timeout=0): return except asyncio.CancelledError: @@ -400,11 +404,11 @@ async def auto_close_connection(self) -> None: finally: # The try/finally ensures that the transport never remains open, # even if this coroutine is cancelled (for example). - if self.io_proto is None or self.io_proto.transport is None: + if (not self.io_proto) or (not self.io_proto.transport): # we were never open, or already dead and buried. Can't do any finalization. return elif ( - self.connection_lost_waiter is not None + self.connection_lost_waiter and self.connection_lost_waiter.done() ): # connection was confirmed closed already, proceed to abort waiter @@ -415,13 +419,13 @@ async def auto_close_connection(self) -> None: ... else: self.io_proto.transport.close() - if self.connection_lost_waiter is None: + if not self.connection_lost_waiter: # Our connection monitor task isn't running. try: await asyncio.sleep(self.close_timeout) except asyncio.CancelledError: ... - if self.io_proto and self.io_proto.transport is not None: + if self.io_proto and self.io_proto.transport: self.io_proto.transport.abort() else: if await self.wait_for_connection_lost( @@ -432,7 +436,7 @@ async def auto_close_connection(self) -> None: error_logger.warning( "Timeout waiting for TCP connection to close. Aborting" ) - if self.io_proto and self.io_proto.transport is not None: + if self.io_proto and self.io_proto.transport: self.io_proto.transport.abort() def abort_pings(self) -> None: @@ -447,7 +451,7 @@ def abort_pings(self) -> None: ) for ping in self.pings.values(): - ping.set_exception(ConnectionClosedError(None)) + ping.set_exception(ConnectionClosedError(None, None)) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does @@ -607,7 +611,7 @@ async def send(self, message: Union[Data, Iterable[Data]]) -> None: raise ServerError( "Cannot write to websocket interface after it is closed." ) - if self.data_finished_fut is None or self.data_finished_fut.done(): + if (not self.data_finished_fut) or self.data_finished_fut.done(): raise ServerError( "Cannot write to websocket interface after it is finished." ) @@ -654,6 +658,10 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future: raise ServerError( "Cannot send a ping when the websocket interface is closed." ) + if (not self.io_proto) or (not self.io_proto.loop): + raise ServerError( + "Cannot send a ping when the websocket has no io protocol attached." + ) if data is not None: if isinstance(data, str): data = data.encode("utf-8") @@ -703,9 +711,9 @@ async def send_data(self, data_to_send): # Send an EOF # We don't actually send it, just trigger to autoclose the connection if ( - self.auto_closer_task is not None + self.auto_closer_task and not self.auto_closer_task.done() - and self.data_finished_fut is not None + and self.data_finished_fut and not self.data_finished_fut.done() ): # Auto-close the connection @@ -740,9 +748,9 @@ async def async_eof_received(self, data_to_send, events_to_process): await self.process_events(events_to_process) if ( - self.auto_closer_task is not None + self.auto_closer_task and not self.auto_closer_task.done() - and self.data_finished_fut is not None + and self.data_finished_fut and not self.data_finished_fut.done() ): # Auto-close the connection @@ -772,5 +780,5 @@ def connection_lost(self, exc): self.connection.state = CLOSED self.abort_pings() - if self.connection_lost_waiter is not None: + if self.connection_lost_waiter: self.connection_lost_waiter.set_result(None) From 791b6935ce2e5ef40623e582d2515b66cb845369 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Tue, 28 Sep 2021 10:28:24 +1000 Subject: [PATCH 14/16] Fix long line lengths of debug messages Add some new debug messages when websocket IO is paused and unpaused for flow control Fix websocket example to use app.static() --- examples/websocket.py | 9 ++-- sanic/server/protocols/websocket_protocol.py | 9 ++-- sanic/server/websockets/frame.py | 51 +++++++++++++------- sanic/server/websockets/impl.py | 51 +++++++++++--------- 4 files changed, 72 insertions(+), 48 deletions(-) diff --git a/examples/websocket.py b/examples/websocket.py index 9cba083cfc..57da077c5d 100644 --- a/examples/websocket.py +++ b/examples/websocket.py @@ -1,13 +1,14 @@ from sanic import Sanic -from sanic.response import file +from sanic.response import file, redirect app = Sanic(__name__) -@app.route('/') -async def index(request): - return await file('websocket.html') +app.static('index.html', "websocket.html") +@app.route('/') +def index(request): + return redirect("index.html") @app.websocket('/feed') async def feed(request, ws): diff --git a/sanic/server/protocols/websocket_protocol.py b/sanic/server/protocols/websocket_protocol.py index 1a1fd12512..628945d29a 100644 --- a/sanic/server/protocols/websocket_protocol.py +++ b/sanic/server/protocols/websocket_protocol.py @@ -42,21 +42,24 @@ def __init__( # TODO: Reminder remove this warning in v22.3 error_logger.warning( DeprecationWarning( - "Websocket no longer uses queueing, so websocket_max_queue is no longer required." + "Websocket no longer uses queueing, so websocket_max_queue" + " is no longer required." ) ) if websocket_read_limit is not None and websocket_read_limit > 0: # TODO: Reminder remove this warning in v22.3 error_logger.warning( DeprecationWarning( - "Websocket no longer uses read buffers, so websocket_read_limit is not required." + "Websocket no longer uses read buffers, so " + "websocket_read_limit is not required." ) ) if websocket_write_limit is not None and websocket_write_limit > 0: # TODO: Reminder remove this warning in v22.3 error_logger.warning( DeprecationWarning( - "Websocket no longer uses write buffers, so websocket_write_limit is not required." + "Websocket no longer uses write buffers, so " + "websocket_write_limit is not required." ) ) self.websocket_ping_interval = websocket_ping_interval diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index 98688b1c34..2ecb578f6a 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -101,7 +101,8 @@ async def get(self, timeout: Optional[float] = None) -> Optional[Data]: if not self.message_complete.is_set(): return None if self.get_in_progress: - # This should be guarded against with the read_mutex, exception is only here as a failsafe + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe raise ServerError( "Called get() on Websocket frame assembler " "while asynchronous get is already in progress." @@ -130,9 +131,11 @@ async def get(self, timeout: Optional[float] = None) -> Optional[Data]: self.protocol.resume_frames() self.paused = False if not self.get_in_progress: - # This should be guarded against with the read_mutex, exception is here as a failsafe + # This should be guarded against with the read_mutex, + # exception is here as a failsafe raise ServerError( - "State of Websocket frame assembler was modified while an asynchronous get was in progress." + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." ) self.get_in_progress = False @@ -148,10 +151,12 @@ async def get(self, timeout: Optional[float] = None) -> Optional[Data]: # mypy cannot figure out that chunks have the proper type. message: Data = joiner.join(self.chunks) # type: ignore if self.message_fetched.is_set(): - # This should be guarded against with the read_mutex, and get_in_progress check, - # this exception is here as a failsafe + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is here + # as a failsafe raise ServerError( - "Websocket get() found a message when state was already fetched." + "Websocket get() found a message when " + "state was already fetched." ) self.message_fetched.set() self.chunks = [] @@ -169,7 +174,8 @@ async def get_iter(self) -> AsyncIterator[Data]: """ async with self.read_mutex: if self.get_in_progress: - # This should be guarded against with the read_mutex, exception is only here as a failsafe + # This should be guarded against with the read_mutex, + # exception is only here as a failsafe raise ServerError( "Called get_iter on Websocket frame assembler " "while asynchronous get is already in progress." @@ -186,7 +192,7 @@ async def get_iter(self) -> AsyncIterator[Data]: if self.message_complete.is_set(): await self.chunks_queue.put(None) - # Locking with get_in_progress ensures only one thread can get here. + # Locking with get_in_progress ensures only one thread can get here for c in chunks: yield c while True: @@ -200,22 +206,28 @@ async def get_iter(self) -> AsyncIterator[Data]: self.protocol.resume_frames() self.paused = False if not self.get_in_progress: - # This should be guarded against with the read_mutex, exception is here as a failsafe + # This should be guarded against with the read_mutex, + # exception is here as a failsafe raise ServerError( - "State of Websocket frame assembler was modified while an asynchronous get was in progress." + "State of Websocket frame assembler was modified while an " + "asynchronous get was in progress." ) self.get_in_progress = False if not self.message_complete.is_set(): - # This should be guarded against with the read_mutex, exception is here as a failsafe + # This should be guarded against with the read_mutex, + # exception is here as a failsafe raise ServerError( - "Websocket frame assembler chunks queue ended before message was complete." + "Websocket frame assembler chunks queue ended before " + "message was complete." ) self.message_complete.clear() if self.message_fetched.is_set(): - # This should be guarded against with the read_mutex, and get_in_progress check, - # this exception is here as a failsafe + # This should be guarded against with the read_mutex, + # and get_in_progress check, this exception is + # here as a failsafe raise ServerError( - "Websocket get_iter() found a message when state was already fetched." + "Websocket get_iter() found a message when state was " + "already fetched." ) self.message_fetched.set() @@ -257,7 +269,8 @@ async def put(self, frame: Frame) -> None: if not frame.fin: return if not self.get_in_progress: - # nobody is waiting for this frame, so try to pause subsequent frames at the protocol level + # nobody is waiting for this frame, so try to pause subsequent + # frames at the protocol level self.paused = self.protocol.pause_frames() # Message is complete. Wait until it's fetched to return. @@ -266,13 +279,15 @@ async def put(self, frame: Frame) -> None: if self.message_complete.is_set(): # This should be guarded against with the write_mutex raise ServerError( - "Websocket put() got a new message when a message was already in its chamber." + "Websocket put() got a new message when a message was " + "already in its chamber." ) self.message_complete.set() # Signal to get() it can serve the if self.message_fetched.is_set(): # This should be guarded against with the write_mutex raise ServerError( - "Websocket put() got a new message when the previous message was not yet fetched." + "Websocket put() got a new message when the previous " + "message was not yet fetched." ) # Allow get() to run and eventually set the event. diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index 591e23b18d..7a9bd3117d 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -3,7 +3,6 @@ import struct from typing import ( - TYPE_CHECKING, AsyncIterator, Dict, Iterable, @@ -15,7 +14,7 @@ from websockets.connection import CLOSED, CLOSING, OPEN, Event from websockets.exceptions import ConnectionClosed, ConnectionClosedError -from websockets.frames import OP_PONG +from websockets.frames import Frame, OP_PONG from websockets.server import ServerConnection from websockets.typing import Data @@ -26,10 +25,6 @@ from .frame import WebsocketFrameAssembler -if TYPE_CHECKING: - from websockets.frames import Frame - - class WebsocketImplProtocol: connection: ServerConnection io_proto: Optional[SanicProtocol] @@ -90,26 +85,31 @@ def pause_frames(self): if not self.can_pause: return False if self.pause_frame_fut: + logger.debug("Websocket connection already paused.") return False if (not self.loop) or (not self.io_proto): return False if self.io_proto.transport: self.io_proto.transport.pause_reading() self.pause_frame_fut = self.loop.create_future() + logger.debug("Websocket connection paused.") return True def resume_frames(self): if not self.pause_frame_fut: + logger.debug("Websocket connection not paused.") return False if (not self.loop) or (not self.io_proto): logger.debug( - "Websocket attempting to resume reading frames, but connection is gone." + "Websocket attempting to resume reading frames, " + "but connection is gone." ) return False if self.io_proto.transport: self.io_proto.transport.resume_reading() self.pause_frame_fut.set_result(None) self.pause_frame_fut = None + logger.debug("Websocket connection unpaused.") return True async def connection_made( @@ -128,7 +128,8 @@ async def connection_made( raise ServerError("Connection received with no asyncio loop.") if self.auto_closer_task: raise ServerError( - "Cannot call connection_made more than once on a websocket connection." + "Cannot call connection_made more than once " + "on a websocket connection." ) self.loop = loop self.io_proto = io_proto @@ -163,9 +164,9 @@ async def wait_for_connection_lost(self, timeout=None) -> bool: ) return True except asyncio.TimeoutError: - # Re-check self.connection_lost_waiter.done() synchronously because - # connection_lost() could run between the moment the timeout occurs - # and the moment this coroutine resumes running. + # Re-check self.connection_lost_waiter.done() synchronously + # because connection_lost() could run between the moment the + # timeout occurs and the moment this coroutine resumes running return self.connection_lost_waiter.done() async def process_events(self, events: Sequence[Event]) -> None: @@ -216,7 +217,7 @@ async def keepalive_ping(self) -> None: await asyncio.sleep(self.ping_interval) # ping() raises CancelledError if the connection is closed, - # when auto_close_connection() cancels self.keepalive_ping_task. + # when auto_close_connection() cancels keepalive_ping_task. # ping() raises ConnectionClosed if the connection is lost, # when connection_lost() calls abort_pings(). @@ -281,7 +282,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> bool: if self.io_proto and self.io_proto.transport: # Stop new data coming in # In Python Version 3.7: pause_reading is idempotent - # i.e. it can be called when the transport is already paused or closed. + # ut can be called when the transport is already paused or closed self.io_proto.transport.pause_reading() # Keeping fail_connection() synchronous guarantees it can't @@ -377,8 +378,8 @@ async def auto_close_connection(self) -> None: "Websocket task finished. Closing the connection." ) except asyncio.CancelledError: - # Cancelled error will be called when data phase is cancelled - # This can be if an error occurred or the client app closed the connection + # Cancelled error is called when data phase is cancelled + # if an error occurred or the client closed the connection logger.debug( "Websocket handler cancelled. Closing the connection." ) @@ -405,13 +406,13 @@ async def auto_close_connection(self) -> None: # The try/finally ensures that the transport never remains open, # even if this coroutine is cancelled (for example). if (not self.io_proto) or (not self.io_proto.transport): - # we were never open, or already dead and buried. Can't do any finalization. + # we were never open, or done. Can't do any finalization. return elif ( self.connection_lost_waiter and self.connection_lost_waiter.done() ): - # connection was confirmed closed already, proceed to abort waiter + # connection confirmed closed already, proceed to abort waiter ... elif self.io_proto.transport.is_closing(): # Connection is already closing (due to half-close above) @@ -500,7 +501,8 @@ async def recv(self, timeout: Optional[float] = None) -> Optional[Data]: if self.recv_lock.locked(): raise ServerError( - "cannot call recv while another task is already waiting for the next message" + "cannot call recv while another task is " + "already waiting for the next message" ) await self.recv_lock.acquire() if self.connection.state in (CLOSED, CLOSING): @@ -531,7 +533,8 @@ async def recv_burst(self, max_recv=256) -> Sequence[Data]: if self.recv_lock.locked(): raise ServerError( - "cannot call recv_burst while another task is already waiting for the next message" + "cannot call recv_burst while another task is already waiting " + "for the next message" ) await self.recv_lock.acquire() if self.connection.state in (CLOSED, CLOSING): @@ -656,11 +659,13 @@ async def ping(self, data: Optional[Data] = None) -> asyncio.Future: async with self.conn_mutex: if self.connection.state in (CLOSED, CLOSING): raise ServerError( - "Cannot send a ping when the websocket interface is closed." + "Cannot send a ping when the websocket interface " + "is closed." ) if (not self.io_proto) or (not self.io_proto.loop): raise ServerError( - "Cannot send a ping when the websocket has no io protocol attached." + "Cannot send a ping when the websocket has no I/O " + "protocol attached." ) if data is not None: if isinstance(data, str): @@ -708,8 +713,8 @@ async def send_data(self, data_to_send): if data: await self.io_proto.send(data) else: - # Send an EOF - # We don't actually send it, just trigger to autoclose the connection + # Send an EOF - We don't actually send it, + # just trigger to autoclose the connection if ( self.auto_closer_task and not self.auto_closer_task.done() From f2644775207a3a00096113b4d574aafb8ed69c92 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Tue, 28 Sep 2021 10:32:21 +1000 Subject: [PATCH 15/16] remove unused import in websocket example app --- examples/websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/websocket.py b/examples/websocket.py index 57da077c5d..92f713756b 100644 --- a/examples/websocket.py +++ b/examples/websocket.py @@ -1,5 +1,5 @@ from sanic import Sanic -from sanic.response import file, redirect +from sanic.response import redirect app = Sanic(__name__) From 216285449f9a11ee6885a23b235648bac80a563c Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Tue, 28 Sep 2021 13:01:51 +1000 Subject: [PATCH 16/16] re-run isort after Flake8 fixes --- sanic/server/websockets/frame.py | 1 + sanic/server/websockets/impl.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sanic/server/websockets/frame.py b/sanic/server/websockets/frame.py index 2ecb578f6a..b4af72b130 100644 --- a/sanic/server/websockets/frame.py +++ b/sanic/server/websockets/frame.py @@ -8,6 +8,7 @@ from sanic.exceptions import ServerError + if TYPE_CHECKING: from .impl import WebsocketImplProtocol diff --git a/sanic/server/websockets/impl.py b/sanic/server/websockets/impl.py index 7a9bd3117d..a2778c5780 100644 --- a/sanic/server/websockets/impl.py +++ b/sanic/server/websockets/impl.py @@ -14,7 +14,7 @@ from websockets.connection import CLOSED, CLOSING, OPEN, Event from websockets.exceptions import ConnectionClosed, ConnectionClosedError -from websockets.frames import Frame, OP_PONG +from websockets.frames import OP_PONG, Frame from websockets.server import ServerConnection from websockets.typing import Data