From 49a45305416f70a5c7c969f1017e5b3e871db812 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 1 Jun 2023 23:07:55 +0200 Subject: [PATCH] Implement ASGI server (#75) --- README.md | 15 +++++++- pyproject.toml | 1 + tests/test_asgi.py | 42 ++++++++++++++++++++++ ypy_websocket/__init__.py | 1 + ypy_websocket/asgi.py | 76 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 tests/test_asgi.py create mode 100644 ypy_websocket/asgi.py diff --git a/README.md b/README.md index 7d91c3c..7ecfa01 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # ypy-websocket -ypy-websocket is an async WebSocket connector for Ypy. +ypy-websocket is an ASGI-compatible async WebSocket connector for Ypy. ## Usage @@ -46,6 +46,19 @@ async def server(): asyncio.run(server()) ``` +Or with an ASGI server: + +```py +# main.py +import uvicorn +from ypy_websocket.asgi import Server + +app = Server() + +if __name__ == "__main__": + uvicorn.run("main:app", port=5000, log_level="info") +``` + ### WebSocket API The WebSocket object passed to `WebsocketProvider` and `WebsocketServer.serve` must respect the diff --git a/pyproject.toml b/pyproject.toml index 3f00458..2d0e02e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ test = [ "pytest", "pytest-asyncio", "websockets >=10.0", + "uvicorn", ] [project.urls] diff --git a/tests/test_asgi.py b/tests/test_asgi.py new file mode 100644 index 0000000..4904bc7 --- /dev/null +++ b/tests/test_asgi.py @@ -0,0 +1,42 @@ +import asyncio + +import pytest +import uvicorn +import y_py as Y +from websockets import connect # type: ignore + +from ypy_websocket import ASGIServer, WebsocketProvider, WebsocketServer + +websocket_server = WebsocketServer(auto_clean_rooms=False) +app = ASGIServer(websocket_server) + + +@pytest.mark.asyncio +async def test_asgi(unused_tcp_port): + # server + config = uvicorn.Config("test_asgi:app", port=unused_tcp_port, log_level="info") + server = uvicorn.Server(config) + server_task = asyncio.create_task(server.serve()) + while not server.started: + await asyncio.sleep(0) + + # clients + # client 1 + ydoc1 = Y.YDoc() + ymap1 = ydoc1.get_map("map") + with ydoc1.begin_transaction() as t: + ymap1.set(t, "key", "value") + async with connect(f"ws://localhost:{unused_tcp_port}/my-roomname") as websocket1: + WebsocketProvider(ydoc1, websocket1) + await asyncio.sleep(0.1) + + # client 2 + ydoc2 = Y.YDoc() + async with connect(f"ws://localhost:{unused_tcp_port}/my-roomname") as websocket2: + WebsocketProvider(ydoc2, websocket2) + await asyncio.sleep(0.1) + + ymap2 = ydoc2.get_map("map") + assert ymap2.to_json() == '{"key":"value"}' + + server_task.cancel() diff --git a/ypy_websocket/__init__.py b/ypy_websocket/__init__.py index 6bf5378..56b151a 100644 --- a/ypy_websocket/__init__.py +++ b/ypy_websocket/__init__.py @@ -1,3 +1,4 @@ +from .asgi import Server as ASGIServer # noqa from .websocket_provider import WebsocketProvider # noqa from .websocket_server import WebsocketServer, YRoom # noqa from .yutils import YMessageType # noqa diff --git a/ypy_websocket/asgi.py b/ypy_websocket/asgi.py new file mode 100644 index 0000000..8c1bf3a --- /dev/null +++ b/ypy_websocket/asgi.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any, Awaitable, Callable + +from .websocket_server import WebsocketServer + + +class WebSocket: + def __init__( + self, + receive: Callable[[], Awaitable[dict[str, Any]]], + send: Callable[[dict[str, Any]], Awaitable[None]], + path: str, + on_disconnect: Callable[[dict[str, Any]], Awaitable[None]] | None = None, + ): + self._receive = receive + self._send = send + self._path = path + self._on_disconnect = on_disconnect + + @property + def path(self) -> str: + return self._path + + def __aiter__(self): + return self + + async def __anext__(self) -> bytes: + return await self.recv() + + async def send(self, message: bytes) -> None: + await self._send( + dict( + type="websocket.send", + bytes=message, + ) + ) + + async def recv(self) -> bytes: + message = await self._receive() + if message["type"] == "websocket.receive": + return message["bytes"] + if message["type"] == "websocket.disconnect": + if self._on_disconnect is not None: + await self._on_disconnect(message) + raise StopAsyncIteration() + return b"" + + +class Server: + def __init__( + self, + websocket_server: WebsocketServer, + on_connect: Callable[[dict[str, Any], dict[str, Any]], Awaitable[bool]] | None = None, + on_disconnect: Callable[[dict[str, Any]], Awaitable[None]] | None = None, + ): + self._websocket_server = websocket_server + self._on_connect = on_connect + self._on_disconnect = on_disconnect + + async def __call__( + self, + scope: dict[str, Any], + receive: Callable[[], Awaitable[dict[str, Any]]], + send: Callable[[dict[str, Any]], Awaitable[None]], + ): + msg = await receive() + if msg["type"] == "websocket.connect": + if self._on_connect is not None: + close = await self._on_connect(msg, scope) + if close: + return + + await send({"type": "websocket.accept"}) + websocket = WebSocket(receive, send, scope["path"], self._on_disconnect) + await self._websocket_server.serve(websocket)