From 37d862757b5bfe14f4ad8e82e8da991a196ee30d Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Tue, 10 Oct 2023 14:16:22 -0700 Subject: [PATCH] MOD: Add repeater mode to mock live server --- tests/conftest.py | 38 ++++++-- tests/mock_live_server.py | 184 ++++++++++++++++++++++++-------------- 2 files changed, 147 insertions(+), 75 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 521da44..99e271a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,11 @@ """ from __future__ import annotations +import asyncio import pathlib import random import string +import threading from collections.abc import AsyncGenerator from collections.abc import Generator from collections.abc import Iterable @@ -188,15 +190,37 @@ async def fixture_mock_live_server( 1, ) + loop = asyncio.new_event_loop() + thread = threading.Thread( + name="MockLiveServer", + target=loop.run_forever, + args=(), + daemon=True, + ) + thread.start() + with caplog.at_level("DEBUG"): - mock_live_server = await MockLiveServer.create( - host="127.0.0.1", - port=unused_tcp_port, - dbn_path=TESTS_ROOT / "data", - ) - await mock_live_server.start() + mock_live_server = asyncio.run_coroutine_threadsafe( + coro=MockLiveServer.create( + host="127.0.0.1", + port=unused_tcp_port, + dbn_path=TESTS_ROOT / "data", + ), + loop=loop, + ).result() + yield mock_live_server - await mock_live_server.stop() + + asyncio.run_coroutine_threadsafe( + coro=mock_live_server.stop(), + loop=loop, + ).result() + + loop.run_in_executor( + None, + loop.stop, + ) + thread.join() @pytest.fixture(name="historical_client") diff --git a/tests/mock_live_server.py b/tests/mock_live_server.py index 4b928c6..79b1d4a 100644 --- a/tests/mock_live_server.py +++ b/tests/mock_live_server.py @@ -2,6 +2,7 @@ import argparse import asyncio +import enum import logging import os import pathlib @@ -10,6 +11,7 @@ import string import sys import threading +import time from concurrent import futures from functools import singledispatchmethod from io import BytesIO @@ -17,6 +19,7 @@ import zstandard from databento.common import cram +from databento.common.data import SCHEMA_STRUCT_MAP from databento.live.gateway import AuthenticationRequest from databento.live.gateway import AuthenticationResponse from databento.live.gateway import ChallengeRequest @@ -25,7 +28,9 @@ from databento.live.gateway import SessionStart from databento.live.gateway import SubscriptionRequest from databento.live.gateway import parse_gateway_message +from databento_dbn import Metadata from databento_dbn import Schema +from databento_dbn import SType LIVE_SERVER_VERSION: str = "1.0.0" @@ -37,6 +42,11 @@ logger = logging.getLogger(__name__) +class MockLiveMode(enum.Enum): + REPLAY = "replay" + REPEAT = "repeat" + + class MockLiveServerProtocol(asyncio.BufferedProtocol): """ The connection protocol to mock the Databento Live Subscription Gateway. @@ -55,6 +65,8 @@ class MockLiveServerProtocol(asyncio.BufferedProtocol): {ip}:{port}. version : str The server version string. + mode : MockLiveMode + The mode for the mock lsg; defaults to "replay" See Also -------- @@ -68,6 +80,7 @@ def __init__( user_api_keys: dict[str, str], message_queue: MessageQueue, dbn_path: pathlib.Path, + mode: MockLiveMode = MockLiveMode.REPLAY, ) -> None: self.__transport: asyncio.Transport self._buffer: bytearray @@ -76,11 +89,13 @@ def __init__( self._cram_challenge: str = "".join( random.choice(string.ascii_letters) for _ in range(32) # noqa: S311 ) + self._mode = mode self._message_queue = message_queue self._peer: str = "" self._version: str = version self._is_authenticated: bool = False self._is_streaming: bool = False + self._repeater_tasks: set[asyncio.Task[None]] = set() self._dbn_path = dbn_path self._user_api_keys = user_api_keys @@ -122,6 +137,18 @@ def is_streaming(self) -> bool: """ return self._is_streaming + @property + def mode(self) -> MockLiveMode: + """ + Return the mock live server replay mode. + + Returns + ------- + MockLiveMode + + """ + return self._mode + @property def peer(self) -> str: """ @@ -213,11 +240,11 @@ def connection_made( raise RuntimeError(f"cannot write to {transport}") self.__transport = transport - self._buffer = bytearray(1024) + self._buffer = bytearray(2**16) self._data = BytesIO() self._schemas: list[Schema] = [] - peer_host, peer_port = transport.get_extra_info("peername") + peer_host, peer_port, *_ = transport.get_extra_info("peername") self._peer = f"{peer_host}:{peer_port}" logger.info("%s connected to %s", type(self).__name__, self._peer) @@ -367,25 +394,47 @@ def _(self, message: SessionStart) -> None: logger.info("received session start request: %s", str(message).strip()) self._is_streaming = True - for schema in self._schemas: - for test_data_path in self._dbn_path.glob(f"*{schema}.dbn.zst"): - decompressor = zstandard.ZstdDecompressor().stream_reader( - test_data_path.read_bytes(), - ) - logger.info( - "streaming %s for %s schema", - test_data_path.name, - schema, - ) - self.__transport.write(decompressor.readall()) + if self.mode is MockLiveMode.REPLAY: + for schema in self._schemas: + for test_data_path in self._dbn_path.glob(f"*{schema}.dbn.zst"): + decompressor = zstandard.ZstdDecompressor().stream_reader( + test_data_path.read_bytes(), + ) + logger.info( + "streaming %s for %s schema", + test_data_path.name, + schema, + ) + self.__transport.write(decompressor.readall()) + + logger.info( + "data streaming for %d schema(s) completed", + len(self._schemas), + ) - logger.info( - "data streaming for %d schema(s) completed", - len(self._schemas), - ) + self.__transport.write_eof() + self.__transport.close() - self.__transport.write_eof() - self.__transport.close() + elif self.mode is MockLiveMode.REPEAT: + metadata = Metadata("UNIT.TEST", 0, SType.RAW_SYMBOL, [], [], [], []) # type: ignore [call-arg] + self.__transport.write(bytes(metadata)) + + loop = asyncio.get_event_loop() + for schema in self._schemas: + task = loop.create_task(self.repeater(schema)) + self._repeater_tasks.add(task) + task.add_done_callback(self._repeater_tasks.remove) + else: + raise ValueError(f"unsupported mode {MockLiveMode.REPEAT}") + + async def repeater(self, schema: Schema) -> None: + struct = SCHEMA_STRUCT_MAP[schema] + repeated = bytes(struct(*[0] * 12)) # for now we only support MBP_1 + + logger.info("repeating %d bytes for %s", len(repeated), schema) + while not self.__transport.is_closing(): + self.__transport.write(16 * repeated) + await asyncio.sleep(0) class MockLiveServer: @@ -401,6 +450,8 @@ class MockLiveServer: The port of the mock server. server : asyncio.base_events.Server The mock server object. + mode : MockLiveMode + The mock server mode; defaults to "replay". Methods ------- @@ -423,7 +474,7 @@ def __init__(self) -> None: self._user_api_keys: dict[str, str] self._message_queue: MessageQueue self._thread: threading.Thread - self._loop: asyncio.AbstractEventLoop + self._mode: MockLiveMode @property def host(self) -> str: @@ -437,6 +488,18 @@ def host(self) -> str: """ return self._host + @property + def mode(self) -> MockLiveMode: + """ + Return the mock live server mode. + + Returns + ------- + MockLiveMode + + """ + return self._mode + @property def port(self) -> int: """ @@ -468,6 +531,7 @@ def _protocol_factory( message_queue: MessageQueue, version: str, dbn_path: pathlib.Path, + mode: MockLiveMode, ) -> Callable[[], MockLiveServerProtocol]: def factory() -> MockLiveServerProtocol: return MockLiveServerProtocol( @@ -475,6 +539,7 @@ def factory() -> MockLiveServerProtocol: user_api_keys=user_api_keys, message_queue=message_queue, dbn_path=dbn_path, + mode=mode, ) return factory @@ -485,6 +550,7 @@ async def create( host: str = "localhost", port: int = 0, dbn_path: pathlib.Path = pathlib.Path.cwd(), + mode: MockLiveMode = MockLiveMode.REPLAY, ) -> MockLiveServer: """ Create a mock server instance. This factory method is the preferred way @@ -510,22 +576,14 @@ async def create( """ logger.info( - "creating %s with host=%s port=%s dbn_path=%s", + "creating %s with host=%s port=%s dbn_path=%s mode=%s", cls.__name__, host, port, dbn_path, + mode, ) - loop = asyncio.new_event_loop() - thread = threading.Thread( - name="MockLiveServer", - target=loop.run_forever, - args=(), - daemon=True, - ) - thread.start() - user_api_keys: dict[str, str] = {} message_queue: MessageQueue = queue.Queue() # type: ignore @@ -535,32 +593,29 @@ async def create( bucket_id = env_key[-cram.BUCKET_ID_LENGTH :] user_api_keys[bucket_id] = env_key - server = asyncio.run_coroutine_threadsafe( - loop.create_server( - protocol_factory=cls._protocol_factory( - user_api_keys=user_api_keys, - message_queue=message_queue, - version=LIVE_SERVER_VERSION, - dbn_path=dbn_path, - ), - host=host, - port=port, - start_serving=False, + loop = asyncio.get_event_loop() + server = await loop.create_server( + protocol_factory=cls._protocol_factory( + user_api_keys=user_api_keys, + message_queue=message_queue, + version=LIVE_SERVER_VERSION, + dbn_path=dbn_path, + mode=mode, ), - loop=loop, - ).result() + host=host, + port=port, + start_serving=True, + ) mock_live_server = cls() # Initialize the MockLiveServer instance mock_live_server._server = server - mock_live_server._host, mock_live_server._port = server.sockets[ + mock_live_server._host, mock_live_server._port, *_ = server.sockets[ -1 ].getsockname() mock_live_server._user_api_keys = user_api_keys mock_live_server._message_queue = message_queue - mock_live_server._thread = thread - mock_live_server._loop = loop return mock_live_server @@ -612,10 +667,10 @@ def get_message_of_type( If the timeout duration is reached, if specified. """ - start_time = self._loop.time() - end_time = self._loop.time() + timeout + start_time = time.perf_counter() + end_time = time.perf_counter() + timeout while start_time < end_time: - remaining_time = abs(end_time - self._loop.time()) + remaining_time = abs(end_time - time.perf_counter()) try: message = self._message_queue.get(timeout=remaining_time) except queue.Empty: @@ -626,21 +681,6 @@ def get_message_of_type( raise futures.TimeoutError - async def start(self) -> None: - """ - Start the mock server. - """ - logger.info( - "starting %s on %s:%s", - self.__class__.__name__, - self.host, - self.port, - ) - asyncio.run_coroutine_threadsafe( - coro=self.server.start_serving(), - loop=self._loop, - ).result() - async def stop(self) -> None: """ Stop the mock server. @@ -652,9 +692,8 @@ async def stop(self) -> None: self.port, ) - self._loop.call_soon_threadsafe(self.server.close) - self._loop.call_soon_threadsafe(self._loop.stop) - self._thread.join() + self.server.close() + await self.server.wait_closed() if __name__ == "__main__": @@ -678,6 +717,15 @@ async def stop(self) -> None: action="store", help="path to a directory containing DBN files to stream", ) + parser.add_argument( + "-m", + "--mode", + metavar="mode", + default="replay", + choices=(x.value for x in MockLiveMode), + action="store", + help="the mock server live mode", + ) params = parser.parse_args(sys.argv[1:]) @@ -696,9 +744,9 @@ async def stop(self) -> None: host=params.host, port=params.port, dbn_path=pathlib.Path(params.dbn_path), + mode=MockLiveMode(params.mode), ), ) - loop.run_until_complete(mock_live_server.start()) # Serve Forever try: