From 67619f2af7a4a6f4fa2dde465f7a53dc39aa3a03 Mon Sep 17 00:00:00 2001 From: Nick Macholl Date: Wed, 8 Nov 2023 10:45:14 -0700 Subject: [PATCH] FIX: Live client stream flush on timeout --- CHANGELOG.md | 7 ++++-- databento/live/client.py | 15 +++++++++-- tests/test_live_client.py | 53 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2539120..26905a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,13 @@ # Changelog -### 0.24.0 - TBD +## 0.24.0 - TBD -##### Enhancements +#### Enhancements - Added new publishers for consolidated DBEQ.BASIC and DBEQ.PLUS +#### Bug fixes +- Fixed an issue where `Live.block_for_close` and `Live.wait_for_close` would not flush streams if the timeout was reached. + ## 0.23.0 - 2023-10-26 #### Enhancements diff --git a/databento/live/client.py b/databento/live/client.py index 03a74d3..ee7436e 100644 --- a/databento/live/client.py +++ b/databento/live/client.py @@ -498,6 +498,7 @@ def terminate(self) -> None: if self._session is None: raise ValueError("cannot terminate a live client before it is connected") self._session.abort() + self._cleanup_client() def block_for_close( self, @@ -539,6 +540,8 @@ def block_for_close( raise except Exception: raise BentoError("connection lost") from None + finally: + self._cleanup_client() async def wait_for_close( self, @@ -581,9 +584,13 @@ async def wait_for_close( self.terminate() if isinstance(exc, KeyboardInterrupt): raise + except BentoError: + raise except Exception: logger.exception("exception encountered waiting for close") raise BentoError("connection lost") from None + finally: + self._cleanup_client() async def _shutdown(self) -> None: """ @@ -597,6 +604,12 @@ async def _shutdown(self) -> None: return await self._session.wait_for_close() + def _cleanup_client(self) -> None: + """ + Cleanup any stateful client data. + """ + self._symbology_map.clear() + to_remove = [] for stream in self._user_streams: stream_name = getattr(stream, "name", str(stream)) @@ -609,8 +622,6 @@ async def _shutdown(self) -> None: for key in to_remove: self._user_streams.pop(key) - self._symbology_map.clear() - def _map_symbol(self, record: DBNRecord) -> None: if isinstance(record, databento_dbn.SymbolMappingMsg): out_symbol = record.stype_out_symbol diff --git a/tests/test_live_client.py b/tests/test_live_client.py index c270a04..82f4629 100644 --- a/tests/test_live_client.py +++ b/tests/test_live_client.py @@ -525,6 +525,31 @@ def test_live_block_for_close_timeout( live_client.terminate.assert_called_once() # type: ignore +@pytest.mark.usefixtures("mock_live_server") +def test_live_block_for_close_timeout_stream( + live_client: client.Live, + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, +) -> None: + """ + Test that block_for_close flushes user streams on timeout. + """ + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + stype_in=SType.INSTRUMENT_ID, + symbols="ALL_SYMBOLS", + start=None, + ) + path = tmp_path / "test.dbn" + stream = path.open("wb") + monkeypatch.setattr(stream, "flush", MagicMock()) + live_client.add_stream(stream) + + live_client.block_for_close(timeout=0) + stream.flush.assert_called() # type: ignore [attr-defined] + + @pytest.mark.usefixtures("mock_live_server") async def test_live_wait_for_close( live_client: client.Live, @@ -571,6 +596,32 @@ async def test_live_wait_for_close_timeout( live_client.terminate.assert_called_once() # type: ignore +@pytest.mark.usefixtures("mock_live_server") +async def test_live_wait_for_close_timeout_stream( + live_client: client.Live, + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, +) -> None: + """ + Test that wait_for_close flushes user streams on timeout. + """ + live_client.subscribe( + dataset=Dataset.GLBX_MDP3, + schema=Schema.MBO, + stype_in=SType.INSTRUMENT_ID, + symbols="ALL_SYMBOLS", + start=None, + ) + + path = tmp_path / "test.dbn" + stream = path.open("wb") + monkeypatch.setattr(stream, "flush", MagicMock()) + live_client.add_stream(stream) + + await live_client.wait_for_close(timeout=0) + stream.flush.assert_called() # type: ignore [attr-defined] + + def test_live_add_callback( live_client: client.Live, ) -> None: @@ -615,6 +666,7 @@ def test_live_add_stream_invalid( with pytest.raises(ValueError): live_client.add_stream(readable_file.open(mode="rb")) + def test_live_add_stream_path_directory( tmp_path: pathlib.Path, live_client: client.Live, @@ -625,6 +677,7 @@ def test_live_add_stream_path_directory( with pytest.raises(OSError): live_client.add_stream(tmp_path) + @pytest.mark.skipif(platform.system() == "Windows", reason="flaky on windows runner") async def test_live_async_iteration( live_client: client.Live,