Skip to content

Commit

Permalink
Add unit test for issue with exceptions not bubbling out of root task
Browse files Browse the repository at this point in the history
- fix logging bug where task_done() could be called too many times if cancelled during queue.get()
- update unwrap_exception_group to treat as tree instead of list
- fix runner_test needlessly marking all tests integration
- refactor event_loop mock patching into helper method
- add test for exceptions not raising, reproducable with previous impelementation
  • Loading branch information
linkous8 committed May 3, 2024
1 parent 59259f4 commit a84a8b8
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 41 deletions.
2 changes: 1 addition & 1 deletion servo/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ async def shutdown(self) -> None:

async def _process_queue(self) -> None:
while True:
progress = await self._queue.get()
try:
progress = await self._queue.get()
if progress is None:
logger.info(
f"retrieved None from progress queue. halting progress reporting"
Expand Down
28 changes: 22 additions & 6 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,20 +188,36 @@ def json_key_path(json_str: str, key_path: str) -> Any:
def unwrap_exception_group(
excg: ExceptionGroup, expected_type: type[E], expected_count: int | None = None
) -> E | list[E]:
excg_len = len(excg.exceptions)
flattened_group = flatten_exception_group(exception_group=excg)
excg_len = len(flattened_group)
if expected_count is not None:
assert (
excg_len == expected_count
), f"Excpetion group count {excg_len} did not match expected count {expected_count}"

assert excg_len > 0
assert all(
isinstance(e, expected_type) for e in excg.exceptions
), f"Group did not contain only expected type {expected_type}: {excg.exceptions}"
if not all(isinstance(e, expected_type) for e in flattened_group):
raise excg

servo.errors.ServoError.servo_error_from_group
if excg_len > 1:
return list(excg.exceptions)
return flattened_group
else:
return excg.exceptions[0]
return flattened_group[0]


def flatten_exception_group(exception_group: ExceptionGroup) -> list[Exception]:
exc_list = []
# traverse tree of exceptions depth first. Evaluate priority and raise as the "main" exception with others included as property additional exceptions
visit_list = list(exception_group.exceptions)
while visit_list:
exc = visit_list.pop(0)
if isinstance(exc, ExceptionGroup):
visit_list = list(exc.exceptions) + visit_list
else:
exc_list.append(exc)

return exc_list


class Subprocess:
Expand Down
96 changes: 62 additions & 34 deletions tests/runner_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
from contextlib import contextmanager
import pathlib
from typing import AsyncGenerator
import typing

import devtools
import pytest
import pytest_mock
import unittest.mock
Expand All @@ -14,8 +17,6 @@
import tests.fake
import tests.helpers

pytestmark = [pytest.mark.asyncio, pytest.mark.integration]


@pytest.fixture()
async def assembly(servo_yaml: pathlib.Path) -> servo.assembly.Assembly:
Expand Down Expand Up @@ -52,12 +53,9 @@ async def assembly_runner(
await runner.progress_handler.shutdown()


@tests.helpers.api_mock
async def test_assembly_shutdown_with_non_running_servo(
assembly_runner: servo.runner.AssemblyRunner,
):
@contextmanager
def patch_event_loop() -> typing.Iterator[asyncio.AbstractEventLoop]:
event_loop = asyncio.get_event_loop()

# NOTE: using the pytest_mocker fixture for mocking the event loop can cause side effects with pytest-asyncio
# (eg. when fixture mocking the `stop` method, the test will run forever).
# By using unittest.mock, we can ensure the event_loop is restored before exiting this method
Expand All @@ -69,34 +67,58 @@ async def test_assembly_shutdown_with_non_running_servo(
with unittest.mock.patch.object(event_loop, "run_forever", return_value=None):
# run_forever no longer blocks causing loop.close() to be called immediately, stop runner from closing it to prevent errors
with unittest.mock.patch.object(event_loop, "close", return_value=None):
yield event_loop


@pytest.mark.asyncio
@tests.helpers.api_mock
async def test_assembly_shutdown_with_non_running_servo(
assembly_runner: servo.runner.AssemblyRunner,
):
with patch_event_loop() as event_loop:

async def wait_for_servo_running():
while not assembly_runner.assembly.servos[0].is_running:
await asyncio.sleep(0.01)

try:
assembly_runner.run()
except ValueError as e:
if "add_signal_handler() can only be called from the main thread" in str(e):
# https://github.com/pytest-dev/pytest-xdist/issues/620
pytest.xfail("not running in the main thread")
else:
raise

await asyncio.wait_for(wait_for_servo_running(), timeout=2)

# Shutdown the servo to produce edge case error
await assembly_runner.assembly.servos[0].shutdown()
try:
await assembly_runner.assembly.shutdown()
except:
raise
finally:
# Teardown runner asyncio tasks so they don't raise errors when the loop is closed by pytest
await assembly_runner.shutdown(event_loop)


@pytest.mark.timeout(5)
@tests.helpers.api_mock
def test_assembly_run_raises_task_errors(
mocker: pytest_mock.MockerFixture, assembly_runner: servo.runner.AssemblyRunner
):
# with patch_event_loop() as event_loop:
mock = mocker.patch.object(servo.runner.ServoRunner, "run_main_loop")
mock.side_effect = RuntimeError("KABOOM")
with pytest.raises(ExceptionGroup) as eg:
assembly_runner.run()
# event_loop.run_until_complete(assembly_runner._root_task)
print(mock.called)

async def wait_for_servo_running():
while not assembly_runner.assembly.servos[0].is_running:
await asyncio.sleep(0.01)

try:
assembly_runner.run()
except ValueError as e:
if (
"add_signal_handler() can only be called from the main thread"
in str(e)
):
# https://github.com/pytest-dev/pytest-xdist/issues/620
pytest.xfail("not running in the main thread")
else:
raise

await asyncio.wait_for(wait_for_servo_running(), timeout=2)

# Shutdown the servo to produce edge case error
await assembly_runner.assembly.servos[0].shutdown()
try:
await assembly_runner.assembly.shutdown()
except:
raise
finally:
# Teardown runner asyncio tasks so they don't raise errors when the loop is closed by pytest
await assembly_runner.shutdown(event_loop)
assert tests.helpers.unwrap_exception_group(
eg.value, RuntimeError, 1
), devtools.pformat(eg.value)


@pytest.fixture
Expand All @@ -105,6 +127,7 @@ async def servo_runner(assembly: servo.Assembly) -> servo.runner.ServoRunner:
return servo.runner.ServoRunner(assembly.servos[0])


@pytest.mark.asyncio
@pytest.fixture
async def running_servo(
event_loop: asyncio.AbstractEventLoop,
Expand All @@ -131,6 +154,7 @@ async def running_servo(


# TODO: Switch this over to using a FakeAPI
@pytest.mark.asyncio
@pytest.mark.xfail(reason="too brittle.")
async def test_out_of_order_operations(servo_runner: servo.runner.ServoRunner) -> None:
await servo_runner.servo.startup()
Expand Down Expand Up @@ -169,6 +193,7 @@ async def test_out_of_order_operations(servo_runner: servo.runner.ServoRunner) -
servo_runner.logger.info("test logging", operation="ADJUST", progress=55)


@pytest.mark.asyncio
async def test_hello(
servo_runner: servo.runner.ServoRunner,
fakeapi_url: str,
Expand Down Expand Up @@ -226,6 +251,7 @@ async def test_hello(
# # fire up runner.run and check .run, etc.


@pytest.mark.asyncio
async def test_authorization_redacted(
servo_runner: servo.runner.ServoRunner,
fakeapi_url: str,
Expand All @@ -251,6 +277,7 @@ async def test_authorization_redacted(
assert servo_runner.optimizer.token.get_secret_value() not in curlify_log


@pytest.mark.asyncio
async def test_control_sent_on_adjust(
servo_runner: servo.runner.ServoRunner,
fakeapi_url: str,
Expand Down Expand Up @@ -287,6 +314,7 @@ async def wait_for_optimizer_done():


# TODO: This doesn't need to be integration test
@pytest.mark.asyncio
@tests.helpers.api_mock
async def test_adjustment_rejected(
mocker, servo_runner: servo.runner.ServoRunner
Expand Down

0 comments on commit a84a8b8

Please sign in to comment.