From f0fa069908c9dfbb1f9782abbbcda809b947227c Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 9 Jun 2023 17:22:46 +0200 Subject: [PATCH] Make API uniform --- ypy_websocket/websocket_provider.py | 27 +++++++++-- ypy_websocket/websocket_server.py | 28 ++++++++++-- ypy_websocket/yroom.py | 45 ++++++++++++------- ypy_websocket/ystore.py | 69 +++++++++++++++++++---------- 4 files changed, 123 insertions(+), 46 deletions(-) diff --git a/ypy_websocket/websocket_provider.py b/ypy_websocket/websocket_provider.py index bef0933..78cd605 100644 --- a/ypy_websocket/websocket_provider.py +++ b/ypy_websocket/websocket_provider.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import logging from contextlib import AsyncExitStack from functools import partial import y_py as Y -from anyio import create_memory_object_stream, create_task_group +from anyio import Event, create_memory_object_stream, create_task_group from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -21,7 +23,8 @@ class WebsocketProvider: _ydoc: Y.YDoc _update_send_stream: MemoryObjectSendStream _update_receive_stream: MemoryObjectReceiveStream - _task_group: TaskGroup + _started: Event | None + _task_group: TaskGroup | None def __init__(self, ydoc: Y.YDoc, websocket, log=None): self._ydoc = ydoc @@ -30,17 +33,30 @@ def __init__(self, ydoc: Y.YDoc, websocket, log=None): self._update_send_stream, self._update_receive_stream = create_memory_object_stream( max_buffer_size=65536 ) + self._started = None + self._task_group = None ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream)) + @property + def started(self): + if self._started is None: + self._started = Event() + return self._started + async def __aenter__(self): + if self._task_group is not None: + raise RuntimeError("WebsocketProvider already running") + async with AsyncExitStack() as exit_stack: tg = create_task_group() self._task_group = await exit_stack.enter_async_context(tg) self._exit_stack = exit_stack.pop_all() tg.start_soon(self._run) + self.started.set() async def __aexit__(self, exc_type, exc_value, exc_tb): self._task_group.cancel_scope.cancel() + self._task_group = None return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) async def _run(self): @@ -59,9 +75,14 @@ async def _send(self): except Exception: pass - async def run(self): + async def start(self): + if self._task_group is not None: + raise RuntimeError("WebsocketProvider already running") + async with create_task_group() as self._task_group: self._task_group.start_soon(self._run) + self.started.set() def stop(self): self._task_group.cancel_scope.cancel() + self._task_group = None diff --git a/ypy_websocket/websocket_server.py b/ypy_websocket/websocket_server.py index 006d89c..9974f93 100644 --- a/ypy_websocket/websocket_server.py +++ b/ypy_websocket/websocket_server.py @@ -14,6 +14,7 @@ class WebsocketServer: auto_clean_rooms: bool rooms: dict[str, YRoom] + _started: Event | None _task_group: TaskGroup | None def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log=None): @@ -21,8 +22,15 @@ def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log= self.auto_clean_rooms = auto_clean_rooms self.log = log or logging.getLogger(__name__) self.rooms = {} + self._started = None self._task_group = None + @property + def started(self): + if self._started is None: + self._started = Event() + return self._started + def get_room(self, path: str) -> YRoom: if path not in self.rooms.keys(): self.rooms[path] = YRoom(ready=self.rooms_ready, log=self.log) @@ -52,7 +60,7 @@ def delete_room(self, *, name: str | None = None, room: YRoom | None = None): async def serve(self, websocket): if self._task_group is None: raise RuntimeError( - "The WebsocketServer is not running: use `async with websocket_server:` or `await websocket_server.run()`" + "The WebsocketServer is not running: use `async with websocket_server:` or `await websocket_server.start()`" ) await self._task_group.start(self._serve, websocket) @@ -60,7 +68,9 @@ async def serve(self, websocket): async def _serve(self, websocket, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): async with create_task_group() as tg: room = self.get_room(websocket.path) - tg.start_soon(room.enter) + if not room.started.is_set(): + tg.start_soon(room.start) + await room.started.wait() room.clients.append(websocket) await sync(room.ydoc, websocket, self.log) try: @@ -105,19 +115,29 @@ async def _serve(self, websocket, *, task_status: TaskStatus[None] = TASK_STATUS task_status.started() async def __aenter__(self): + if self._task_group is not None: + raise RuntimeError("WebsocketServer already running") + async with AsyncExitStack() as exit_stack: tg = create_task_group() self._task_group = await exit_stack.enter_async_context(tg) self._exit_stack = exit_stack.pop_all() + self.started.set() async def __aexit__(self, exc_type, exc_value, exc_tb): self._task_group.cancel_scope.cancel() + self._task_group = None return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) - async def run(self): + async def start(self): + if self._task_group is not None: + raise RuntimeError("WebsocketServer already running") + + # create the task group and wait forever async with create_task_group() as self._task_group: - # wait forever self._task_group.start_soon(Event().wait) + self.started.set() def stop(self): self._task_group.cancel_scope.cancel() + self._task_group = None diff --git a/ypy_websocket/yroom.py b/ypy_websocket/yroom.py index e8a5b9d..adf982e 100644 --- a/ypy_websocket/yroom.py +++ b/ypy_websocket/yroom.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import logging from contextlib import AsyncExitStack from functools import partial from typing import Callable, List, Optional import y_py as Y -from anyio import create_memory_object_stream, create_task_group +from anyio import Event, create_memory_object_stream, create_task_group from anyio.abc import TaskGroup from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -15,17 +17,17 @@ class YRoom: - clients: List + clients: list ydoc: Y.YDoc - ystore: Optional[BaseYStore] - _on_message: Optional[Callable] + ystore: BaseYStore | None + _on_message: Callable | None _update_send_stream: MemoryObjectSendStream _update_receive_stream: MemoryObjectReceiveStream _ready: bool - _task_group: TaskGroup - _entered: bool + _task_group: TaskGroup | None + _started: Event | None - def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None): + def __init__(self, ready: bool = True, ystore: BaseYStore | None = None, log=None): self.ydoc = Y.YDoc() self.awareness = Awareness(self.ydoc) self._update_send_stream, self._update_receive_stream = create_memory_object_stream( @@ -37,7 +39,14 @@ def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log= self.log = log or logging.getLogger(__name__) self.clients = [] self._on_message = None - self._entered = False + self._started = None + self._task_group = None + + @property + def started(self): + if self._started is None: + self._started = Event() + return self._started @property def ready(self) -> bool: @@ -50,11 +59,11 @@ def ready(self, value: bool) -> None: self.ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream)) @property - def on_message(self) -> Optional[Callable]: + def on_message(self) -> Callable | None: return self._on_message @on_message.setter - def on_message(self, value: Optional[Callable]): + def on_message(self, value: Callable | None): self._on_message = value async def _broadcast_updates(self): @@ -73,23 +82,29 @@ async def _broadcast_updates(self): self._task_group.start_soon(self.ystore.write, update) async def __aenter__(self): + if self._task_group is not None: + raise RuntimeError("YRoom already running") + async with AsyncExitStack() as exit_stack: tg = create_task_group() self._task_group = await exit_stack.enter_async_context(tg) self._exit_stack = exit_stack.pop_all() tg.start_soon(self._broadcast_updates) + self._started.set() async def __aexit__(self, exc_type, exc_value, exc_tb): self._task_group.cancel_scope.cancel() + self._task_group = None return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) - async def enter(self): - if self._entered: - return + async def start(self): + if self._task_group is not None: + raise RuntimeError("YRoom already running") async with create_task_group() as self._task_group: self._task_group.start_soon(self._broadcast_updates) - self._entered = True + self._started.set() - def exit(self): + def stop(self): self._task_group.cancel_scope.cancel() + self._task_group = None diff --git a/ypy_websocket/ystore.py b/ypy_websocket/ystore.py index c82888f..3e004f1 100644 --- a/ypy_websocket/ystore.py +++ b/ypy_websocket/ystore.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import struct import tempfile @@ -10,6 +12,8 @@ import aiosqlite import anyio import y_py as Y +from anyio import Event, Lock, create_task_group +from anyio.abc import TaskGroup from .yutils import Decoder, get_new_path, write_var_uint @@ -20,8 +24,10 @@ class YDocNotFound(Exception): class BaseYStore(ABC): - metadata_callback: Optional[Callable] = None + metadata_callback: Callable | None = None version = 2 + _started: Event | None = None + _task_group: TaskGroup | None = None @abstractmethod def __init__(self, path: str, metadata_callback=None): @@ -32,25 +38,39 @@ async def write(self, data: bytes) -> None: ... @abstractmethod - async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: + async def read(self) -> AsyncIterator[tuple[bytes, bytes]]: ... + @property + def started(self) -> Event: + if self._started is None: + self._started = Event() + return self._started + async def __aenter__(self): + if self._task_group is not None: + raise RuntimeError("YStore already running") + async with AsyncExitStack() as exit_stack: - tg = anyio.create_task_group() + tg = create_task_group() self._task_group = await exit_stack.enter_async_context(tg) self._exit_stack = exit_stack.pop_all() tg.start_soon(self.start) async def __aexit__(self, exc_type, exc_value, exc_tb): self._task_group.cancel_scope.cancel() + self._task_group = None return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) async def start(self): - pass + if self._task_group is not None: + raise RuntimeError("YStore already running") + + self.started.set() def stop(self): - pass + self._task_group.cancel_scope.cancel() + self._task_group = None async def get_metadata(self) -> bytes: metadata = b"" if not self.metadata_callback else await self.metadata_callback() @@ -69,14 +89,14 @@ class FileYStore(BaseYStore): """A YStore which uses one file per document.""" path: str - metadata_callback: Optional[Callable] - lock: anyio.Lock + metadata_callback: Callable | None + lock: Lock - def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log=None): + def __init__(self, path: str, metadata_callback: Callable | None = None, log=None): self.path = path self.metadata_callback = metadata_callback self.log = log or logging.getLogger(__name__) - self.lock = anyio.Lock() + self.lock = Lock() async def check_version(self) -> int: if not await anyio.Path(self.path).exists(): @@ -107,7 +127,7 @@ async def check_version(self) -> int: offset = len(version_bytes) return offset - async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: ignore + async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore async with self.lock: if not await anyio.Path(self.path).exists(): raise YDocNotFound @@ -153,10 +173,10 @@ class PrefixTempFileYStore(TempFileYStore): prefix_dir = "my_prefix_" """ - prefix_dir: Optional[str] = None - base_dir: Optional[str] = None + prefix_dir: str | None = None + base_dir: str | None = None - def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log=None): + def __init__(self, path: str, metadata_callback: Callable | None = None, log=None): full_path = str(Path(self.get_base_dir()) / path) super().__init__(full_path, metadata_callback=metadata_callback, log=log) @@ -184,24 +204,25 @@ class MySQLiteYStore(SQLiteYStore): # Determines the "time to live" for all documents, i.e. how recent the # latest update of a document must be before purging document history. # Defaults to never purging document history (None). - document_ttl: Optional[int] = None + document_ttl: int | None = None path: str - lock: anyio.Lock - db_initialized: anyio.Event + lock: Lock + db_initialized: Event - def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log=None): + def __init__(self, path: str, metadata_callback: Callable | None = None, log=None): self.path = path self.metadata_callback = metadata_callback self.log = log or logging.getLogger(__name__) - self.lock = anyio.Lock() - self.db_initialized = anyio.Event() + self.lock = Lock() + self.db_initialized = Event() async def start(self): - async with anyio.create_task_group() as self._task_group: - self._task_group.start_soon(self._init_db) + if self._task_group is not None: + raise RuntimeError("YStore already running") - def stop(self): - self._task_group.cancel_scope.cancel() + async with create_task_group() as self._task_group: + self._task_group.start_soon(self._init_db) + self.started.set() async def _init_db(self): create_db = False @@ -240,7 +261,7 @@ async def _init_db(self): await db.commit() self.db_initialized.set() - async def read(self) -> AsyncIterator[Tuple[bytes, bytes, float]]: # type: ignore + async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore await self.db_initialized.wait() try: async with self.lock: