Skip to content

Commit

Permalink
Make API uniform
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jun 9, 2023
1 parent 8713800 commit f0fa069
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 46 deletions.
27 changes: 24 additions & 3 deletions ypy_websocket/websocket_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import logging
from contextlib import AsyncExitStack
from functools import partial

import y_py as Y
from anyio import create_memory_object_stream, create_task_group
from anyio import Event, create_memory_object_stream, create_task_group
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

Expand All @@ -21,7 +23,8 @@ class WebsocketProvider:
_ydoc: Y.YDoc
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup
_started: Event | None
_task_group: TaskGroup | None

def __init__(self, ydoc: Y.YDoc, websocket, log=None):
self._ydoc = ydoc
Expand All @@ -30,17 +33,30 @@ def __init__(self, ydoc: Y.YDoc, websocket, log=None):
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
)
self._started = None
self._task_group = None
ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream))

@property
def started(self):
if self._started is None:
self._started = Event()
return self._started

async def __aenter__(self):
if self._task_group is not None:
raise RuntimeError("WebsocketProvider already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
tg.start_soon(self._run)
self.started.set()

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._task_group.cancel_scope.cancel()
self._task_group = None
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def _run(self):
Expand All @@ -59,9 +75,14 @@ async def _send(self):
except Exception:
pass

async def run(self):
async def start(self):
if self._task_group is not None:
raise RuntimeError("WebsocketProvider already running")

async with create_task_group() as self._task_group:
self._task_group.start_soon(self._run)
self.started.set()

def stop(self):
self._task_group.cancel_scope.cancel()
self._task_group = None
28 changes: 24 additions & 4 deletions ypy_websocket/websocket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,23 @@ class WebsocketServer:

auto_clean_rooms: bool
rooms: dict[str, YRoom]
_started: Event | None
_task_group: TaskGroup | None

def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log=None):
self.rooms_ready = rooms_ready
self.auto_clean_rooms = auto_clean_rooms
self.log = log or logging.getLogger(__name__)
self.rooms = {}
self._started = None
self._task_group = None

@property
def started(self):
if self._started is None:
self._started = Event()
return self._started

def get_room(self, path: str) -> YRoom:
if path not in self.rooms.keys():
self.rooms[path] = YRoom(ready=self.rooms_ready, log=self.log)
Expand Down Expand Up @@ -52,15 +60,17 @@ def delete_room(self, *, name: str | None = None, room: YRoom | None = None):
async def serve(self, websocket):
if self._task_group is None:
raise RuntimeError(
"The WebsocketServer is not running: use `async with websocket_server:` or `await websocket_server.run()`"
"The WebsocketServer is not running: use `async with websocket_server:` or `await websocket_server.start()`"
)

await self._task_group.start(self._serve, websocket)

async def _serve(self, websocket, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
async with create_task_group() as tg:
room = self.get_room(websocket.path)
tg.start_soon(room.enter)
if not room.started.is_set():
tg.start_soon(room.start)
await room.started.wait()
room.clients.append(websocket)
await sync(room.ydoc, websocket, self.log)
try:
Expand Down Expand Up @@ -105,19 +115,29 @@ async def _serve(self, websocket, *, task_status: TaskStatus[None] = TASK_STATUS
task_status.started()

async def __aenter__(self):
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
self.started.set()

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._task_group.cancel_scope.cancel()
self._task_group = None
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def run(self):
async def start(self):
if self._task_group is not None:
raise RuntimeError("WebsocketServer already running")

# create the task group and wait forever
async with create_task_group() as self._task_group:
# wait forever
self._task_group.start_soon(Event().wait)
self.started.set()

def stop(self):
self._task_group.cancel_scope.cancel()
self._task_group = None
45 changes: 30 additions & 15 deletions ypy_websocket/yroom.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import logging
from contextlib import AsyncExitStack
from functools import partial
from typing import Callable, List, Optional

import y_py as Y
from anyio import create_memory_object_stream, create_task_group
from anyio import Event, create_memory_object_stream, create_task_group
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

Expand All @@ -15,17 +17,17 @@

class YRoom:

clients: List
clients: list
ydoc: Y.YDoc
ystore: Optional[BaseYStore]
_on_message: Optional[Callable]
ystore: BaseYStore | None
_on_message: Callable | None
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_ready: bool
_task_group: TaskGroup
_entered: bool
_task_group: TaskGroup | None
_started: Event | None

def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None):
def __init__(self, ready: bool = True, ystore: BaseYStore | None = None, log=None):
self.ydoc = Y.YDoc()
self.awareness = Awareness(self.ydoc)
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
Expand All @@ -37,7 +39,14 @@ def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=
self.log = log or logging.getLogger(__name__)
self.clients = []
self._on_message = None
self._entered = False
self._started = None
self._task_group = None

@property
def started(self):
if self._started is None:
self._started = Event()
return self._started

@property
def ready(self) -> bool:
Expand All @@ -50,11 +59,11 @@ def ready(self, value: bool) -> None:
self.ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream))

@property
def on_message(self) -> Optional[Callable]:
def on_message(self) -> Callable | None:
return self._on_message

@on_message.setter
def on_message(self, value: Optional[Callable]):
def on_message(self, value: Callable | None):
self._on_message = value

async def _broadcast_updates(self):
Expand All @@ -73,23 +82,29 @@ async def _broadcast_updates(self):
self._task_group.start_soon(self.ystore.write, update)

async def __aenter__(self):
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
tg.start_soon(self._broadcast_updates)
self._started.set()

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._task_group.cancel_scope.cancel()
self._task_group = None
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def enter(self):
if self._entered:
return
async def start(self):
if self._task_group is not None:
raise RuntimeError("YRoom already running")

async with create_task_group() as self._task_group:
self._task_group.start_soon(self._broadcast_updates)
self._entered = True
self._started.set()

def exit(self):
def stop(self):
self._task_group.cancel_scope.cancel()
self._task_group = None
Loading

0 comments on commit f0fa069

Please sign in to comment.