Skip to content

Commit

Permalink
Make WebsocketServer an async context manager
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jun 9, 2023
1 parent 58cec1a commit 8713800
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ async def yws_server(request):
except Exception:
kwargs = {}
websocket_server = WebsocketServer(**kwargs)
async with serve(websocket_server.serve, "127.0.0.1", 1234):
yield websocket_server
try:
async with serve(websocket_server.serve, "127.0.0.1", 1234), websocket_server:
yield websocket_server
except Exception:
pass


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def test_asgi(unused_tcp_port):
# server
config = uvicorn.Config("test_asgi:app", port=unused_tcp_port, log_level="info")
server = uvicorn.Server(config)
async with create_task_group() as tg:
async with create_task_group() as tg, websocket_server:
tg.start_soon(server.serve)
while not server.started:
await sleep(0)
Expand Down
42 changes: 37 additions & 5 deletions ypy_websocket/websocket_server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import logging
from typing import Dict, Optional
from contextlib import AsyncExitStack

from anyio import create_task_group
from anyio import TASK_STATUS_IGNORED, Event, create_task_group
from anyio.abc import TaskGroup, TaskStatus

from .yroom import YRoom
from .yutils import YMessageType, process_sync_message, sync
Expand All @@ -10,13 +13,15 @@
class WebsocketServer:

auto_clean_rooms: bool
rooms: Dict[str, YRoom]
rooms: dict[str, YRoom]
_task_group: TaskGroup | None

def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log=None):
self.rooms_ready = rooms_ready
self.auto_clean_rooms = auto_clean_rooms
self.log = log or logging.getLogger(__name__)
self.rooms = {}
self._task_group = None

def get_room(self, path: str) -> YRoom:
if path not in self.rooms.keys():
Expand All @@ -27,15 +32,15 @@ def get_room_name(self, room):
return list(self.rooms.keys())[list(self.rooms.values()).index(room)]

def rename_room(
self, to_name: str, *, from_name: Optional[str] = None, from_room: Optional[YRoom] = None
self, to_name: str, *, from_name: str | None = None, from_room: YRoom | None = None
):
if from_name is not None and from_room is not None:
raise RuntimeError("Cannot pass from_name and from_room")
if from_name is None:
from_name = self.get_room_name(from_room)
self.rooms[to_name] = self.rooms.pop(from_name)

def delete_room(self, *, name: Optional[str] = None, room: Optional[YRoom] = None):
def delete_room(self, *, name: str | None = None, room: YRoom | None = None):
if name is not None and room is not None:
raise RuntimeError("Cannot pass name and room")
if name is None:
Expand All @@ -45,6 +50,14 @@ def delete_room(self, *, name: Optional[str] = None, room: Optional[YRoom] = Non
del self.rooms[name]

async def serve(self, websocket):
if self._task_group is None:
raise RuntimeError(
"The WebsocketServer is not running: use `async with websocket_server:` or `await websocket_server.run()`"
)

await self._task_group.start(self._serve, websocket)

async def _serve(self, websocket, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
async with create_task_group() as tg:
room = self.get_room(websocket.path)
tg.start_soon(room.enter)
Expand Down Expand Up @@ -89,3 +102,22 @@ async def serve(self, websocket):
if self.auto_clean_rooms and not room.clients:
self.delete_room(room=room)
tg.cancel_scope.cancel()
task_status.started()

async def __aenter__(self):
async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._task_group.cancel_scope.cancel()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def run(self):
async with create_task_group() as self._task_group:
# wait forever
self._task_group.start_soon(Event().wait)

def stop(self):
self._task_group.cancel_scope.cancel()

0 comments on commit 8713800

Please sign in to comment.