From fdb59c514a79f78c1891b5503899daf820c1050c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 28 Mar 2022 17:44:18 +0100 Subject: [PATCH] execute() to return StateMachineEvent --- .../tests/test_worker_state_machine.py | 7 + distributed/worker.py | 223 +++++++++++------- distributed/worker_state_machine.py | 44 +++- 3 files changed, 194 insertions(+), 80 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 8d2eb8662a9..b24bfbb136e 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -5,6 +5,7 @@ Execute, ReleaseWorkerDataMsg, SendMessageToScheduler, + StateMachineEvent, TaskState, UniqueTaskHeap, ) @@ -99,3 +100,9 @@ def test_sendmsg_to_dict(): # Arbitrary sample class smsg = ReleaseWorkerDataMsg(key="x") assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} + + +@pytest.mark.parametrize("cls", StateMachineEvent.__subclasses__()) +def test_event_slots(cls): + smsg = cls(**dict.fromkeys(cls.__annotations__), stimulus_id="test") + assert not hasattr(smsg, "__dict__") diff --git a/distributed/worker.py b/distributed/worker.py index 0c5d48c0bc1..3869249ad3f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -4,6 +4,7 @@ import bisect import builtins import errno +import functools import heapq import logging import os @@ -15,7 +16,6 @@ import weakref from collections import defaultdict, deque from collections.abc import ( - Awaitable, Callable, Collection, Container, @@ -109,13 +109,18 @@ PROCESSING, READY, AddKeysMsg, + CancelComputeEvent, Execute, + ExecuteFailureEvent, + ExecuteSuccessEvent, InvalidTransition, LongRunningMsg, ReleaseWorkerDataMsg, + RescheduleEvent, RescheduleMsg, SendMessageToScheduler, SerializedTask, + StateMachineEvent, TaskErredMsg, TaskFinishedMsg, TaskState, @@ -135,7 +140,9 @@ # {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)} Recs: TypeAlias = "dict[TaskState, TaskStateState | tuple]" Instructions: TypeAlias = "list[Instruction]" - +else: + Recs = dict + Instructions = list logger = logging.getLogger(__name__) @@ -1775,14 +1782,7 @@ def handle_cancel_compute(self, key: str, stimulus_id: str) -> None: is in state `waiting` or `ready`. Nothing will happen otherwise. """ - ts = self.tasks.get(key) - if ts and ts.state in READY | {"waiting"}: - self.log.append((key, "cancel-compute", stimulus_id, time())) - # All possible dependents of TS should not be in state Processing on - # scheduler side and therefore should not be assigned to a worker, - # yet. - assert not ts.dependents - self.transition(ts, "released", stimulus_id=stimulus_id) + self.handle_stimulus(CancelComputeEvent(key=key, stimulus_id=stimulus_id)) def handle_acquire_replicas( self, @@ -2597,6 +2597,24 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: else: self._handle_instructions(instructions) + def handle_stimulus(self, stim: StateMachineEvent) -> None: + with log_errors(): + # self.stimulus_history.append(stim) # TODO + recs, instructions = self.handle_event(stim) + self.transitions(recs, stimulus_id=stim.stimulus_id) + self._handle_instructions(instructions) + self.ensure_computing() + self.ensure_communicating() + + def _handle_stimulus_from_future( + self, future: asyncio.Future[StateMachineEvent | None] + ) -> None: + with log_errors(): + # This *should* never raise + stim = future.result() + if stim: + self.handle_stimulus(stim) + def _handle_instructions(self, instructions: list[Instruction]) -> None: # TODO this method is temporary. # See final design: https://github.com/dask/distributed/issues/5894 @@ -2604,24 +2622,13 @@ def _handle_instructions(self, instructions: list[Instruction]) -> None: if isinstance(inst, SendMessageToScheduler): self.batched_stream.send(inst.to_dict()) elif isinstance(inst, Execute): - self.loop.add_callback( - self._async_instruction_callback, - self.execute(inst.key, stimulus_id=inst.stimulus_id), - stimulus_id=inst.stimulus_id, - ) + coro = self.execute(inst.key, stimulus_id=inst.stimulus_id) + task = asyncio.create_task(coro) + # TODO track task (at the moment it's fire-and-forget) + task.add_done_callback(self._handle_stimulus_from_future) else: raise TypeError(inst) # pragma: nocover - async def _async_instruction_callback( - self, coro: Awaitable[tuple[Recs, Instructions]], *, stimulus_id: str - ) -> None: - with log_errors(): - recs, instructions = await coro - self.transitions(recs, stimulus_id=stimulus_id) - self._handle_instructions(instructions) - self.ensure_computing() - self.ensure_communicating() - def maybe_transition_long_running( self, ts: TaskState, *, compute_duration: float, stimulus_id: str ): @@ -3410,23 +3417,16 @@ def ensure_computing(self) -> None: pdb.set_trace() raise - async def execute(self, key: str, *, stimulus_id: str) -> tuple[Recs, Instructions]: + async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: - return {}, [] - if key not in self.tasks: - return {}, [] - ts = self.tasks[key] + return None + ts = self.tasks.get(key) + if not ts: + return None + if ts.state == "cancelled": + return CancelComputeEvent(key=ts.key, stimulus_id=stimulus_id) try: - if ts.state == "cancelled": - # This might happen if keys are canceled - logger.debug( - "Trying to execute task %s which is not in executing state anymore", - ts, - ) - ts.done = True - return {ts: "released"}, [] - if self.validate: assert not ts.waiting_for_data assert ts.state == "executing" @@ -3482,59 +3482,124 @@ async def execute(self, key: str, *, stimulus_id: str) -> tuple[Recs, Instructio finally: self.active_keys.discard(ts.key) - key = ts.key - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(key) # type: ignore - assert ts, self.story(key) - ts.done = True - result["key"] = ts.key - value = result.pop("result", None) - ts.startstops.append( - {"action": "compute", "start": result["start"], "stop": result["stop"]} - ) self.threads[ts.key] = result["thread"] - recommendations: Recs = {} + if result["op"] == "task-finished": - ts.nbytes = result["nbytes"] - ts.type = result["type"] - recommendations[ts] = ("memory", value) if self.digests is not None: self.digests["task-duration"].add(result["stop"] - result["start"]) - elif isinstance(result.pop("actual-exception"), Reschedule): - recommendations[ts] = "rescheduled" - else: - logger.warning( - "Compute Failed\n" - "Key: %s\n" - "Function: %s\n" - "args: %s\n" - "kwargs: %s\n" - "Exception: %r\n", - ts.key, - str(funcname(function))[:1000], - convert_args_to_str(args2, max_len=1000), - convert_kwargs_to_str(kwargs2, max_len=1000), - result["exception_text"], - ) - recommendations[ts] = ( - "error", - result["exception"], - result["traceback"], - result["exception_text"], - result["traceback_text"], + return ExecuteSuccessEvent( + key=ts.key, + value=result["result"], + start=result["start"], + stop=result["stop"], + nbytes=result["nbytes"], + type=result["type"], + stimulus_id=stimulus_id, ) - logger.debug("Send compute response to scheduler: %s, %s", ts.key, result) + if isinstance(result["actual-exception"], Reschedule): + return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id) + + logger.warning( + "Compute Failed\n" + "Key: %s\n" + "Function: %s\n" + "args: %s\n" + "kwargs: %s\n" + "Exception: %r\n", + ts.key, + str(funcname(function))[:1000], + convert_args_to_str(args2, max_len=1000), + convert_kwargs_to_str(kwargs2, max_len=1000), + result["exception_text"], + ) + return ExecuteFailureEvent( + key=ts.key, + start=result["start"], + stop=result["stop"], + exception=result["exception"], + traceback=result["traceback"], + exception_text=result["exception_text"], + traceback_text=result["traceback_text"], + stimulus_id=stimulus_id, + ) except Exception as exc: logger.error( "Exception during execution of task %s.", ts.key, exc_info=True ) msg = error_message(exc) - recommendations = {ts: tuple(msg.values())} + return ExecuteFailureEvent( + key=ts.key, + start=None, + stop=None, + exception=msg["exception"], + traceback=msg["traceback"], + exception_text=msg["exception_text"], + traceback_text=msg["traceback_text"], + stimulus_id=stimulus_id, + ) - return recommendations, [] + @functools.singledispatchmethod + def handle_event(self, ev: StateMachineEvent) -> tuple[Recs, Instructions]: + raise TypeError(ev) # pragma: nocover + + @handle_event.register + def _(self, ev: CancelComputeEvent) -> tuple[Recs, Instructions]: + ts = self.tasks.get(ev.key) + if not ts or ts.state not in READY | {"waiting"}: + return {}, [] + + self.log.append((ev.key, "cancel-compute", ev.stimulus_id, time())) + # All possible dependents of ts should not be in state Processing on + # scheduler side and therefore should not be assigned to a worker, yet. + assert not ts.dependents + ts.done = True + return {ts: "released"}, [] + + @handle_event.register + def _(self, ev: ExecuteSuccessEvent) -> tuple[Recs, Instructions]: + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) # type: ignore + assert ts, self.story(ev.key) + + ts.done = True + ts.startstops.append({"action": "compute", "start": ev.start, "stop": ev.stop}) + ts.nbytes = ev.nbytes + ts.type = ev.type + return {ts: ("memory", ev.value)}, [] + + @handle_event.register + def _(self, ev: ExecuteFailureEvent) -> tuple[Recs, Instructions]: + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) # type: ignore + assert ts, self.story(ev.key) + + ts.done = True + if ev.start is not None and ev.stop is not None: + ts.startstops.append( + {"action": "compute", "start": ev.start, "stop": ev.stop} + ) + + return { + ts: ( + "error", + ev.exception, + ev.traceback, + ev.exception_text, + ev.traceback_text, + ) + }, [] + + @handle_event.register + def _(self, ev: RescheduleEvent) -> tuple[Recs, Instructions]: + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) # type: ignore + assert ts, self.story(ev.key) + return {ts: "rescheduled"}, [] def _prepare_args_for_execution( self, ts: TaskState, args: tuple, kwargs: dict[str, Any] diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 310116a8fbd..9ae99cf8cc5 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -321,11 +321,11 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler): key: str +# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @dataclass class RescheduleMsg(SendMessageToScheduler): op = "reschedule" - # Not to be confused with the distributed.Reschedule Exception __slots__ = ("key", "worker") key: str worker: str @@ -347,3 +347,45 @@ class AddKeysMsg(SendMessageToScheduler): __slots__ = ("keys", "stimulus_id") keys: list[str] stimulus_id: str + + +@dataclass +class StateMachineEvent: + __slots__ = ("stimulus_id",) + stimulus_id: str + + +@dataclass +class ExecuteSuccessEvent(StateMachineEvent): + key: str + value: object + start: float + stop: float + nbytes: int + type: type | None + __slots__ = tuple(__annotations__) # type: ignore + + +@dataclass +class ExecuteFailureEvent(StateMachineEvent): + key: str + start: float | None + stop: float | None + exception: bytes # serialized + traceback: bytes # serialized + exception_text: str + traceback_text: str + __slots__ = tuple(__annotations__) # type: ignore + + +@dataclass +class CancelComputeEvent(StateMachineEvent): + __slots__ = ("key",) + key: str + + +# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception +@dataclass +class RescheduleEvent(StateMachineEvent): + __slots__ = ("key",) + key: str