diff --git a/CHANGES/9543.feature.rst b/CHANGES/9543.feature.rst new file mode 100644 index 00000000000..ee624ddc48d --- /dev/null +++ b/CHANGES/9543.feature.rst @@ -0,0 +1 @@ +Improved performance of reading WebSocket messages with a Cython implementation -- by :user:`bdraco`. diff --git a/Makefile b/Makefile index 3a8803756ba..4876f999cde 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,11 @@ endif aiohttp/_find_header.c: $(call to-hash,aiohttp/hdrs.py ./tools/gen.py) ./tools/gen.py +# Special case for reader since we want to be able to disable +# the extension with AIOHTTP_NO_EXTENSIONS +aiohttp/_websocket/reader_c.c: aiohttp/_websocket/reader_c.py + cython -3 -o $@ $< -I aiohttp -Werror + # _find_headers generator creates _headers.pyi as well aiohttp/%.c: aiohttp/%.pyx $(call to-hash,$(CYS)) aiohttp/_find_header.c cython -3 -o $@ $< -I aiohttp -Werror @@ -74,7 +79,7 @@ vendor/llhttp/node_modules: vendor/llhttp/package.json generate-llhttp: .llhttp-gen .PHONY: cythonize -cythonize: .install-cython $(PYXS:.pyx=.c) +cythonize: .install-cython $(PYXS:.pyx=.c) aiohttp/_websocket/reader_c.c .install-deps: .install-cython $(PYXS:.pyx=.c) $(call to-hash,$(CYS) $(REQS)) @python -m pip install -r requirements/dev.in -c requirements/dev.txt @@ -157,6 +162,7 @@ clean: @rm -f aiohttp/_http_parser.c @rm -f aiohttp/_http_writer.c @rm -f aiohttp/_websocket.c + @rm -f aiohttp/_websocket/reader_c.c @rm -rf .tox @rm -f .develop @rm -f .flake diff --git a/aiohttp/_websocket/mask.pxd b/aiohttp/_websocket/mask.pxd new file mode 100644 index 00000000000..90983de9ac7 --- /dev/null +++ b/aiohttp/_websocket/mask.pxd @@ -0,0 +1,3 @@ +"""Cython declarations for websocket masking.""" + +cpdef void _websocket_mask_cython(bytes mask, bytearray data) diff --git a/aiohttp/_websocket/mask.pyx b/aiohttp/_websocket/mask.pyx index 94318d2b1be..2d956c88996 100644 --- a/aiohttp/_websocket/mask.pyx +++ b/aiohttp/_websocket/mask.pyx @@ -8,7 +8,7 @@ cdef extern from "Python.h": from libc.stdint cimport uint32_t, uint64_t, uintmax_t -def _websocket_mask_cython(object mask, object data): +cpdef void _websocket_mask_cython(bytes mask, bytearray data): """Note, this function mutates its `data` argument """ cdef: @@ -21,14 +21,6 @@ def _websocket_mask_cython(object mask, object data): assert len(mask) == 4 - if not isinstance(mask, bytes): - mask = bytes(mask) - - if isinstance(data, bytearray): - data = data - else: - data = bytearray(data) - data_len = len(data) in_buf = PyByteArray_AsString(data) mask_buf = PyBytes_AsString(mask) diff --git a/aiohttp/_websocket/reader.py b/aiohttp/_websocket/reader.py index bd6d29e9a77..254288ac7e7 100644 --- a/aiohttp/_websocket/reader.py +++ b/aiohttp/_websocket/reader.py @@ -1,350 +1,21 @@ """Reader for WebSocket protocol versions 13 and 8.""" -from enum import IntEnum -from typing import Final, List, Optional, Set, Tuple +from typing import TYPE_CHECKING -from ..compression_utils import ZLibDecompressor -from ..helpers import set_exception -from ..streams import DataQueue -from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN2, UNPACK_LEN3, websocket_mask -from .models import ( - WS_DEFLATE_TRAILING, - WebSocketError, - WSCloseCode, - WSMessage, - WSMessageBinary, - WSMessageClose, - WSMessagePing, - WSMessagePong, - WSMessageText, - WSMsgType, -) +from ..helpers import NO_EXTENSIONS -MESSAGE_TYPES_WITH_CONTENT: Final = frozenset( - { - WSMsgType.BINARY, - WSMsgType.TEXT, - WSMsgType.CONTINUATION, - } -) +if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover + from .reader_py import WebSocketReader as WebSocketReaderPython -ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} + WebSocketReader = WebSocketReaderPython +else: + try: + from .reader_c import ( # type: ignore[import-not-found] + WebSocketReader as WebSocketReaderCython, + ) + WebSocketReader = WebSocketReaderCython + except ImportError: # pragma: no cover + from .reader_py import WebSocketReader as WebSocketReaderPython -class WSParserState(IntEnum): - READ_HEADER = 1 - READ_PAYLOAD_LENGTH = 2 - READ_PAYLOAD_MASK = 3 - READ_PAYLOAD = 4 - - -class WebSocketReader: - def __init__( - self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True - ) -> None: - self.queue = queue - self._max_msg_size = max_msg_size - - self._exc: Optional[BaseException] = None - self._partial = bytearray() - self._state = WSParserState.READ_HEADER - - self._opcode: Optional[int] = None - self._frame_fin = False - self._frame_opcode: Optional[int] = None - self._frame_payload = bytearray() - - self._tail: bytes = b"" - self._has_mask = False - self._frame_mask: Optional[bytes] = None - self._payload_length = 0 - self._payload_length_flag = 0 - self._compressed: Optional[bool] = None - self._decompressobj: Optional[ZLibDecompressor] = None - self._compress = compress - - def feed_eof(self) -> None: - self.queue.feed_eof() - - def feed_data(self, data: bytes) -> Tuple[bool, bytes]: - if self._exc: - return True, data - - try: - self._feed_data(data) - except Exception as exc: - self._exc = exc - set_exception(self.queue, exc) - return True, b"" - - return False, b"" - - def _feed_data(self, data: bytes) -> None: - msg: WSMessage - for fin, opcode, payload, compressed in self.parse_frame(data): - if opcode in MESSAGE_TYPES_WITH_CONTENT: - # load text/binary - is_continuation = opcode == WSMsgType.CONTINUATION - if not fin: - # got partial frame payload - if not is_continuation: - self._opcode = opcode - self._partial += payload - if self._max_msg_size and len(self._partial) >= self._max_msg_size: - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(self._partial), self._max_msg_size - ), - ) - continue - - has_partial = bool(self._partial) - if is_continuation: - if self._opcode is None: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Continuation frame for non started message", - ) - opcode = self._opcode - self._opcode = None - # previous frame was non finished - # we should get continuation opcode - elif has_partial: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "The opcode in non-fin frame is expected " - "to be zero, got {!r}".format(opcode), - ) - - if has_partial: - assembled_payload = self._partial + payload - self._partial.clear() - else: - assembled_payload = payload - - if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Message size {} exceeds limit {}".format( - len(assembled_payload), self._max_msg_size - ), - ) - - # Decompress process must to be done after all packets - # received. - if compressed: - if not self._decompressobj: - self._decompressobj = ZLibDecompressor( - suppress_deflate_header=True - ) - payload_merged = self._decompressobj.decompress_sync( - assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size - ) - if self._decompressobj.unconsumed_tail: - left = len(self._decompressobj.unconsumed_tail) - raise WebSocketError( - WSCloseCode.MESSAGE_TOO_BIG, - "Decompressed message size {} exceeds limit {}".format( - self._max_msg_size + left, self._max_msg_size - ), - ) - else: - payload_merged = bytes(assembled_payload) - - if opcode == WSMsgType.TEXT: - try: - text = payload_merged.decode("utf-8") - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - - # XXX: The Text and Binary messages here can be a performance - # bottleneck, so we use tuple.__new__ to improve performance. - # This is not type safe, but many tests should fail in - # test_client_ws_functional.py if this is wrong. - msg = tuple.__new__(WSMessageText, (text, "", WSMsgType.TEXT)) - self.queue.feed_data(msg) - continue - - msg = tuple.__new__( - WSMessageBinary, (payload_merged, "", WSMsgType.BINARY) - ) - self.queue.feed_data(msg) - elif opcode == WSMsgType.CLOSE: - if len(payload) >= 2: - close_code = UNPACK_CLOSE_CODE(payload[:2])[0] - if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - f"Invalid close code: {close_code}", - ) - try: - close_message = payload[2:].decode("utf-8") - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - msg = WSMessageClose(data=close_code, extra=close_message) - elif payload: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - f"Invalid close frame: {fin} {opcode} {payload!r}", - ) - else: - msg = WSMessageClose(data=0, extra="") - - self.queue.feed_data(msg) - - elif opcode == WSMsgType.PING: - msg = WSMessagePing(data=payload, extra="") - self.queue.feed_data(msg) - - elif opcode == WSMsgType.PONG: - msg = WSMessagePong(data=payload, extra="") - self.queue.feed_data(msg) - - else: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" - ) - - def parse_frame( - self, buf: bytes - ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]: - """Return the next frame from the socket.""" - frames: List[Tuple[bool, Optional[int], bytearray, Optional[bool]]] = [] - if self._tail: - buf, self._tail = self._tail + buf, b"" - - start_pos: int = 0 - buf_length = len(buf) - - while True: - # read header - if self._state is WSParserState.READ_HEADER: - if buf_length - start_pos < 2: - break - data = buf[start_pos : start_pos + 2] - start_pos += 2 - first_byte, second_byte = data - - fin = (first_byte >> 7) & 1 - rsv1 = (first_byte >> 6) & 1 - rsv2 = (first_byte >> 5) & 1 - rsv3 = (first_byte >> 4) & 1 - opcode = first_byte & 0xF - - # frame-fin = %x0 ; more frames of this message follow - # / %x1 ; final frame of this message - # frame-rsv1 = %x0 ; - # 1 bit, MUST be 0 unless negotiated otherwise - # frame-rsv2 = %x0 ; - # 1 bit, MUST be 0 unless negotiated otherwise - # frame-rsv3 = %x0 ; - # 1 bit, MUST be 0 unless negotiated otherwise - # - # Remove rsv1 from this test for deflate development - if rsv2 or rsv3 or (rsv1 and not self._compress): - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Received frame with non-zero reserved bits", - ) - - if opcode > 0x7 and fin == 0: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Received fragmented control frame", - ) - - has_mask = (second_byte >> 7) & 1 - length = second_byte & 0x7F - - # Control frames MUST have a payload - # length of 125 bytes or less - if opcode > 0x7 and length > 125: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Control frame payload cannot be larger than 125 bytes", - ) - - # Set compress status if last package is FIN - # OR set compress status if this is first fragment - # Raise error if not first fragment with rsv1 = 0x1 - if self._frame_fin or self._compressed is None: - self._compressed = True if rsv1 else False - elif rsv1: - raise WebSocketError( - WSCloseCode.PROTOCOL_ERROR, - "Received frame with non-zero reserved bits", - ) - - self._frame_fin = bool(fin) - self._frame_opcode = opcode - self._has_mask = bool(has_mask) - self._payload_length_flag = length - self._state = WSParserState.READ_PAYLOAD_LENGTH - - # read payload length - if self._state is WSParserState.READ_PAYLOAD_LENGTH: - length_flag = self._payload_length_flag - if length_flag == 126: - if buf_length - start_pos < 2: - break - data = buf[start_pos : start_pos + 2] - start_pos += 2 - self._payload_length = UNPACK_LEN2(data)[0] - elif length_flag > 126: - if buf_length - start_pos < 8: - break - data = buf[start_pos : start_pos + 8] - start_pos += 8 - self._payload_length = UNPACK_LEN3(data)[0] - else: - self._payload_length = length_flag - - self._state = ( - WSParserState.READ_PAYLOAD_MASK - if self._has_mask - else WSParserState.READ_PAYLOAD - ) - - # read payload mask - if self._state is WSParserState.READ_PAYLOAD_MASK: - if buf_length - start_pos < 4: - break - self._frame_mask = buf[start_pos : start_pos + 4] - start_pos += 4 - self._state = WSParserState.READ_PAYLOAD - - if self._state is WSParserState.READ_PAYLOAD: - length = self._payload_length - payload = self._frame_payload - - chunk_len = buf_length - start_pos - if length >= chunk_len: - self._payload_length = length - chunk_len - payload += buf[start_pos:] - start_pos = buf_length - else: - self._payload_length = 0 - payload += buf[start_pos : start_pos + length] - start_pos = start_pos + length - - if self._payload_length != 0: - break - - if self._has_mask: - assert self._frame_mask is not None - websocket_mask(self._frame_mask, payload) - - frames.append( - (self._frame_fin, self._frame_opcode, payload, self._compressed) - ) - self._frame_payload = bytearray() - self._state = WSParserState.READ_HEADER - - self._tail = buf[start_pos:] - - return frames + WebSocketReader = WebSocketReaderPython diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd new file mode 100644 index 00000000000..d4a82397f92 --- /dev/null +++ b/aiohttp/_websocket/reader_c.pxd @@ -0,0 +1,90 @@ +import cython + +from .mask cimport _websocket_mask_cython as websocket_mask + + +cdef unsigned int READ_HEADER +cdef unsigned int READ_PAYLOAD_LENGTH +cdef unsigned int READ_PAYLOAD_MASK +cdef unsigned int READ_PAYLOAD + +cdef unsigned int OP_CODE_CONTINUATION +cdef unsigned int OP_CODE_TEXT +cdef unsigned int OP_CODE_BINARY +cdef unsigned int OP_CODE_CLOSE +cdef unsigned int OP_CODE_PING +cdef unsigned int OP_CODE_PONG + +cdef object UNPACK_LEN2 +cdef object UNPACK_LEN3 +cdef object UNPACK_CLOSE_CODE +cdef object TUPLE_NEW + +cdef object WSMsgType + +cdef object WSMessageText +cdef object WSMessageBinary +cdef object WSMessagePing +cdef object WSMessagePong +cdef object WSMessageClose + +cdef object WS_MSG_TYPE_TEXT +cdef object WS_MSG_TYPE_BINARY + +cdef set ALLOWED_CLOSE_CODES +cdef set MESSAGE_TYPES_WITH_CONTENT + +cdef tuple EMPTY_FRAME +cdef tuple EMPTY_FRAME_ERROR + + +cdef class WebSocketReader: + + cdef object queue + cdef object _queue_feed_data + cdef unsigned int _max_msg_size + + cdef Exception _exc + cdef bytearray _partial + cdef unsigned int _state + + cdef object _opcode + cdef object _frame_fin + cdef object _frame_opcode + cdef bytearray _frame_payload + + cdef bytes _tail + cdef bint _has_mask + cdef bytes _frame_mask + cdef unsigned int _payload_length + cdef unsigned int _payload_length_flag + cdef object _compressed + cdef object _decompressobj + cdef bint _compress + + cpdef tuple feed_data(self, object data) + + @cython.locals( + is_continuation=bint, + fin=bint, + has_partial=bint, + payload_merged=bytes, + opcode="unsigned int", + ) + cpdef void _feed_data(self, bytes data) + + @cython.locals( + start_pos="unsigned int", + buf_len="unsigned int", + length="unsigned int", + chunk_size="unsigned int", + chunk_len="unsigned int", + buf_length="unsigned int", + data=bytes, + payload=bytearray, + first_byte=char, + second_byte=char, + has_mask=bint, + fin=bint, + ) + cpdef list parse_frame(self, bytes buf) diff --git a/aiohttp/_websocket/reader_c.py b/aiohttp/_websocket/reader_c.py new file mode 120000 index 00000000000..083cbb4331f --- /dev/null +++ b/aiohttp/_websocket/reader_c.py @@ -0,0 +1 @@ +reader_py.py \ No newline at end of file diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py new file mode 100644 index 00000000000..f7136365129 --- /dev/null +++ b/aiohttp/_websocket/reader_py.py @@ -0,0 +1,367 @@ +"""Reader for WebSocket protocol versions 13 and 8.""" + +from typing import Final, List, Optional, Set, Tuple, Union + +from ..compression_utils import ZLibDecompressor +from ..helpers import set_exception +from ..streams import DataQueue +from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN2, UNPACK_LEN3, websocket_mask +from .models import ( + WS_DEFLATE_TRAILING, + WebSocketError, + WSCloseCode, + WSMessage, + WSMessageBinary, + WSMessageClose, + WSMessagePing, + WSMessagePong, + WSMessageText, + WSMsgType, +) + +ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} + +# States for the reader, used to parse the WebSocket frame +# integer values are used so they can be cythonized +READ_HEADER = 1 +READ_PAYLOAD_LENGTH = 2 +READ_PAYLOAD_MASK = 3 +READ_PAYLOAD = 4 + +WS_MSG_TYPE_BINARY = WSMsgType.BINARY +WS_MSG_TYPE_TEXT = WSMsgType.TEXT + +# WSMsgType values unpacked so they can by cythonized to ints +OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value +OP_CODE_TEXT = WSMsgType.TEXT.value +OP_CODE_BINARY = WSMsgType.BINARY.value +OP_CODE_CLOSE = WSMsgType.CLOSE.value +OP_CODE_PING = WSMsgType.PING.value +OP_CODE_PONG = WSMsgType.PONG.value + +EMPTY_FRAME_ERROR = (True, b"") +EMPTY_FRAME = (False, b"") + +TUPLE_NEW = tuple.__new__ + + +class WebSocketReader: + def __init__( + self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True + ) -> None: + self.queue = queue + self._queue_feed_data = queue.feed_data + self._max_msg_size = max_msg_size + + self._exc: Optional[Exception] = None + self._partial = bytearray() + self._state = READ_HEADER + + self._opcode: Optional[int] = None + self._frame_fin = False + self._frame_opcode: Optional[int] = None + self._frame_payload = bytearray() + + self._tail: bytes = b"" + self._has_mask = False + self._frame_mask: Optional[bytes] = None + self._payload_length = 0 + self._payload_length_flag = 0 + self._compressed: Optional[bool] = None + self._decompressobj: Optional[ZLibDecompressor] = None + self._compress = compress + + def feed_eof(self) -> None: + self.queue.feed_eof() + + # data can be bytearray on Windows because proactor event loop uses bytearray + # and asyncio types this to Union[bytes, bytearray, memoryview] so we need + # coerce data to bytes if it is not + def feed_data( + self, data: Union[bytes, bytearray, memoryview] + ) -> Tuple[bool, bytes]: + if type(data) is not bytes: + data = bytes(data) + + if self._exc is not None: + return True, data + + try: + self._feed_data(data) + except Exception as exc: + self._exc = exc + set_exception(self.queue, exc) + return EMPTY_FRAME_ERROR + + return EMPTY_FRAME + + def _feed_data(self, data: bytes) -> None: + msg: WSMessage + for frame in self.parse_frame(data): + fin = frame[0] + opcode = frame[1] + payload = frame[2] + compressed = frame[3] + + is_continuation = opcode == OP_CODE_CONTINUATION + if opcode == OP_CODE_TEXT or opcode == OP_CODE_BINARY or is_continuation: + # load text/binary + if not fin: + # got partial frame payload + if not is_continuation: + self._opcode = opcode + self._partial += payload + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) + continue + + has_partial = bool(self._partial) + if is_continuation: + if self._opcode is None: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Continuation frame for non started message", + ) + opcode = self._opcode + self._opcode = None + # previous frame was non finished + # we should get continuation opcode + elif has_partial: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "The opcode in non-fin frame is expected " + "to be zero, got {!r}".format(opcode), + ) + + if has_partial: + assembled_payload = self._partial + payload + self._partial.clear() + else: + assembled_payload = payload + + if self._max_msg_size and len(assembled_payload) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(assembled_payload), self._max_msg_size + ), + ) + + # Decompress process must to be done after all packets + # received. + if compressed: + if not self._decompressobj: + self._decompressobj = ZLibDecompressor( + suppress_deflate_header=True + ) + payload_merged = self._decompressobj.decompress_sync( + assembled_payload + WS_DEFLATE_TRAILING, self._max_msg_size + ) + if self._decompressobj.unconsumed_tail: + left = len(self._decompressobj.unconsumed_tail) + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Decompressed message size {} exceeds limit {}".format( + self._max_msg_size + left, self._max_msg_size + ), + ) + else: + payload_merged = bytes(assembled_payload) + + if opcode == OP_CODE_TEXT: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + # XXX: The Text and Binary messages here can be a performance + # bottleneck, so we use tuple.__new__ to improve performance. + # This is not type safe, but many tests should fail in + # test_client_ws_functional.py if this is wrong. + msg = TUPLE_NEW(WSMessageText, (text, "", WS_MSG_TYPE_TEXT)) + else: + msg = TUPLE_NEW( + WSMessageBinary, (payload_merged, "", WS_MSG_TYPE_BINARY) + ) + + self._queue_feed_data(msg) + elif opcode == OP_CODE_CLOSE: + if len(payload) >= 2: + close_code = UNPACK_CLOSE_CODE(payload[:2])[0] + if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close code: {close_code}", + ) + try: + close_message = payload[2:].decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + msg = WSMessageClose(data=close_code, extra=close_message) + elif payload: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close frame: {fin} {opcode} {payload!r}", + ) + else: + msg = WSMessageClose(data=0, extra="") + + self._queue_feed_data(msg) + + elif opcode == OP_CODE_PING: + msg = WSMessagePing(data=payload, extra="") + self._queue_feed_data(msg) + + elif opcode == OP_CODE_PONG: + msg = WSMessagePong(data=payload, extra="") + self._queue_feed_data(msg) + + else: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" + ) + + def parse_frame( + self, buf: bytes + ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]: + """Return the next frame from the socket.""" + frames: List[Tuple[bool, Optional[int], bytearray, Optional[bool]]] = [] + if self._tail: + buf, self._tail = self._tail + buf, b"" + + start_pos: int = 0 + buf_length = len(buf) + + while True: + # read header + if self._state == READ_HEADER: + if buf_length - start_pos < 2: + break + data = buf[start_pos : start_pos + 2] + start_pos += 2 + first_byte = data[0] + second_byte = data[1] + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xF + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # + # Remove rsv1 from this test for deflate development + if rsv2 or rsv3 or (rsv1 and not self._compress): + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) + + if opcode > 0x7 and fin == 0: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received fragmented control frame", + ) + + has_mask = (second_byte >> 7) & 1 + length = second_byte & 0x7F + + # Control frames MUST have a payload + # length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Control frame payload cannot be larger than 125 bytes", + ) + + # Set compress status if last package is FIN + # OR set compress status if this is first fragment + # Raise error if not first fragment with rsv1 = 0x1 + if self._frame_fin or self._compressed is None: + self._compressed = True if rsv1 else False + elif rsv1: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) + + self._frame_fin = bool(fin) + self._frame_opcode = opcode + self._has_mask = bool(has_mask) + self._payload_length_flag = length + self._state = READ_PAYLOAD_LENGTH + + # read payload length + if self._state == READ_PAYLOAD_LENGTH: + length_flag = self._payload_length_flag + if length_flag == 126: + if buf_length - start_pos < 2: + break + data = buf[start_pos : start_pos + 2] + start_pos += 2 + self._payload_length = UNPACK_LEN2(data)[0] + elif length_flag > 126: + if buf_length - start_pos < 8: + break + data = buf[start_pos : start_pos + 8] + start_pos += 8 + self._payload_length = UNPACK_LEN3(data)[0] + else: + self._payload_length = length_flag + + self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD + + # read payload mask + if self._state == READ_PAYLOAD_MASK: + if buf_length - start_pos < 4: + break + self._frame_mask = buf[start_pos : start_pos + 4] + start_pos += 4 + self._state = READ_PAYLOAD + + if self._state == READ_PAYLOAD: + length = self._payload_length + payload = self._frame_payload + + chunk_len = buf_length - start_pos + if length >= chunk_len: + self._payload_length = length - chunk_len + payload += buf[start_pos:] + start_pos = buf_length + else: + self._payload_length = 0 + payload += buf[start_pos : start_pos + length] + start_pos = start_pos + length + + if self._payload_length != 0: + break + + if self._has_mask: + assert self._frame_mask is not None + websocket_mask(self._frame_mask, payload) + + frames.append( + (self._frame_fin, self._frame_opcode, payload, self._compressed) + ) + self._frame_payload = bytearray() + self._state = READ_HEADER + + self._tail = buf[start_pos:] + + return frames diff --git a/setup.py b/setup.py index cbb1944762a..c9a2c5c856c 100644 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ include_dirs=["vendor/llhttp/build"], ), Extension("aiohttp._http_writer", ["aiohttp/_http_writer.c"]), + Extension("aiohttp._websocket.reader_c", ["aiohttp/_websocket/reader_c.c"]), ] diff --git a/tests/test_websocket_parser.py b/tests/test_websocket_parser.py index 767c1843076..6d490fd15e1 100644 --- a/tests/test_websocket_parser.py +++ b/tests/test_websocket_parser.py @@ -3,6 +3,7 @@ import random import struct import zlib +from typing import Union from unittest import mock import pytest @@ -28,6 +29,10 @@ ) +class PatchableWebSocketReader(WebSocketReader): + """WebSocketReader subclass that allows for patching parse_frame.""" + + def build_frame( message: bytes, opcode: int, @@ -97,8 +102,19 @@ def out(loop: asyncio.AbstractEventLoop) -> aiohttp.DataQueue[WSMessage]: @pytest.fixture() -def parser(out: aiohttp.DataQueue[WSMessage]) -> WebSocketReader: - return WebSocketReader(out, 4 * 1024 * 1024) +def parser(out: aiohttp.DataQueue[WSMessage]) -> PatchableWebSocketReader: + return PatchableWebSocketReader(out, 4 * 1024 * 1024) + + +def test_feed_data_remembers_exception(parser: WebSocketReader) -> None: + """Verify that feed_data remembers an exception was already raised internally.""" + error, data = parser.feed_data(struct.pack("!BB", 0b01100000, 0b00000000)) + assert error is True + assert data == b"" + + error, data = parser.feed_data(b"") + assert error is True + assert data == b"" def test_parse_frame(parser: WebSocketReader) -> None: @@ -171,11 +187,22 @@ def test_parse_frame_header_payload_size( parser.parse_frame(struct.pack("!BB", 0b10001000, 0b01111110)) -def test_ping_frame(out: aiohttp.DataQueue[WSMessage], parser: WebSocketReader) -> None: +# Protractor event loop will call feed_data with bytearray. Since +# asyncio technically supports memoryview as well, we should test that. +@pytest.mark.parametrize( + argnames="data", + argvalues=[b"", bytearray(b""), memoryview(b"")], + ids=["bytes", "bytearray", "memoryview"], +) +def test_ping_frame( + out: aiohttp.DataQueue[WSMessage], + parser: WebSocketReader, + data: Union[bytes, bytearray, memoryview], +) -> None: with mock.patch.object(parser, "parse_frame", autospec=True) as m: m.return_value = [(1, WSMsgType.PING, b"data", False)] - parser.feed_data(b"") + parser.feed_data(data) res = out._buffer[0] assert res == WSMessagePing(data=b"data", extra="")