Skip to content

Commit

Permalink
websocket: Make WSH.get a coroutine
Browse files Browse the repository at this point in the history
This is necessary to convert accept_connection to native coroutines -
the handshake no longer completes within a single IOLoop iteration
with this change due to coroutine scheduling.

This has the side effect of keeping the HTTP1Connection open for the
lifetime of the websocket connection. That's not great for memory, but
might help streamline close handling. Either way, it'll be refactored
in a future change.
  • Loading branch information
bdarnell committed Dec 29, 2018
1 parent e719d82 commit e69becb
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions tornado/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import abc
import asyncio
import base64
import hashlib
import os
Expand Down Expand Up @@ -228,7 +229,7 @@ def __init__(
self.stream = None # type: Optional[IOStream]
self._on_close_called = False

def get(self, *args: Any, **kwargs: Any) -> None:
async def get(self, *args: Any, **kwargs: Any) -> None:
self.open_args = args
self.open_kwargs = kwargs

Expand Down Expand Up @@ -275,11 +276,10 @@ def get(self, *args: Any, **kwargs: Any) -> None:

self.ws_connection = self.get_websocket_protocol()
if self.ws_connection:
self.ws_connection.accept_connection(self)
await self.ws_connection.accept_connection(self)
else:
self.set_status(426, "Upgrade Required")
self.set_header("Sec-WebSocket-Version", "7, 8, 13")
self.finish()

stream = None

Expand Down Expand Up @@ -679,7 +679,7 @@ def is_closing(self) -> bool:
raise NotImplementedError()

@abc.abstractmethod
def accept_connection(self, handler: WebSocketHandler) -> None:
async def accept_connection(self, handler: WebSocketHandler) -> None:
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -833,7 +833,7 @@ def __init__(
self._masked_frame = None
self._frame_mask = None # type: Optional[bytes]
self._frame_length = None
self._fragmented_message_buffer = None
self._fragmented_message_buffer = None # type: Optional[bytes]
self._fragmented_message_opcode = None
self._waiting = None # type: object
self._compression_options = params.compression_options
Expand Down Expand Up @@ -864,7 +864,7 @@ def selected_subprotocol(self) -> Optional[str]:
def selected_subprotocol(self, value: Optional[str]) -> None:
self._selected_subprotocol = value

def accept_connection(self, handler: WebSocketHandler) -> None:
async def accept_connection(self, handler: WebSocketHandler) -> None:
try:
self._handle_websocket_headers(handler)
except ValueError:
Expand All @@ -875,7 +875,10 @@ def accept_connection(self, handler: WebSocketHandler) -> None:
return

try:
self._accept_connection(handler)
await self._accept_connection(handler)
except asyncio.CancelledError:
self._abort()
return
except ValueError:
gen_log.debug("Malformed WebSocket request received", exc_info=True)
self._abort()
Expand Down Expand Up @@ -906,10 +909,7 @@ def _challenge_response(self, handler: WebSocketHandler) -> str:
cast(str, handler.request.headers.get("Sec-Websocket-Key"))
)

@gen.coroutine
def _accept_connection(
self, handler: WebSocketHandler
) -> Generator[Any, Any, None]:
async def _accept_connection(self, handler: WebSocketHandler) -> None:
subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
if subprotocol_header:
subprotocols = [s.strip() for s in subprotocol_header.split(",")]
Expand Down Expand Up @@ -953,8 +953,8 @@ def _accept_connection(
handler.open, *handler.open_args, **handler.open_kwargs
)
if open_result is not None:
yield open_result
yield self._receive_frame_loop()
await open_result
await self._receive_frame_loop()

def _parse_extensions_header(
self, headers: httputil.HTTPHeaders
Expand Down

0 comments on commit e69becb

Please sign in to comment.