diff --git a/distributed/core.py b/distributed/core.py index 419205ef4a0..56d5cf2667f 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -596,7 +596,7 @@ async def handle_comm(self, comm): "Failed while closing connection to %r: %s", address, e ) - async def handle_stream(self, comm, extra=None, every_cycle=[]): + async def handle_stream(self, comm, extra=None, every_cycle=()): extra = extra or {} logger.info("Starting established connection") diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 2c97b217aab..4bc66f46288 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,9 +1,13 @@ +from itertools import chain + import pytest from distributed.utils import recursive_to_dict from distributed.worker_state_machine import ( + Instruction, ReleaseWorkerDataMsg, SendMessageToScheduler, + StateMachineEvent, TaskState, UniqueTaskHeap, ) @@ -82,13 +86,32 @@ def test_unique_task_heap(): assert repr(heap) == "" -@pytest.mark.parametrize("cls", SendMessageToScheduler.__subclasses__()) -def test_sendmsg_slots(cls): - smsg = cls(**dict.fromkeys(cls.__annotations__)) - assert not hasattr(smsg, "__dict__") +@pytest.mark.parametrize( + "cls", + chain( + [UniqueTaskHeap], + Instruction.__subclasses__(), + SendMessageToScheduler.__subclasses__(), + StateMachineEvent.__subclasses__(), + ), +) +def test_slots(cls): + params = [ + k + for k in dir(cls) + if not k.startswith("_") and k != "op" and not callable(getattr(cls, k)) + ] + inst = cls(**dict.fromkeys(params)) + assert not hasattr(inst, "__dict__") def test_sendmsg_to_dict(): # Arbitrary sample class smsg = ReleaseWorkerDataMsg(key="x") assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} + + +@pytest.mark.parametrize("cls", StateMachineEvent.__subclasses__()) +def test_event_slots(cls): + smsg = cls(**dict.fromkeys(cls.__annotations__), stimulus_id="test") + assert not hasattr(smsg, "__dict__") diff --git a/distributed/worker.py b/distributed/worker.py index c607bb5afa1..06124887bf4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -4,6 +4,7 @@ import bisect import builtins import errno +import functools import heapq import logging import os @@ -103,36 +104,38 @@ DeprecatedMemoryMonitor, WorkerMemoryManager, ) -from distributed.worker_state_machine import Instruction # noqa: F401 from distributed.worker_state_machine import ( PROCESSING, READY, AddKeysMsg, + AlreadyCancelledEvent, + CancelComputeEvent, + Execute, + ExecuteFailureEvent, + ExecuteSuccessEvent, + Instructions, InvalidTransition, LongRunningMsg, + Recs, + RecsInstrs, ReleaseWorkerDataMsg, + RescheduleEvent, RescheduleMsg, SendMessageToScheduler, SerializedTask, + StateMachineEvent, TaskErredMsg, TaskFinishedMsg, TaskState, + TaskStateState, UniqueTaskHeap, ) if TYPE_CHECKING: - # TODO move to typing (requires Python >=3.10) - from typing_extensions import TypeAlias - from distributed.actor import Actor from distributed.client import Client from distributed.diagnostics.plugin import WorkerPlugin from distributed.nanny import Nanny - from distributed.worker_state_machine import TaskStateState - - # {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)} - Recs: TypeAlias = "dict[TaskState, TaskStateState | tuple]" - Instructions: TypeAlias = "list[Instruction]" logger = logging.getLogger(__name__) @@ -1786,14 +1789,7 @@ def handle_cancel_compute(self, key: str, stimulus_id: str) -> None: is in state `waiting` or `ready`. Nothing will happen otherwise. """ - ts = self.tasks.get(key) - if ts and ts.state in READY | {"waiting"}: - self.log.append((key, "cancel-compute", stimulus_id, time())) - # All possible dependents of TS should not be in state Processing on - # scheduler side and therefore should not be assigned to a worker, - # yet. - assert not ts.dependents - self.transition(ts, "released", stimulus_id=stimulus_id) + self.handle_stimulus(CancelComputeEvent(key=key, stimulus_id=stimulus_id)) def handle_acquire_replicas( self, @@ -1920,7 +1916,7 @@ def handle_compute_task( def transition_missing_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert ts.state == "missing" assert ts.priority is not None @@ -1933,7 +1929,7 @@ def transition_missing_fetch( def transition_missing_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: self._missing_dep_flight.discard(ts) recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id @@ -1943,7 +1939,7 @@ def transition_missing_released( def transition_flight_missing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: assert ts.done ts.state = "missing" self._missing_dep_flight.add(ts) @@ -1952,7 +1948,7 @@ def transition_flight_missing( def transition_released_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert ts.state == "released" assert ts.priority is not None @@ -1965,7 +1961,7 @@ def transition_released_fetch( def transition_generic_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: self.release_key(ts.key, stimulus_id=stimulus_id) recs: Recs = {} for dependency in ts.dependencies: @@ -1982,7 +1978,7 @@ def transition_generic_released( def transition_released_waiting( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert ts.state == "released" assert all(d.key in self.tasks for d in ts.dependencies) @@ -2008,7 +2004,7 @@ def transition_released_waiting( def transition_fetch_flight( self, ts: TaskState, worker, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert ts.state == "fetch" assert ts.who_has @@ -2021,7 +2017,7 @@ def transition_fetch_flight( def transition_memory_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id ) @@ -2030,7 +2026,7 @@ def transition_memory_released( def transition_waiting_constrained( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert ts.state == "waiting" assert not ts.waiting_for_data @@ -2046,14 +2042,14 @@ def transition_waiting_constrained( def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: recs: Recs = {ts: "released"} smsg = RescheduleMsg(key=ts.key, worker=self.address) return recs, [smsg] def transition_executing_rescheduled( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) @@ -2064,7 +2060,7 @@ def transition_executing_rescheduled( def transition_waiting_ready( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert ts.state == "waiting" assert ts.key not in self.ready @@ -2088,7 +2084,7 @@ def transition_cancelled_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: recs: Recs = {} instructions: Instructions = [] if ts._previous == "executing": @@ -2122,7 +2118,7 @@ def transition_generic_error( traceback_text: str, *, stimulus_id: str, - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: ts.exception = exception ts.traceback = traceback ts.exception_text = exception_text @@ -2149,7 +2145,7 @@ def transition_executing_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) @@ -2164,7 +2160,7 @@ def transition_executing_error( def _transition_from_resumed( self, ts: TaskState, finish: TaskStateState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: """`resumed` is an intermediate degenerate state which splits further up into two states depending on what the last signal / next state is intended to be. There are only two viable choices depending on whether @@ -2202,7 +2198,7 @@ def _transition_from_resumed( def transition_resumed_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: """ See Worker._transition_from_resumed """ @@ -2210,7 +2206,7 @@ def transition_resumed_fetch( def transition_resumed_missing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: """ See Worker._transition_from_resumed """ @@ -2224,7 +2220,7 @@ def transition_resumed_waiting(self, ts: TaskState, *, stimulus_id: str): def transition_cancelled_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if ts.done: return {ts: "released"}, [] elif ts._previous == "flight": @@ -2236,14 +2232,14 @@ def transition_cancelled_fetch( def transition_cancelled_resumed( self, ts: TaskState, next: TaskStateState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: ts._next = next ts.state = "resumed" return {}, [] def transition_cancelled_waiting( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if ts.done: return {ts: "released"}, [] elif ts._previous == "executing": @@ -2255,7 +2251,7 @@ def transition_cancelled_waiting( def transition_cancelled_forgotten( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: ts._next = "forgotten" if not ts.done: return {}, [] @@ -2263,7 +2259,7 @@ def transition_cancelled_forgotten( def transition_cancelled_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if not ts.done: ts._next = "released" return {}, [] @@ -2283,7 +2279,7 @@ def transition_cancelled_released( def transition_executing_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: ts._previous = ts.state ts._next = "released" # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 @@ -2293,13 +2289,13 @@ def transition_executing_released( def transition_long_running_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: self.executed_count += 1 return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) def transition_generic_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if value is no_value and ts.key not in self.data: raise RuntimeError( f"Tried to transition task {ts} to `memory` without data available" @@ -2325,7 +2321,7 @@ def transition_generic_memory( def transition_executing_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert ts.state == "executing" or ts.key in self.long_running assert not ts.waiting_for_data @@ -2337,7 +2333,7 @@ def transition_executing_memory( def transition_constrained_executing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert not ts.waiting_for_data assert ts.key not in self.data @@ -2350,12 +2346,12 @@ def transition_constrained_executing( self.available_resources[resource] -= quantity ts.state = "executing" self._executing.add(ts) - self.loop.add_callback(self.execute, ts.key, stimulus_id=stimulus_id) - return {}, [] + instr = Execute(key=ts.key, stimulus_id=stimulus_id) + return {}, [instr] def transition_ready_executing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if self.validate: assert not ts.waiting_for_data assert ts.key not in self.data @@ -2368,12 +2364,10 @@ def transition_ready_executing( ts.state = "executing" self._executing.add(ts) - self.loop.add_callback(self.execute, ts.key, stimulus_id=stimulus_id) - return {}, [] + instr = Execute(key=ts.key, stimulus_id=stimulus_id) + return {}, [instr] - def transition_flight_fetch( - self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + def transition_flight_fetch(self, ts: TaskState, *, stimulus_id: str) -> RecsInstrs: # If this transition is called after the flight coroutine has finished, # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op @@ -2401,7 +2395,7 @@ def transition_flight_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: self._in_flight_tasks.discard(ts) ts.coming_from = None return self.transition_generic_error( @@ -2415,7 +2409,7 @@ def transition_flight_error( def transition_flight_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if ts.done: # FIXME: Is this even possible? Would an assert instead be more # sensible? @@ -2429,13 +2423,13 @@ def transition_flight_released( def transition_cancelled_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: assert ts._next return {ts: ts._next}, [] def transition_executing_long_running( self, ts: TaskState, compute_duration: float, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) @@ -2445,7 +2439,7 @@ def transition_executing_long_running( def transition_released_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: try: recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: @@ -2457,7 +2451,7 @@ def transition_released_memory( def transition_flight_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: self._in_flight_tasks.discard(ts) ts.coming_from = None try: @@ -2471,7 +2465,7 @@ def transition_flight_memory( def transition_released_forgotten( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: recommendations: Recs = {} # Dependents _should_ be released by the scheduler before this if self.validate: @@ -2488,7 +2482,7 @@ def transition_released_forgotten( def _transition( self, ts: TaskState, finish: str | tuple, *args, stimulus_id: str, **kwargs - ) -> tuple[Recs, Instructions]: + ) -> RecsInstrs: if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple assert not args @@ -2608,12 +2602,35 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: else: self._handle_instructions(instructions) - def _handle_instructions(self, instructions: list[Instruction]) -> None: + def handle_stimulus(self, stim: StateMachineEvent) -> None: + with log_errors(): + # self.stimulus_history.append(stim) # TODO + recs, instructions = self.handle_event(stim) + self.transitions(recs, stimulus_id=stim.stimulus_id) + self._handle_instructions(instructions) + self.ensure_computing() + self.ensure_communicating() + + def _handle_stimulus_from_future( + self, future: asyncio.Future[StateMachineEvent | None] + ) -> None: + with log_errors(): + # This *should* never raise + stim = future.result() + if stim: + self.handle_stimulus(stim) + + def _handle_instructions(self, instructions: Instructions) -> None: # TODO this method is temporary. # See final design: https://github.com/dask/distributed/issues/5894 for inst in instructions: if isinstance(inst, SendMessageToScheduler): self.batched_stream.send(inst.to_dict()) + elif isinstance(inst, Execute): + coro = self.execute(inst.key, stimulus_id=inst.stimulus_id) + task = asyncio.create_task(coro) + # TODO track task (at the moment it's fire-and-forget) + task.add_done_callback(self._handle_stimulus_from_future) else: raise TypeError(inst) # pragma: nocover @@ -3405,24 +3422,20 @@ def ensure_computing(self) -> None: pdb.set_trace() raise - async def execute(self, key: str, *, stimulus_id: str) -> None: + async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: - return - if key not in self.tasks: - return - ts = self.tasks[key] + return None + ts = self.tasks.get(key) + if not ts: + return None + if ts.state == "cancelled": + logger.debug( + "Trying to execute task %s which is not in executing state anymore", + ts, + ) + return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id) try: - if ts.state == "cancelled": - # This might happen if keys are canceled - logger.debug( - "Trying to execute task %s which is not in executing state anymore", - ts, - ) - ts.done = True - self.transition(ts, "released", stimulus_id=stimulus_id) - return - if self.validate: assert not ts.waiting_for_data assert ts.state == "executing" @@ -3446,9 +3459,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> None: f"expected one of: {sorted(self.executors)}" ) - self.active_keys.add(ts.key) - - result: dict + self.active_keys.add(key) try: ts.start_time = time() if iscoroutinefunction(function): @@ -3466,7 +3477,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> None: args2, kwargs2, self.execution_state, - ts.key, + key, self.active_threads, self.active_threads_lock, self.scheduler_delay, @@ -3481,75 +3492,137 @@ async def execute(self, key: str, *, stimulus_id: str) -> None: self.scheduler_delay, ) finally: - self.active_keys.discard(ts.key) - - key = ts.key - # key *must* be still in tasks. Releasing it directly is forbidden - # without going through cancelled - ts = self.tasks.get(key) # type: ignore - assert ts, self.story(key) - ts.done = True - result["key"] = ts.key - value = result.pop("result", None) - ts.startstops.append( - {"action": "compute", "start": result["start"], "stop": result["stop"]} - ) - self.threads[ts.key] = result["thread"] - recommendations: Recs = {} + self.active_keys.discard(key) + + self.threads[key] = result["thread"] + if result["op"] == "task-finished": - ts.nbytes = result["nbytes"] - ts.type = result["type"] - recommendations[ts] = ("memory", value) if self.digests is not None: self.digests["task-duration"].add(result["stop"] - result["start"]) - elif isinstance(result.pop("actual-exception"), Reschedule): - recommendations[ts] = "rescheduled" - else: - logger.warning( - "Compute Failed\n" - "Key: %s\n" - "Function: %s\n" - "args: %s\n" - "kwargs: %s\n" - "Exception: %r\n", - ts.key, - str(funcname(function))[:1000], - convert_args_to_str(args2, max_len=1000), - convert_kwargs_to_str(kwargs2, max_len=1000), - result["exception_text"], + return ExecuteSuccessEvent( + key=key, + value=result["result"], + start=result["start"], + stop=result["stop"], + nbytes=result["nbytes"], + type=result["type"], + stimulus_id=stimulus_id, ) - recommendations[ts] = ( - "error", - result["exception"], - result["traceback"], - result["exception_text"], - result["traceback_text"], - ) - - self.transitions(recommendations, stimulus_id=stimulus_id) - logger.debug("Send compute response to scheduler: %s, %s", ts.key, result) + if isinstance(result["actual-exception"], Reschedule): + return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id) - if self.validate: - assert ts.state != "executing" - assert not ts.waiting_for_data + logger.warning( + "Compute Failed\n" + "Key: %s\n" + "Function: %s\n" + "args: %s\n" + "kwargs: %s\n" + "Exception: %r\n", + key, + str(funcname(function))[:1000], + convert_args_to_str(args2, max_len=1000), + convert_kwargs_to_str(kwargs2, max_len=1000), + result["exception_text"], + ) + return ExecuteFailureEvent( + key=key, + start=result["start"], + stop=result["stop"], + exception=result["exception"], + traceback=result["traceback"], + exception_text=result["exception_text"], + traceback_text=result["traceback_text"], + stimulus_id=stimulus_id, + ) except Exception as exc: - assert ts - logger.error( - "Exception during execution of task %s.", ts.key, exc_info=True + logger.error("Exception during execution of task %s.", key, exc_info=True) + msg = error_message(exc) + return ExecuteFailureEvent( + key=key, + start=None, + stop=None, + exception=msg["exception"], + traceback=msg["traceback"], + exception_text=msg["exception_text"], + traceback_text=msg["traceback_text"], + stimulus_id=stimulus_id, ) - emsg = error_message(exc) - emsg.pop("status") - self.transition( - ts, + + @functools.singledispatchmethod + def handle_event(self, ev: StateMachineEvent) -> RecsInstrs: + raise TypeError(ev) # pragma: nocover + + @handle_event.register + def _(self, ev: CancelComputeEvent) -> RecsInstrs: + """Scheduler requested to cancel a task""" + ts = self.tasks.get(ev.key) + if not ts or ts.state not in READY | {"waiting"}: + return {}, [] + + self.log.append((ev.key, "cancel-compute", ev.stimulus_id, time())) + # All possible dependents of ts should not be in state Processing on + # scheduler side and therefore should not be assigned to a worker, yet. + assert not ts.dependents + return {ts: "released"}, [] + + @handle_event.register + def _(self, ev: AlreadyCancelledEvent) -> RecsInstrs: + """Task is already cancelled by the time execute() runs""" + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) # type: ignore + assert ts, self.story(ev.key) + ts.done = True + return {ts: "released"}, [] + + @handle_event.register + def _(self, ev: ExecuteSuccessEvent) -> RecsInstrs: + """Task completed successfully""" + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) # type: ignore + assert ts, self.story(ev.key) + + ts.done = True + ts.startstops.append({"action": "compute", "start": ev.start, "stop": ev.stop}) + ts.nbytes = ev.nbytes + ts.type = ev.type + return {ts: ("memory", ev.value)}, [] + + @handle_event.register + def _(self, ev: ExecuteFailureEvent) -> RecsInstrs: + """Task execution failed""" + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) # type: ignore + assert ts, self.story(ev.key) + + ts.done = True + if ev.start is not None and ev.stop is not None: + ts.startstops.append( + {"action": "compute", "start": ev.start, "stop": ev.stop} + ) + + return { + ts: ( "error", - **emsg, - stimulus_id=stimulus_id, + ev.exception, + ev.traceback, + ev.exception_text, + ev.traceback_text, ) - finally: - self.ensure_computing() - self.ensure_communicating() + }, [] + + @handle_event.register + def _(self, ev: RescheduleEvent) -> RecsInstrs: + """Task raised Reschedule exception while it was running""" + # key *must* be still in tasks. Releasing it directly is forbidden + # without going through cancelled + ts = self.tasks.get(ev.key) # type: ignore + assert ts, self.story(ev.key) + return {ts: "rescheduled"}, [] def _prepare_args_for_execution( self, ts: TaskState, args: tuple, kwargs: dict[str, Any] diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index ce8d200ba7a..2c9db8cc4c4 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -14,7 +14,7 @@ from distributed.utils import recursive_to_dict if TYPE_CHECKING: - # TODO move to typing (requires Python >=3.10) + # TODO move to typing and get out of TYPE_CHECKING (requires Python >=3.10) from typing_extensions import TypeAlias TaskStateState: TypeAlias = Literal[ @@ -34,7 +34,8 @@ "resumed", "waiting", ] - +else: + TaskStateState = str # TaskState.state subsets PROCESSING: set[TaskStateState] = { @@ -256,11 +257,11 @@ class Instruction: # __slots__ = () -# @dataclass -# class Execute(Instruction): -# __slots__ = ("key", "stimulus_id") -# key: str -# stimulus_id: str +@dataclass +class Execute(Instruction): + __slots__ = ("key", "stimulus_id") + key: str + stimulus_id: str class SendMessageToScheduler(Instruction): @@ -321,11 +322,11 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler): key: str +# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @dataclass class RescheduleMsg(SendMessageToScheduler): op = "reschedule" - # Not to be confused with the distributed.Reschedule Exception __slots__ = ("key", "worker") key: str worker: str @@ -347,3 +348,64 @@ class AddKeysMsg(SendMessageToScheduler): __slots__ = ("keys", "stimulus_id") keys: list[str] stimulus_id: str + + +@dataclass +class StateMachineEvent: + __slots__ = ("stimulus_id",) + stimulus_id: str + + +@dataclass +class ExecuteSuccessEvent(StateMachineEvent): + key: str + value: object + start: float + stop: float + nbytes: int + type: type | None + __slots__ = tuple(__annotations__) # type: ignore + + +@dataclass +class ExecuteFailureEvent(StateMachineEvent): + key: str + start: float | None + stop: float | None + exception: bytes # serialized + traceback: bytes # serialized + exception_text: str + traceback_text: str + __slots__ = tuple(__annotations__) # type: ignore + + +@dataclass +class CancelComputeEvent(StateMachineEvent): + __slots__ = ("key",) + key: str + + +@dataclass +class AlreadyCancelledEvent(StateMachineEvent): + __slots__ = ("key",) + key: str + + +# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception +@dataclass +class RescheduleEvent(StateMachineEvent): + __slots__ = ("key",) + key: str + + +if TYPE_CHECKING: + # TODO remove quotes (requires Python >=3.9) + # TODO get out of TYPE_CHECKING (requires Python >=3.10) + # {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)} + Recs: TypeAlias = "dict[TaskState, TaskStateState | tuple]" + Instructions: TypeAlias = "list[Instruction]" + RecsInstrs: TypeAlias = "tuple[Recs, Instructions]" +else: + Recs = dict + Instructions = list + RecsInstrs = tuple