Skip to content

Commit

Permalink
Add more context managers
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Jun 8, 2023
1 parent de7d8c4 commit c702f4c
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 88 deletions.
9 changes: 9 additions & 0 deletions ypy_websocket/websocket_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import y_py as Y
from anyio import create_memory_object_stream, create_task_group
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from .yutils import (
Expand All @@ -20,6 +21,7 @@ class WebsocketProvider:
_ydoc: Y.YDoc
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup

def __init__(self, ydoc: Y.YDoc, websocket, log=None):
self._ydoc = ydoc
Expand Down Expand Up @@ -56,3 +58,10 @@ async def _send(self):
await self._websocket.send(message)
except Exception:
pass

async def run(self):
async with create_task_group() as self._task_group:
self._task_group.start_soon(self._run)

def stop(self):
self._task_group.cancel_scope.cancel()
90 changes: 4 additions & 86 deletions ypy_websocket/websocket_server.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,10 @@
import logging
from functools import partial
from typing import Callable, Dict, List, Optional
from typing import Dict, Optional

import y_py as Y
from anyio import create_memory_object_stream, create_task_group
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio import create_task_group

from .awareness import Awareness
from .ystore import BaseYStore
from .yutils import (
YMessageType,
create_update_message,
process_sync_message,
put_updates,
sync,
)


class YRoom:

clients: List
ydoc: Y.YDoc
ystore: Optional[BaseYStore]
_on_message: Optional[Callable]
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_ready: bool
_task_group: TaskGroup
_entered: bool

def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None):
self.ydoc = Y.YDoc()
self.awareness = Awareness(self.ydoc)
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
)
self._ready = False
self.ready = ready
self.ystore = ystore
self.log = log or logging.getLogger(__name__)
self.clients = []
self._on_message = None
self._entered = False

async def enter(self):
if self._entered:
return

async with create_task_group() as self._task_group:
self._task_group.start_soon(self._broadcast_updates)
self._entered = True

@property
def ready(self) -> bool:
return self._ready

@ready.setter
def ready(self, value: bool) -> None:
self._ready = value
if value:
self.ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream))

@property
def on_message(self) -> Optional[Callable]:
return self._on_message

@on_message.setter
def on_message(self, value: Optional[Callable]):
self._on_message = value

async def _broadcast_updates(self):
async with self._update_receive_stream:
async for update in self._update_receive_stream:
if self._task_group.cancel_scope.cancel_called:
return
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
if self.ystore:
self.log.debug("Writing Y update to YStore")
self._task_group.start_soon(self.ystore.write, update)

def exit(self):
self._task_group.cancel_scope.cancel()
from .yroom import YRoom
from .yutils import YMessageType, process_sync_message, sync


class WebsocketServer:
Expand Down
95 changes: 95 additions & 0 deletions ypy_websocket/yroom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging
from contextlib import AsyncExitStack
from functools import partial
from typing import Callable, List, Optional

import y_py as Y
from anyio import create_memory_object_stream, create_task_group
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from .awareness import Awareness
from .ystore import BaseYStore
from .yutils import create_update_message, put_updates


class YRoom:

clients: List
ydoc: Y.YDoc
ystore: Optional[BaseYStore]
_on_message: Optional[Callable]
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_ready: bool
_task_group: TaskGroup
_entered: bool

def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None):
self.ydoc = Y.YDoc()
self.awareness = Awareness(self.ydoc)
self._update_send_stream, self._update_receive_stream = create_memory_object_stream(
max_buffer_size=65536
)
self._ready = False
self.ready = ready
self.ystore = ystore
self.log = log or logging.getLogger(__name__)
self.clients = []
self._on_message = None
self._entered = False

@property
def ready(self) -> bool:
return self._ready

@ready.setter
def ready(self, value: bool) -> None:
self._ready = value
if value:
self.ydoc.observe_after_transaction(partial(put_updates, self._update_send_stream))

@property
def on_message(self) -> Optional[Callable]:
return self._on_message

@on_message.setter
def on_message(self, value: Optional[Callable]):
self._on_message = value

async def _broadcast_updates(self):
async with self._update_receive_stream:
async for update in self._update_receive_stream:
if self._task_group.cancel_scope.cancel_called:
return
# broadcast internal ydoc's update to all clients, that includes changes from the
# clients and changes from the backend (out-of-band changes)
for client in self.clients:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
message = create_update_message(update)
self._task_group.start_soon(client.send, message)
if self.ystore:
self.log.debug("Writing Y update to YStore")
self._task_group.start_soon(self.ystore.write, update)

async def __aenter__(self):
async with AsyncExitStack() as exit_stack:
tg = create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
tg.start_soon(self._broadcast_updates)

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._task_group.cancel_scope.cancel()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def enter(self):
if self._entered:
return

async with create_task_group() as self._task_group:
self._task_group.start_soon(self._broadcast_updates)
self._entered = True

def exit(self):
self._task_group.cancel_scope.cancel()
22 changes: 20 additions & 2 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import time
from abc import ABC, abstractmethod
from contextlib import AsyncExitStack
from pathlib import Path
from typing import AsyncIterator, Callable, Optional, Tuple

Expand Down Expand Up @@ -37,6 +38,9 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
async def start(self):
pass

def stop(self):
pass

async def get_metadata(self) -> bytes:
metadata = b"" if not self.metadata_callback else await self.metadata_callback()
return metadata
Expand Down Expand Up @@ -181,10 +185,24 @@ def __init__(self, path: str, metadata_callback: Optional[Callable] = None, log=
self.lock = anyio.Lock()
self.db_initialized = anyio.Event()

async def start(self):
async with anyio.create_task_group() as tg:
async def __aenter__(self):
async with AsyncExitStack() as exit_stack:
tg = anyio.create_task_group()
self._task_group = await exit_stack.enter_async_context(tg)
self._exit_stack = exit_stack.pop_all()
tg.start_soon(self._init_db)

async def __aexit__(self, exc_type, exc_value, exc_tb):
self._task_group.cancel_scope.cancel()
return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb)

async def start(self):
async with anyio.create_task_group() as self._task_group:
self._task_group.start_soon(self._init_db)

def stop(self):
self._task_group.cancel_scope.cancel()

async def _init_db(self):
create_db = False
move_db = False
Expand Down

0 comments on commit c702f4c

Please sign in to comment.