From 062947648b7cd9bcbf837a9760151aa8466bbb8b Mon Sep 17 00:00:00 2001 From: Linkous Sharp Date: Tue, 23 Apr 2024 14:33:52 -0500 Subject: [PATCH 1/6] First pass update gather methods to prevent task leaks on error This update replaces instances of gather() with (default behavior) return_exceptions=false since only the future that raises the exception will be stopped. TaskGroup ensures all taks complete or are cancelled before the with block is exited. This update is also applied to gather()s that are wrapped by task leak handling identical to TaskGroup to reduce the boilerplate added by said handling This update is not applied to gather invocations that need to return (all) exceptions since that is not supported by TaskGroup. It is also not applied to gathers used as a shorthand for try catch pass logic (except for a case where logging was improved) NOTE we avoid the use of create_task in generators in favor of list comprehensions which call create_task right away instead of lazily (doesn't matter for task result) NOTE TaskGroup also prevents the task from being garbage collected which allows cleanup of variables for whose only purpose was ensuring the lifetime of the task reference matched the scope of the function --- servo/assembly.py | 6 +- servo/cli.py | 14 ++--- servo/connectors/kubernetes.py | 108 ++++++++++++++------------------- servo/connectors/prometheus.py | 35 ++++++----- servo/events.py | 31 ++++------ servo/runner.py | 6 +- servo/utilities/subprocess.py | 41 ++++++------- 7 files changed, 105 insertions(+), 136 deletions(-) diff --git a/servo/assembly.py b/servo/assembly.py index 67cfeb0cc..99acbf800 100644 --- a/servo/assembly.py +++ b/servo/assembly.py @@ -150,9 +150,9 @@ async def assemble( ) # Attach all connectors to the servo - await asyncio.gather( - *list(map(lambda s: s.dispatch_event(servo.servo.Events.attach, s), servos)) - ) + async with asyncio.TaskGroup() as tg: + for s in servos: + _ = tg.create_task(s.dispatch_event(servo.servo.Events.attach, s)) return assembly diff --git a/servo/cli.py b/servo/cli.py index d573eee5e..127d341de 100644 --- a/servo/cli.py +++ b/servo/cli.py @@ -1123,14 +1123,12 @@ def print_callback(input: str) -> None: # gather() expects a loop to exist at invocation time which is not compatible with the run_async # execution model. wrap the gather in a standard async function to work around this async def gather_checks(): - return await asyncio.gather( - *list( - map( - lambda s: s.check_servo(print_callback), - context.assembly.servos, - ) - ), - ) + async with asyncio.TaskGroup() as tg: + tasks = [ + tg.create_task(s.check_servo(print_callback)) + for s in context.assembly.servos + ] + return (t.result() for t in tasks) results = run_async(gather_checks()) ready = functools.reduce(lambda x, y: x and y, results) diff --git a/servo/connectors/kubernetes.py b/servo/connectors/kubernetes.py index ad112c6f6..1406692d0 100644 --- a/servo/connectors/kubernetes.py +++ b/servo/connectors/kubernetes.py @@ -1141,27 +1141,15 @@ async def create_tuning_pod(self) -> V1Pod: ) ) progress.start() - - task = asyncio.create_task(PodHelper.wait_until_ready(tuning_pod)) - task.add_done_callback(lambda _: progress.complete()) - gather_task = asyncio.gather( - task, - progress.watch(progress_logger), - ) - try: - await asyncio.wait_for(gather_task, timeout=self.timeout.total_seconds()) + async with asyncio.timeout(delay=self.timeout.total_seconds()): + async with asyncio.TaskGroup() as tg: + task = tg.create_task(PodHelper.wait_until_ready(tuning_pod)) + task.add_done_callback(lambda _: progress.complete()) + _ = tg.create_task(progress.watch(progress_logger)) except asyncio.TimeoutError: servo.logger.error(f"Timed out waiting for Tuning Pod to become ready...") - servo.logger.debug(f"Cancelling Task: {task}, progress: {progress}") - for t in {task, gather_task}: - t.cancel() - with contextlib.suppress(asyncio.CancelledError): - await t - servo.logger.debug(f"Cancelled Task: {t}, progress: {progress}") - - # get latest status of tuning pod for raise_for_status await self.raise_for_status() # Hydrate local state @@ -1631,34 +1619,32 @@ async def apply(self, adjustments: List[servo.Adjustment]) -> None: # TODO: Run sanity checks to look for out of band changes async def raise_for_status(self) -> None: - handle_error_tasks = [] - - def _raise_for_task(task: asyncio.Task, optimization: BaseOptimization) -> None: - if task.done() and not task.cancelled(): - if exception := task.exception(): - handle_error_tasks.append( - asyncio.create_task(optimization.handle_error(exception)) - ) - - tasks = [] - for optimization in self.optimizations: - task = asyncio.create_task(optimization.raise_for_status()) - task.add_done_callback( - functools.partial(_raise_for_task, optimization=optimization) - ) - tasks.append(task) - - for future in asyncio.as_completed( - tasks, timeout=self.config.timeout.total_seconds() - ): - try: - await future - except Exception as error: - servo.logger.exception(f"Optimization failed with error: {error}") + # TODO: first handle_error_task to raise will likely interrupt other tasks. + # Gather with return_exceptions=True and aggregate resulting exceptions into group before raising + async with asyncio.TaskGroup() as tg: + + def _raise_for_task( + task: asyncio.Task, optimization: BaseOptimization + ) -> None: + if task.done() and not task.cancelled(): + if exception := task.exception(): + _ = tg.create_task(optimization.handle_error(exception)) + + tasks = [] + for optimization in self.optimizations: + task = asyncio.create_task(optimization.raise_for_status()) + task.add_done_callback( + functools.partial(_raise_for_task, optimization=optimization) + ) + tasks.append(task) - # TODO: first handler to raise will likely interrupt other tasks. - # Gather with return_exceptions=True and aggregate resulting exceptions before raising - await asyncio.gather(*handle_error_tasks) + for future in asyncio.as_completed( + tasks, timeout=self.config.timeout.total_seconds() + ): + try: + await future + except Exception as error: + servo.logger.exception(f"Optimization failed with error: {error}") async def is_ready(self): if self.optimizations: @@ -1666,14 +1652,13 @@ async def is_ready(self): f"Checking for readiness of {len(self.optimizations)} optimizations" ) try: - results = await asyncio.wait_for( - asyncio.gather( - *list(map(lambda a: a.is_ready(), self.optimizations)), - ), - timeout=self.config.timeout.total_seconds(), - ) + async with asyncio.timeout(delay=self.config.timeout.total_seconds()): + async with asyncio.TaskGroup() as tg: + results = [ + tg.create_task(o.is_ready()) for o in self.optimizations + ] - return all(results) + return all((r.result() for r in results)) except asyncio.TimeoutError: return False @@ -2297,15 +2282,13 @@ async def adjust( progress=p.progress, ) progress = servo.EventProgress(timeout=self.config.timeout) - future = asyncio.create_task(state.apply(adjustments)) - future.add_done_callback(lambda _: progress.trigger()) # Catch-all for spaghettified non-EventError usage try: - await asyncio.gather( - future, - progress.watch(progress_logger), - ) + async with asyncio.TaskGroup() as tg: + future = tg.create_task(state.apply(adjustments)) + future.add_done_callback(lambda _: progress.trigger()) + _ = tg.create_task(progress.watch(progress_logger)) # Handle settlement settlement = control.settlement or self.config.settlement @@ -2383,13 +2366,10 @@ async def _create_optimizations(self) -> KubernetesOptimizations: ) progress = servo.EventProgress(timeout=self.config.timeout) try: - future = asyncio.create_task(KubernetesOptimizations.create(self.config)) - future.add_done_callback(lambda _: progress.trigger()) - - await asyncio.gather( - future, - progress.watch(progress_logger), - ) + async with asyncio.TaskGroup() as tg: + future = tg.create_task(KubernetesOptimizations.create(self.config)) + future.add_done_callback(lambda _: progress.trigger()) + _ = tg.create_task(progress.watch(progress_logger)) return future.result() except Exception as e: diff --git a/servo/connectors/prometheus.py b/servo/connectors/prometheus.py index 83094a3a1..fc3961b7d 100644 --- a/servo/connectors/prometheus.py +++ b/servo/connectors/prometheus.py @@ -994,28 +994,25 @@ async def measure( ), ) fast_fail_progress = servo.EventProgress(timeout=measurement_duration) - gather_tasks = [ - asyncio.create_task(progress.watch(self.observe)), - asyncio.create_task( + async with asyncio.TaskGroup() as tg: + _ = tg.create_task(progress.watch(self.observe)) + _ = tg.create_task( fast_fail_progress.watch( - fast_fail_observer.observe, every=self.config.fast_fail.period + fast_fail_observer.observe, + every=self.config.fast_fail.period, ) - ), - ] - try: - await asyncio.gather(*gather_tasks) - except: - [task.cancel() for task in gather_tasks] - await asyncio.gather(*gather_tasks, return_exceptions=True) - raise + ) + else: await progress.watch(self.observe) # Capture the measurements self.logger.info(f"Querying Prometheus for {len(metrics__)} metrics...") - readings = await asyncio.gather( - *list(map(lambda m: self._query_prometheus(m, start, end), metrics__)) - ) + async with asyncio.TaskGroup() as tg: + q_tasks = [ + tg.create_task(self._query_prometheus(m, start, end)) for m in metrics__ + ] + readings = (qt.result() for qt in q_tasks) all_readings = ( functools.reduce(lambda x, y: x + y, readings) if readings else [] ) @@ -1077,9 +1074,11 @@ async def _query_slo_metrics( self, start: datetime, end: datetime, metrics: List[PrometheusMetric] ) -> Dict[str, List[servo.TimeSeries]]: """Query prometheus for the provided metrics and return mapping of metric names to their corresponding readings""" - readings = await asyncio.gather( - *list(map(lambda m: self._query_prometheus(m, start, end), metrics)) - ) + async with asyncio.TaskGroup() as tg: + q_tasks = [ + tg.create_task(self._query_prometheus(m, start, end)) for m in metrics + ] + readings = (qt.result() for qt in q_tasks) return dict(map(lambda tup: (tup[0].name, tup[1]), zip(metrics, readings))) diff --git a/servo/events.py b/servo/events.py index ba01dc1c9..47fb3ba17 100644 --- a/servo/events.py +++ b/servo/events.py @@ -1038,36 +1038,31 @@ async def run(self) -> List[EventResult]: if results: break else: - group = asyncio.gather( - *list( - map( - lambda c: c.run_event_handlers( + async with asyncio.TaskGroup() as tg: + ev_tasks = [ + tg.create_task( + c.run_event_handlers( self.event, Preposition.on, return_exceptions=self._return_exceptions, *self._args, **self._kwargs, - ), - self._connectors, + ) ) - ), - ) - results = await group + for c in self._connectors + ] + + results = (et.result() for et in ev_tasks) results = list(filter(lambda r: r is not None, results)) results = functools.reduce(lambda x, y: x + y, results, []) # Invoke the after event handlers if self._prepositions & Preposition.after: - await asyncio.gather( - *list( - map( - lambda c: c.run_event_handlers( - self.event, Preposition.after, results - ), - self._connectors, + async with asyncio.TaskGroup() as tg: + for c in self._connectors: + _ = tg.create_task( + c.run_event_handlers(self.event, Preposition.after, results) ) - ) - ) if self.channel: await self.channel.close() diff --git a/servo/runner.py b/servo/runner.py index c2706c4d0..6acd83f9f 100644 --- a/servo/runner.py +++ b/servo/runner.py @@ -745,7 +745,11 @@ async def _shutdown(self, loop: asyncio.AbstractEventLoop, signal=None) -> None: except Exception as error: self.logger.critical(f"Failed assembly shutdown with error: {error}") - await asyncio.gather(self.progress_handler.shutdown(), return_exceptions=True) + try: + await self.progress_handler.shutdown() + except Exception as error: + self.logger.warning(f"Failed progress handler shutdown with error: {error}") + self.logger.remove(self.progress_handler_id) # Cancel any outstanding tasks -- under a clean, graceful shutdown this list will be empty diff --git a/servo/utilities/subprocess.py b/servo/utilities/subprocess.py index b3cc02692..7ac108598 100644 --- a/servo/utilities/subprocess.py +++ b/servo/utilities/subprocess.py @@ -314,29 +314,28 @@ async def stream_subprocess_output( :raises asyncio.TimeoutError: Raised if the timeout expires before the subprocess exits. :return: The exit status of the subprocess. """ - tasks = [] - if process.stdout: - tasks.append( - asyncio.create_task( - _read_lines_from_output_stream(process.stdout, stdout_callback), - name="stdout", - ) - ) - if process.stderr: - tasks.append( - asyncio.create_task( - _read_lines_from_output_stream(process.stderr, stderr_callback), - name="stderr", - ) - ) timeout_in_seconds = ( timeout.total_seconds() if isinstance(timeout, datetime.timedelta) else timeout ) try: - # Gather the stream output tasks and the parent process - gather_task = asyncio.gather(*tasks, process.wait()) - await asyncio.wait_for(gather_task, timeout=timeout_in_seconds) + async with asyncio.timeout(delay=timeout_in_seconds): + async with asyncio.TaskGroup() as tg: + if process.stdout: + tg.create_task( + _read_lines_from_output_stream(process.stdout, stdout_callback), + name="stdout", + ) + + if process.stderr: + tg.create_task( + _read_lines_from_output_stream(process.stderr, stderr_callback), + name="stderr", + ) + + tg.create_task(process.wait()) + + # Gather the stream output tasks and the parent process (with block does not exit until error or all complete) except (asyncio.TimeoutError, asyncio.CancelledError): with contextlib.suppress(ProcessLookupError): @@ -351,12 +350,6 @@ async def stream_subprocess_output( process.kill() await process.wait() - with contextlib.suppress(asyncio.CancelledError): - await gather_task - - [task.cancel() for task in tasks] - await asyncio.gather(*tasks, return_exceptions=True) - raise return cast(int, process.returncode) From ff281546b73b3ef37bcfe18fcb2a6a1002fd2117 Mon Sep 17 00:00:00 2001 From: Linkous Sharp Date: Wed, 24 Apr 2024 11:26:26 -0500 Subject: [PATCH 2/6] Update except certain except blocks to use except-star to catch ExceptionGroups raised by TaskGroups --- servo/cli.py | 4 ++-- servo/connectors/kubernetes.py | 7 ++++++- servo/events.py | 2 +- servo/runner.py | 8 +++++--- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/servo/cli.py b/servo/cli.py index 127d341de..13566ca17 100644 --- a/servo/cli.py +++ b/servo/cli.py @@ -1133,12 +1133,12 @@ async def gather_checks(): results = run_async(gather_checks()) ready = functools.reduce(lambda x, y: x and y, results) - except servo.ConnectorNotFoundError as e: + except* servo.ConnectorNotFoundError as e: typer.echo( "A connector named within the checks config was not found in the current Assembly" ) raise typer.Exit(1) from e - except servo.EventHandlersNotFoundError as e: + except* servo.EventHandlersNotFoundError as e: typer.echo( "At least one configured connector must respond to the Check event (Note the servo " "responds to checks so this error should never raise unless something is well and truly wrong" diff --git a/servo/connectors/kubernetes.py b/servo/connectors/kubernetes.py index 1406692d0..2b7ca236a 100644 --- a/servo/connectors/kubernetes.py +++ b/servo/connectors/kubernetes.py @@ -2308,7 +2308,7 @@ async def readiness_monitor() -> None: # Raise a specific exception if the optimization defines one try: await state.raise_for_status() - except servo.AdjustmentRejectedError as e: + except* servo.AdjustmentRejectedError as e: # Update rejections with start-failed to indicate the initial rollout was successful if e.reason == "start-failed": e.reason = "unstable" @@ -2333,6 +2333,11 @@ async def readiness_monitor() -> None: ) description = state.to_description() + except ExceptionGroup as eg: + if any(isinstance(se, servo.EventError) for se in eg.exceptions): + raise + else: + raise servo.AdjustmentFailedError(str(eg.message)) from eg except servo.EventError: # this is recognized by the runner raise except Exception as e: diff --git a/servo/events.py b/servo/events.py index 47fb3ba17..0f166cf6a 100644 --- a/servo/events.py +++ b/servo/events.py @@ -1016,7 +1016,7 @@ async def run(self) -> List[EventResult]: **self._kwargs, ) - except servo.errors.EventCancelledError as error: + except* servo.errors.EventCancelledError as error: # Return an empty result set servo.logger.warning( f'event cancelled by before event handler on connector "{connector.name}": {error}' diff --git a/servo/runner.py b/servo/runner.py index 6acd83f9f..dc619d82b 100644 --- a/servo/runner.py +++ b/servo/runner.py @@ -160,7 +160,7 @@ async def exec_command(self) -> servo.api.Status: descriptor=description.__opsani_repr__(), command_uid=cmd_response.command_uid, ) - except servo.errors.EventError as error: + except* servo.errors.EventError as error: self.logger.error(f"Describe failed: {error}") status = servo.api.Status.from_error( error=error, @@ -183,7 +183,7 @@ async def exec_command(self) -> servo.api.Status: command_uid=cmd_response.command_uid, **measurement.__opsani_repr__(), ) - except servo.errors.EventError as error: + except* servo.errors.EventError as error: self.logger.error(f"Measurement failed: {error}") status = servo.api.Status.from_error( error=error, @@ -215,7 +215,7 @@ async def exec_command(self) -> servo.api.Status: self.logger.success( f"Adjusted: {components_count} components, {settings_count} settings" ) - except servo.EventError as error: + except* servo.EventError as error: self.logger.error(f"Adjustment failed: {error}") status = servo.api.Status.from_error( error, @@ -557,6 +557,7 @@ async def handle_progress_exception( ) return + # TODO try to abort a TaskGroup here tasks = [ t for t in asyncio.all_tasks() if t is not asyncio.current_task() ] @@ -754,6 +755,7 @@ async def _shutdown(self, loop: asyncio.AbstractEventLoop, signal=None) -> None: # Cancel any outstanding tasks -- under a clean, graceful shutdown this list will be empty # The shutdown of the assembly and the servo should clean up its tasks + # TODO try killing a task group here instead tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] if len(tasks): [task.cancel() for task in tasks] From efb315b985c7c4a5ea5f92301780e5e6bf4bb9cb Mon Sep 17 00:00:00 2001 From: Linkous Sharp Date: Wed, 24 Apr 2024 11:55:41 -0500 Subject: [PATCH 3/6] Fix except-start cannot have return statements --- servo/events.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/servo/events.py b/servo/events.py index 0f166cf6a..4fa0f6554 100644 --- a/servo/events.py +++ b/servo/events.py @@ -1016,12 +1016,20 @@ async def run(self) -> List[EventResult]: **self._kwargs, ) - except* servo.errors.EventCancelledError as error: - # Return an empty result set - servo.logger.warning( - f'event cancelled by before event handler on connector "{connector.name}": {error}' - ) - return [] + except ExceptionGroup as eg: + if any( + ( + isinstance(se, servo.errors.EventCancelledError) + for se in eg.exceptions + ) + ): + # Return an empty result set + servo.logger.warning( + f'event cancelled by before event handler on connector "{connector.name}": {eg}' + ) + return [] + else: + raise # Invoke the on event handlers and gather results if self._prepositions & Preposition.on: From 8c4d7121bf6d07504b7202738876896b6b33722a Mon Sep 17 00:00:00 2001 From: Linkous Sharp Date: Wed, 24 Apr 2024 12:52:13 -0500 Subject: [PATCH 4/6] Fix test failures --- servo/events.py | 6 ++++++ tests/servo_test.py | 13 ++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/servo/events.py b/servo/events.py index 4fa0f6554..3bd609238 100644 --- a/servo/events.py +++ b/servo/events.py @@ -1016,6 +1016,12 @@ async def run(self) -> List[EventResult]: **self._kwargs, ) + except servo.errors.EventCancelledError as error: + # Return an empty result set + servo.logger.warning( + f'event cancelled by before event handler on connector "{connector.name}": {error}' + ) + return [] except ExceptionGroup as eg: if any( ( diff --git a/tests/servo_test.py b/tests/servo_test.py index 582e9626e..89a5367bf 100644 --- a/tests/servo_test.py +++ b/tests/servo_test.py @@ -314,9 +314,9 @@ async def test_cannot_cancel_from_on_handlers(mocker, servo: Servo): mock = mocker.patch.object(event_handler, "handler") mock.side_effect = EventCancelledError() - with pytest.raises(TypeError) as error: + with pytest.raises(ExceptionGroup) as error: await servo.dispatch_event("promote") - assert str(error.value) == "Cannot cancel an event from an on handler" + assert str(error.value.exceptions[0]) == "Cannot cancel an event from an on handler" async def test_cannot_cancel_from_after_handlers_warning(mocker, servo: Servo): @@ -326,9 +326,11 @@ async def test_cannot_cancel_from_after_handlers_warning(mocker, servo: Servo): mock = mocker.patch.object(event_handler, "handler") mock.side_effect = EventCancelledError() - with pytest.raises(TypeError) as error: + with pytest.raises(ExceptionGroup) as error: await servo.dispatch_event("promote") - assert str(error.value) == "Cannot cancel an event from an after handler" + assert ( + str(error.value.exceptions[0]) == "Cannot cancel an event from an after handler" + ) async def test_after_handlers_are_not_called_on_failure_raises(mocker, servo: Servo): @@ -340,9 +342,10 @@ async def test_after_handlers_are_not_called_on_failure_raises(mocker, servo: Se on_handler = connector.get_event_handlers("promote", Preposition.on)[0] mock = mocker.patch.object(on_handler, "handler") mock.side_effect = EventError() - with pytest.raises(EventError): + with pytest.raises(ExceptionGroup) as error: await servo.dispatch_event("promote", return_exceptions=False) + assert isinstance(error.value.exceptions[0], EventError) spy.assert_not_called() From 3c14eed07fb8f9270cfcddf0f6e0335341feed12 Mon Sep 17 00:00:00 2001 From: Linkous Sharp Date: Wed, 24 Apr 2024 14:05:47 -0500 Subject: [PATCH 5/6] Further hardening and fixes of asyncio and events - convert before event handlers to run concurrently - fix task group broke (unused) return_exceptions=True case of dispatch_event - rename sub-exception vars for clarity - update relevant test --- servo/connectors/kubernetes.py | 2 +- servo/events.py | 80 +++++++++++++++++----------------- tests/servo_test.py | 4 +- 3 files changed, 44 insertions(+), 42 deletions(-) diff --git a/servo/connectors/kubernetes.py b/servo/connectors/kubernetes.py index 2b7ca236a..9e2eb185c 100644 --- a/servo/connectors/kubernetes.py +++ b/servo/connectors/kubernetes.py @@ -2334,7 +2334,7 @@ async def readiness_monitor() -> None: description = state.to_description() except ExceptionGroup as eg: - if any(isinstance(se, servo.EventError) for se in eg.exceptions): + if any(isinstance(sub_e, servo.EventError) for sub_e in eg.exceptions): raise else: raise servo.AdjustmentFailedError(str(eg.message)) from eg diff --git a/servo/events.py b/servo/events.py index 3bd609238..6dde9b39e 100644 --- a/servo/events.py +++ b/servo/events.py @@ -550,6 +550,7 @@ def decorator(fn: EventCallable) -> EventCallable: if preposition == Preposition.before: # 'before' event takes same args as 'on' event, but returns None ref_signature = ref_signature.replace(return_annotation="None") + servo.utilities.inspect.assert_equal_callable_descriptors( servo.utilities.inspect.CallableDescriptor( signature=ref_signature, @@ -891,6 +892,7 @@ async def run_event_handlers( value=error, ) + # TODO refactor to use ExceptionGroups with events retrievable from exceptions if return_exceptions: results.append(error.__event_result__) else: @@ -1006,36 +1008,34 @@ async def run(self) -> List[EventResult]: # Invoke the before event handlers if self._prepositions & Preposition.before: - for connector in self._connectors: - try: - results = await connector.run_event_handlers( - self.event, - Preposition.before, - *self._args, - return_exceptions=False, - **self._kwargs, - ) + try: + async with asyncio.TaskGroup() as tg: + for connector in self._connectors: + tg.create_task( + connector.run_event_handlers( + self.event, + Preposition.before, + *self._args, + return_exceptions=False, + **self._kwargs, + ) + ) - except servo.errors.EventCancelledError as error: + except ExceptionGroup as eg: + cancelled_errs = [ + sub_e + for sub_e in eg.exceptions + if isinstance(sub_e, servo.errors.EventCancelledError) + ] + if cancelled_errs: # Return an empty result set + canceller_names = (ce.connector.name for ce in cancelled_errs) servo.logger.warning( - f'event cancelled by before event handler on connector "{connector.name}": {error}' + f'event cancelled by before event handler on connector "{", ".join(canceller_names)}": {eg.exceptions}' ) return [] - except ExceptionGroup as eg: - if any( - ( - isinstance(se, servo.errors.EventCancelledError) - for se in eg.exceptions - ) - ): - # Return an empty result set - servo.logger.warning( - f'event cancelled by before event handler on connector "{connector.name}": {eg}' - ) - return [] - else: - raise + else: + raise # Invoke the on event handlers and gather results if self._prepositions & Preposition.on: @@ -1052,21 +1052,23 @@ async def run(self) -> List[EventResult]: if results: break else: - async with asyncio.TaskGroup() as tg: - ev_tasks = [ - tg.create_task( - c.run_event_handlers( - self.event, - Preposition.on, - return_exceptions=self._return_exceptions, - *self._args, - **self._kwargs, - ) - ) - for c in self._connectors - ] + tasks = ( + c.run_event_handlers( + self.event, + Preposition.on, + return_exceptions=self._return_exceptions, + *self._args, + **self._kwargs, + ) + for c in self._connectors + ) + if self._return_exceptions: + results = await asyncio.gather(*tasks) + else: + async with asyncio.TaskGroup() as tg: + tg_tasks = [tg.create_task(t) for t in tasks] + results = (tt.result() for tt in tg_tasks) - results = (et.result() for et in ev_tasks) results = list(filter(lambda r: r is not None, results)) results = functools.reduce(lambda x, y: x + y, results, []) diff --git a/tests/servo_test.py b/tests/servo_test.py index 89a5367bf..64d305923 100644 --- a/tests/servo_test.py +++ b/tests/servo_test.py @@ -272,7 +272,7 @@ async def test_cancellation_of_event_from_before_handler(mocker, servo: Servo): # Mock the before handler to throw a cancel exception mock = mocker.patch.object(before_handler, "handler") - mock.side_effect = EventCancelledError("it burns when I pee") + mock.side_effect = EventCancelledError("it burns when I pee", connector=connector) results = await servo.dispatch_event("promote") # Check that on and after callbacks were never called @@ -284,7 +284,7 @@ async def test_cancellation_of_event_from_before_handler(mocker, servo: Servo): assert messages[0].record["level"].name == "WARNING" assert ( messages[0].record["message"] - == 'event cancelled by before event handler on connector "first_test_servo": it burns when I pee' + == "event cancelled by before event handler on connector \"first_test_servo\": (EventCancelledError('it burns when I pee'),)" ) From 04b98117bfbae25be04d8a5fb586ba36b9645693 Mon Sep 17 00:00:00 2001 From: Linkous Sharp Date: Wed, 24 Apr 2024 19:56:56 -0500 Subject: [PATCH 6/6] Add missing _ var for clarity --- servo/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/servo/events.py b/servo/events.py index 6dde9b39e..357505fcd 100644 --- a/servo/events.py +++ b/servo/events.py @@ -1011,7 +1011,7 @@ async def run(self) -> List[EventResult]: try: async with asyncio.TaskGroup() as tg: for connector in self._connectors: - tg.create_task( + _ = tg.create_task( connector.run_event_handlers( self.event, Preposition.before,