From ff4987cc7d8f6beeaab2161eceffd391a7d1c7d0 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 10 Mar 2022 00:43:37 +0000 Subject: [PATCH 1/9] Refactor worker scheduler messages and TaskState --- distributed/active_memory_manager.py | 2 +- distributed/batched.py | 2 +- distributed/scheduler.py | 14 +- distributed/tests/test_worker.py | 70 +-- .../tests/test_worker_state_machine.py | 90 +++ distributed/worker.py | 564 ++++++------------ distributed/worker_state_machine.py | 332 +++++++++++ docs/source/worker.rst | 9 +- 8 files changed, 612 insertions(+), 471 deletions(-) create mode 100644 distributed/tests/test_worker_state_machine.py create mode 100644 distributed/worker_state_machine.py diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 4a616095908..e31f4802fb3 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -14,7 +14,7 @@ from .metrics import time from .utils import import_term, log_errors -if TYPE_CHECKING: # pragma: nocover +if TYPE_CHECKING: from .client import Client from .scheduler import Scheduler, TaskState, WorkerState diff --git a/distributed/batched.py b/distributed/batched.py index 3e6cbcfd30b..0b1fc1da0f5 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -128,7 +128,7 @@ def _background_send(self): self.stopped.set() self.abort() - def send(self, *msgs): + def send(self, *msgs: dict) -> None: """Schedule a message for sending to the other side This completes quickly and synchronously diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 59cd77419b7..248242a2955 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1209,6 +1209,8 @@ def _to_dict_no_nest(self, *, exclude: "Container[str]" = ()) -> dict: class TaskState: """ A simple object holding information about a task. + Not to be confused with :class:`distributed.worker_state_machine.TaskState`, which + holds similar information on the Worker side. .. attribute:: key: str @@ -5505,7 +5507,9 @@ def handle_task_finished(self, key=None, worker=None, **msg): client_msgs: dict worker_msgs: dict - r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) + r: tuple = self.stimulus_task_finished( + key=key, worker=worker, status="OK", **msg + ) recommendations, client_msgs, worker_msgs = r parent._transitions(recommendations, client_msgs, worker_msgs) @@ -5516,7 +5520,7 @@ def handle_task_erred(self, key=None, **msg): recommendations: dict client_msgs: dict worker_msgs: dict - r: tuple = self.stimulus_task_erred(key=key, **msg) + r: tuple = self.stimulus_task_erred(key=key, status="error", **msg) recommendations, client_msgs, worker_msgs = r parent._transitions(recommendations, client_msgs, worker_msgs) @@ -7025,7 +7029,9 @@ async def _track_retire_worker( logger.info("Retired worker %s", ws._address) return ws._address, ws.identity() - def add_keys(self, worker=None, keys=(), stimulus_id=None): + def add_keys( + self, worker: str, keys: "Iterable[str]" = (), stimulus_id: "str | None" = None + ) -> str: """ Learn that a worker has certain keys @@ -7038,7 +7044,7 @@ def add_keys(self, worker=None, keys=(), stimulus_id=None): ws: WorkerState = parent._workers_dv[worker] redundant_replicas = [] for key in keys: - ts: TaskState = parent._tasks.get(key) + ts: TaskState = parent._tasks.get(key) # type: ignore if ts is not None and ts._state == "memory": if ws not in ts._who_has: parent.add_replica(ts, ws) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index c75d8abd7ea..3441652b656 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -59,14 +59,7 @@ slowinc, slowsum, ) -from distributed.worker import ( - TaskState, - UniqueTaskHeap, - Worker, - error_message, - logger, - parse_memory_limit, -) +from distributed.worker import Worker, error_message, logger, parse_memory_limit pytestmark = pytest.mark.ci1 @@ -3746,67 +3739,6 @@ async def test_Worker__to_dict(c, s, a): assert d["tasks"]["x"]["key"] == "x" -@gen_cluster(client=True, nthreads=[("", 1)]) -async def test_TaskState__to_dict(c, s, a): - """tasks that are listed as dependencies of other tasks are dumped as a short repr - and always appear in full under Worker.tasks - """ - x = c.submit(inc, 1, key="x") - y = c.submit(inc, x, key="y") - z = c.submit(inc, 2, key="z") - await wait([x, y, z]) - - tasks = a._to_dict()["tasks"] - - assert isinstance(tasks["x"], dict) - assert isinstance(tasks["y"], dict) - assert isinstance(tasks["z"], dict) - assert tasks["x"]["dependents"] == [""] - assert tasks["y"]["dependencies"] == [""] - - -def test_unique_task_heap(): - heap = UniqueTaskHeap() - - for x in range(10): - ts = TaskState(f"f{x}") - ts.priority = (0, 0, 1, x % 3) - heap.push(ts) - - heap_list = list(heap) - # iteration does not empty heap - assert len(heap) == 10 - assert heap_list == sorted(heap_list, key=lambda ts: ts.priority) - - seen = set() - last_prio = (0, 0, 0, 0) - while heap: - peeked = heap.peek() - ts = heap.pop() - assert peeked == ts - seen.add(ts.key) - assert ts.priority - assert last_prio <= ts.priority - last_prio = last_prio - - ts = TaskState("foo") - heap.push(ts) - heap.push(ts) - assert len(heap) == 1 - - assert repr(heap) == "" - - assert heap.pop() == ts - assert not heap - - # Test that we're cleaning the seen set on pop - heap.push(ts) - assert len(heap) == 1 - assert heap.pop() == ts - - assert repr(heap) == "" - - @gen_cluster(nthreads=[]) async def test_do_not_block_event_loop_during_shutdown(s): loop = asyncio.get_running_loop() diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py new file mode 100644 index 00000000000..153f7c5f108 --- /dev/null +++ b/distributed/tests/test_worker_state_machine.py @@ -0,0 +1,90 @@ +from distributed.utils import recursive_to_dict +from distributed.worker_state_machine import ( + ReleaseWorkerDataMsg, + TaskState, + UniqueTaskHeap, +) + + +def test_TaskState_get_nbytes(): + assert TaskState("x", nbytes=123).get_nbytes() == 123 + # Default to distributed.scheduler.default-data-size + assert TaskState("y").get_nbytes() == 1024 + + +def test_TaskState__to_dict(): + """Tasks that are listed as dependencies or dependents of other tasks are dumped as + a short repr and always appear in full directly under Worker.tasks. Uninteresting + fields are omitted. + """ + x = TaskState("x", state="memory", done=True) + y = TaskState("y", priority=(0,), dependencies={x}) + x.dependents.add(y) + actual = recursive_to_dict([x, y]) + assert actual == [ + { + "key": "x", + "state": "memory", + "done": True, + "dependents": [""], + }, + { + "key": "y", + "state": "released", + "dependencies": [""], + "priority": [0], + }, + ] + + +def test_unique_task_heap(): + heap = UniqueTaskHeap() + + for x in range(10): + ts = TaskState(f"f{x}", priority=(0,)) + ts.priority = (0, 0, 1, x % 3) + heap.push(ts) + + heap_list = list(heap) + # iteration does not empty heap + assert len(heap) == 10 + assert heap_list == sorted(heap_list, key=lambda ts: ts.priority) + + seen = set() + last_prio = (0, 0, 0, 0) + while heap: + peeked = heap.peek() + ts = heap.pop() + assert peeked == ts + seen.add(ts.key) + assert ts.priority + assert last_prio <= ts.priority + last_prio = last_prio + + ts = TaskState("foo", priority=(0,)) + heap.push(ts) + heap.push(ts) + assert len(heap) == 1 + + assert repr(heap) == "" + + assert heap.pop() == ts + assert not heap + + # Test that we're cleaning the seen set on pop + heap.push(ts) + assert len(heap) == 1 + assert heap.pop() == ts + + assert repr(heap) == "" + + +def test_sendmsg_slots(): + # Sample test on one of the subclasses + smsg = ReleaseWorkerDataMsg(key="x") + assert not hasattr(smsg, "__dict__") + + +def test_sendmsg_to_dict(): + smsg = ReleaseWorkerDataMsg(key="x") + assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} diff --git a/distributed/worker.py b/distributed/worker.py index 589ab3aee54..50f840d7864 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -18,7 +18,6 @@ Collection, Container, Iterable, - Iterator, Mapping, MutableMapping, ) @@ -27,7 +26,7 @@ from datetime import timedelta from inspect import isawaitable from pickle import PicklingError -from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast from tlz import first, keymap, merge, pluck # noqa: F401 from tornado.ioloop import IOLoop, PeriodicCallback @@ -95,18 +94,36 @@ from .utils_comm import gather_from_workers, pack_data, retry_operation from .utils_perf import ThrottledGC, disable_gc_diagnosis, enable_gc_diagnosis from .versions import get_versions +from .worker_state_machine import Instruction # noqa: F401 +from .worker_state_machine import ( + PROCESSING, + READY, + AddKeysMsg, + InvalidTransition, + LongRunningMsg, + ReleaseWorkerDataMsg, + RescheduleMsg, + SendMessageToScheduler, + SerializedTask, + TaskErredMsg, + TaskFinishedMsg, + TaskState, + UniqueTaskHeap, +) if TYPE_CHECKING: + # TODO move to typing (requires Python >=3.10) from typing_extensions import TypeAlias from .actor import Actor from .client import Client from .diagnostics.plugin import WorkerPlugin from .nanny import Nanny + from .worker_state_machine import TaskStateState - # {TaskState -> finish: str | (finish: str, *args)} - Recs: TypeAlias = "dict[TaskState, str | tuple]" - Smsgs: TypeAlias = "list[dict[str, Any]]" + # {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)} + Recs: TypeAlias = "dict[TaskState, TaskStateState | tuple]" + Instructions: TypeAlias = "list[Instruction]" logger = logging.getLogger(__name__) @@ -115,241 +132,12 @@ no_value = "--no-value-sentinel--" -# TaskState.state subsets -PROCESSING = { - "waiting", - "ready", - "constrained", - "executing", - "long-running", - "cancelled", - "resumed", -} -READY = {"ready", "constrained"} - DEFAULT_EXTENSIONS: list[type] = [PubSubWorkerExtension, ShuffleWorkerExtension] DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {} DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {} -DEFAULT_DATA_SIZE = parse_bytes( - dask.config.get("distributed.scheduler.default-data-size") -) - - -class SerializedTask(NamedTuple): - function: Callable - args: tuple - kwargs: dict[str, Any] - task: object # distributed.scheduler.TaskState.run_spec - - -class StartStop(TypedDict, total=False): - action: str - start: float - stop: float - source: str # optional - - -class InvalidTransition(Exception): - pass - - -class TaskState: - """Holds volatile state relating to an individual Dask task - - - * **dependencies**: ``set(TaskState instances)`` - The data needed by this key to run - * **dependents**: ``set(TaskState instances)`` - The keys that use this dependency. - * **duration**: ``float`` - Expected duration the a task - * **priority**: ``tuple`` - The priority this task given by the scheduler. Determines run order. - * **state**: ``str`` - The current state of the task. One of ["waiting", "ready", "executing", - "fetch", "memory", "flight", "long-running", "rescheduled", "error"] - * **who_has**: ``set(worker)`` - Workers that we believe have this data - * **coming_from**: ``str`` - The worker that current task data is coming from if task is in flight - * **waiting_for_data**: ``set(keys of dependencies)`` - A dynamic version of dependencies. All dependencies that we still don't - have for a particular key. - * **resource_restrictions**: ``{str: number}`` - Abstract resources required to run a task - * **exception**: ``str`` - The exception caused by running a task if it erred - * **traceback**: ``str`` - The exception caused by running a task if it erred - * **type**: ``type`` - The type of a particular piece of data - * **suspicious_count**: ``int`` - The number of times a dependency has not been where we expected it - * **startstops**: ``[{startstop}]`` - Log of transfer, load, and compute times for a task - * **start_time**: ``float`` - Time at which task begins running - * **stop_time**: ``float`` - Time at which task finishes running - * **metadata**: ``dict`` - Metadata related to task. Stored metadata should be msgpack - serializable (e.g. int, string, list, dict). - * **nbytes**: ``int`` - The size of a particular piece of data - * **annotations**: ``dict`` - Task annotations - - Parameters - ---------- - key: str - run_spec: SerializedTask - A named tuple containing the ``function``, ``args``, ``kwargs`` and - ``task`` associated with this `TaskState` instance. This defaults to - ``None`` and can remain empty if it is a dependency that this worker - will receive from another worker. - - """ - - key: str - run_spec: SerializedTask | None - dependencies: set[TaskState] - dependents: set[TaskState] - duration: float | None - priority: tuple[int, ...] | None - state: str - who_has: set[str] - coming_from: str | None - waiting_for_data: set[TaskState] - waiters: set[TaskState] - resource_restrictions: dict[str, float] - exception: Exception | None - exception_text: str | None - traceback: object | None - traceback_text: str | None - type: type | None - suspicious_count: int - startstops: list[StartStop] - start_time: float | None - stop_time: float | None - metadata: dict - nbytes: float | None - annotations: dict | None - done: bool - _previous: str | None - _next: str | None - - def __init__(self, key: str, run_spec: SerializedTask | None = None): - assert key is not None - self.key = key - self.run_spec = run_spec - self.dependencies = set() - self.dependents = set() - self.duration = None - self.priority = None - self.state = "released" - self.who_has = set() - self.coming_from = None - self.waiting_for_data = set() - self.waiters = set() - self.resource_restrictions = {} - self.exception = None - self.exception_text = "" - self.traceback = None - self.traceback_text = "" - self.type = None - self.suspicious_count = 0 - self.startstops = [] - self.start_time = None - self.stop_time = None - self.metadata = {} - self.nbytes = None - self.annotations = None - self.done = False - self._previous = None - self._next = None - - def __repr__(self) -> str: - return f"" - - def get_nbytes(self) -> int: - nbytes = self.nbytes - return nbytes if nbytes is not None else DEFAULT_DATA_SIZE - - def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict: - """Dictionary representation for debugging purposes. - Not type stable and not intended for roundtrips. - - See also - -------- - Client.dump_cluster_state - distributed.utils.recursive_to_dict - - Notes - ----- - This class uses ``_to_dict_no_nest`` instead of ``_to_dict``. - When a task references another task, just print the task repr. All tasks - should neatly appear under Worker.tasks. This also prevents a RecursionError - during particularly heavy loads, which have been observed to happen whenever - there's an acyclic dependency chain of ~200+ tasks. - """ - return recursive_to_dict(self, exclude=exclude, members=True) - - def is_protected(self) -> bool: - return self.state in PROCESSING or any( - dep_ts.state in PROCESSING for dep_ts in self.dependents - ) - - -class UniqueTaskHeap(Collection): - """A heap of TaskState objects ordered by TaskState.priority - Ties are broken by string comparison of the key. Keys are guaranteed to be - unique. Iterating over this object returns the elements in priority order. - """ - - def __init__(self, collection: Collection[TaskState] = ()): - self._known = {ts.key for ts in collection} - self._heap = [(ts.priority, ts.key, ts) for ts in collection] - heapq.heapify(self._heap) - - def push(self, ts: TaskState) -> None: - """Add a new TaskState instance to the heap. If the key is already - known, no object is added. - - Note: This does not update the priority / heap order in case priority - changes. - """ - assert isinstance(ts, TaskState) - if ts.key not in self._known: - heapq.heappush(self._heap, (ts.priority, ts.key, ts)) - self._known.add(ts.key) - - def pop(self) -> TaskState: - """Pop the task with highest priority from the heap.""" - _, key, ts = heapq.heappop(self._heap) - self._known.remove(key) - return ts - - def peek(self) -> TaskState: - """Get the highest priority TaskState without removing it from the heap""" - return self._heap[0][2] - - def __contains__(self, x: object) -> bool: - if isinstance(x, TaskState): - x = x.key - return x in self._known - - def __iter__(self) -> Iterator[TaskState]: - return (ts for _, _, ts in sorted(self._heap)) - - def __len__(self) -> int: - return len(self._known) - - def __repr__(self) -> str: - return f"<{type(self).__name__}: {len(self)} items>" - class Worker(ServerNode): """Worker node in a Dask distributed cluster @@ -1917,7 +1705,7 @@ def update_data( if stimulus_id is None: stimulus_id = f"update-data-{time()}" recommendations: Recs = {} - scheduler_messages = [] + instructions: Instructions = [] for key, value in data.items(): try: ts = self.tasks[key] @@ -1936,13 +1724,10 @@ def update_data( self.log.append((key, "receive-from-scatter", stimulus_id, time())) if report: - scheduler_messages.append( - {"op": "add-keys", "keys": list(data), "stimulus_id": stimulus_id} - ) + instructions.append(AddKeysMsg(keys=list(data), stimulus_id=stimulus_id)) self.transitions(recommendations, stimulus_id=stimulus_id) - for msg in scheduler_messages: - self.batched_stream.send(msg) + self._handle_instructions(instructions) return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} def handle_free_keys(self, keys: list[str], stimulus_id: str) -> None: @@ -2002,9 +1787,8 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: if rejected: self.log.append(("remove-replica-rejected", rejected, stimulus_id, time())) - self.batched_stream.send( - {"op": "add-keys", "keys": rejected, "stimulus_id": stimulus_id} - ) + smsg = AddKeysMsg(keys=rejected, stimulus_id=stimulus_id) + self._handle_instructions([smsg]) self.transitions(recommendations, stimulus_id=stimulus_id) @@ -2129,7 +1913,7 @@ def handle_compute_task( ts.annotations = annotations recommendations: Recs = {} - scheduler_msgs: Smsgs = [] + instructions: Instructions = [] for dependency in who_has: dep_ts = self.ensure_task_exists( key=dependency, @@ -2145,7 +1929,7 @@ def handle_compute_task( pass elif ts.state == "memory": recommendations[ts] = "memory" - scheduler_msgs.append(self._get_task_finished_msg(ts)) + instructions.append(self._get_task_finished_msg(ts)) elif ts.state in { "released", "fetch", @@ -2158,9 +1942,7 @@ def handle_compute_task( else: # pragma: no cover raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}") - for msg in scheduler_msgs: - self.batched_stream.send(msg) - + self._handle_instructions(instructions) self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) @@ -2170,7 +1952,7 @@ def handle_compute_task( def transition_missing_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "missing" assert ts.priority is not None @@ -2183,15 +1965,17 @@ def transition_missing_fetch( def transition_missing_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self._missing_dep_flight.discard(ts) - recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) + recs, instructions = self.transition_generic_released( + ts, stimulus_id=stimulus_id + ) assert ts.key in self.tasks - return recs, smsgs + return recs, instructions def transition_flight_missing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: assert ts.done ts.state = "missing" self._missing_dep_flight.add(ts) @@ -2200,7 +1984,7 @@ def transition_flight_missing( def transition_released_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "released" assert ts.priority is not None @@ -2213,7 +1997,7 @@ def transition_released_fetch( def transition_generic_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self.release_key(ts.key, stimulus_id=stimulus_id) recs: Recs = {} for dependency in ts.dependencies: @@ -2230,7 +2014,7 @@ def transition_generic_released( def transition_released_waiting( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "released" assert all(d.key in self.tasks for d in ts.dependencies) @@ -2238,7 +2022,7 @@ def transition_released_waiting( recommendations: Recs = {} ts.waiting_for_data.clear() for dep_ts in ts.dependencies: - if not dep_ts.state == "memory": + if dep_ts.state != "memory": ts.waiting_for_data.add(dep_ts) dep_ts.waiters.add(ts) if dep_ts.state not in {"fetch", "flight"}: @@ -2256,7 +2040,7 @@ def transition_released_waiting( def transition_fetch_flight( self, ts: TaskState, worker, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "fetch" assert ts.who_has @@ -2269,14 +2053,16 @@ def transition_fetch_flight( def transition_memory_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: - recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) - smsgs.append({"op": "release-worker-data", "key": ts.key}) - return recs, smsgs + ) -> tuple[Recs, Instructions]: + recs, instructions = self.transition_generic_released( + ts, stimulus_id=stimulus_id + ) + instructions.append(ReleaseWorkerDataMsg(ts.key)) + return recs, instructions def transition_waiting_constrained( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "waiting" assert not ts.waiting_for_data @@ -2292,25 +2078,25 @@ def transition_waiting_constrained( def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: recs: Recs = {ts: "released"} - smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] - return recs, smsgs + smsg = RescheduleMsg(key=ts.key, worker=self.address) + return recs, [smsg] def transition_executing_rescheduled( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) recs: Recs = {ts: "released"} - smsgs: Smsgs = [{"op": "reschedule", "key": ts.key, "worker": self.address}] - return recs, smsgs + smsg = RescheduleMsg(key=ts.key, worker=self.address) + return recs, [smsg] def transition_waiting_ready( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "waiting" assert ts.key not in self.ready @@ -2334,11 +2120,11 @@ def transition_cancelled_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: recs: Recs = {} - smsgs: Smsgs = [] + instructions: Instructions = [] if ts._previous == "executing": - recs, smsgs = self.transition_executing_error( + recs, instructions = self.transition_executing_error( ts, exception, traceback, @@ -2347,7 +2133,7 @@ def transition_cancelled_error( stimulus_id=stimulus_id, ) elif ts._previous == "flight": - recs, smsgs = self.transition_flight_error( + recs, instructions = self.transition_flight_error( ts, exception, traceback, @@ -2357,36 +2143,32 @@ def transition_cancelled_error( ) if ts._next: recs[ts] = ts._next - return recs, smsgs + return recs, instructions def transition_generic_error( self, ts: TaskState, - exception, - traceback, - exception_text, - traceback_text, + exception: Exception, + traceback: object, + exception_text: str, + traceback_text: str, *, stimulus_id: str, - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: ts.exception = exception ts.traceback = traceback ts.exception_text = exception_text ts.traceback_text = traceback_text ts.state = "error" - smsg = { - "op": "task-erred", - "status": "error", - "key": ts.key, - "thread": self.threads.get(ts.key), - "exception": ts.exception, - "traceback": ts.traceback, - "exception_text": ts.exception_text, - "traceback_text": ts.traceback_text, - } - - if ts.startstops: - smsg["startstops"] = ts.startstops + smsg = TaskErredMsg( + key=ts.key, + thread=self.threads.get(ts.key), + exception=ts.exception, + traceback=ts.traceback, + exception_text=ts.exception_text, + traceback_text=ts.traceback_text, + startstops=ts.startstops, + ) return {}, [smsg] @@ -2399,7 +2181,7 @@ def transition_executing_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) @@ -2413,8 +2195,8 @@ def transition_executing_error( ) def _transition_from_resumed( - self, ts: TaskState, finish: str, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + self, ts: TaskState, finish: TaskStateState, *, stimulus_id: str + ) -> tuple[Recs, Instructions]: """`resumed` is an intermediate degenerate state which splits further up into two states depending on what the last signal / next state is intended to be. There are only two viable choices depending on whether @@ -2435,24 +2217,24 @@ def _transition_from_resumed( See also `transition_resumed_waiting` """ recs: Recs = {} - smsgs: Smsgs = [] + instructions: Instructions = [] if ts.done: next_state = ts._next # if the next state is already intended to be waiting or if the # coro/thread is still running (ts.done==False), this is a noop if ts._next != finish: - recs, smsgs = self.transition_generic_released( + recs, instructions = self.transition_generic_released( ts, stimulus_id=stimulus_id ) assert next_state recs[ts] = next_state else: ts._next = finish - return recs, smsgs + return recs, instructions def transition_resumed_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: """ See Worker._transition_from_resumed """ @@ -2460,7 +2242,7 @@ def transition_resumed_fetch( def transition_resumed_missing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: """ See Worker._transition_from_resumed """ @@ -2474,7 +2256,7 @@ def transition_resumed_waiting(self, ts: TaskState, *, stimulus_id: str): def transition_cancelled_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if ts.done: return {ts: "released"}, [] elif ts._previous == "flight": @@ -2485,15 +2267,15 @@ def transition_cancelled_fetch( return {ts: ("resumed", "fetch")}, [] def transition_cancelled_resumed( - self, ts: TaskState, next: str, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + self, ts: TaskState, next: TaskStateState, *, stimulus_id: str + ) -> tuple[Recs, Instructions]: ts._next = next ts.state = "resumed" return {}, [] def transition_cancelled_waiting( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if ts.done: return {ts: "released"}, [] elif ts._previous == "executing": @@ -2505,7 +2287,7 @@ def transition_cancelled_waiting( def transition_cancelled_forgotten( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: ts._next = "forgotten" if not ts.done: return {}, [] @@ -2513,7 +2295,7 @@ def transition_cancelled_forgotten( def transition_cancelled_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if not ts.done: ts._next = "released" return {}, [] @@ -2524,14 +2306,16 @@ def transition_cancelled_released( for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - recs, smsgs = self.transition_generic_released(ts, stimulus_id=stimulus_id) + recs, instructions = self.transition_generic_released( + ts, stimulus_id=stimulus_id + ) if next_state != "released": recs[ts] = next_state - return recs, smsgs + return recs, instructions def transition_executing_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: ts._previous = ts.state ts._next = "released" # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 @@ -2541,13 +2325,13 @@ def transition_executing_released( def transition_long_running_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self.executed_count += 1 return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) def transition_generic_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if value is no_value and ts.key not in self.data: raise RuntimeError( f"Tried to transition task {ts} to `memory` without data available" @@ -2566,13 +2350,14 @@ def transition_generic_memory( msg = error_message(e) recs = {ts: tuple(msg.values())} return recs, [] - assert ts.key in self.data or ts.key in self.actors - smsgs = [self._get_task_finished_msg(ts)] - return recs, smsgs + if self.validate: + assert ts.key in self.data or ts.key in self.actors + smsg = self._get_task_finished_msg(ts) + return recs, [smsg] def transition_executing_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert ts.state == "executing" or ts.key in self.long_running assert not ts.waiting_for_data @@ -2584,7 +2369,7 @@ def transition_executing_memory( def transition_constrained_executing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert not ts.waiting_for_data assert ts.key not in self.data @@ -2602,7 +2387,7 @@ def transition_constrained_executing( def transition_ready_executing( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if self.validate: assert not ts.waiting_for_data assert ts.key not in self.data @@ -2620,7 +2405,7 @@ def transition_ready_executing( def transition_flight_fetch( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: # If this transition is called after the flight coroutine has finished, # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op @@ -2648,7 +2433,7 @@ def transition_flight_error( traceback_text, *, stimulus_id: str, - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self._in_flight_tasks.discard(ts) ts.coming_from = None return self.transition_generic_error( @@ -2662,7 +2447,7 @@ def transition_flight_error( def transition_flight_released( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if ts.done: # FIXME: Is this even possible? Would an assert instead be more # sensible? @@ -2676,70 +2461,49 @@ def transition_flight_released( def transition_cancelled_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: assert ts._next return {ts: ts._next}, [] def transition_executing_long_running( - self, ts: TaskState, compute_duration, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + self, ts: TaskState, compute_duration: float, *, stimulus_id: str + ) -> tuple[Recs, Instructions]: ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) - smsgs = [ - { - "op": "long-running", - "key": ts.key, - "compute_duration": compute_duration, - } - ] - + smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration) self.io_loop.add_callback(self.ensure_computing) - return {}, smsgs + return {}, [smsg] def transition_released_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: - recs: Recs = {} + ) -> tuple[Recs, Instructions]: try: recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: msg = error_message(e) - recs[ts] = ( - "error", - msg["exception"], - msg["traceback"], - msg["exception_text"], - msg["traceback_text"], - ) + recs = {ts: tuple(msg.values())} return recs, [] - smsgs = [{"op": "add-keys", "keys": [ts.key], "stimulus_id": stimulus_id}] - return recs, smsgs + smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + return recs, [smsg] def transition_flight_memory( self, ts: TaskState, value, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: self._in_flight_tasks.discard(ts) ts.coming_from = None - recs: Recs = {} try: recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: msg = error_message(e) - recs[ts] = ( - "error", - msg["exception"], - msg["traceback"], - msg["exception_text"], - msg["traceback_text"], - ) + recs = {ts: tuple(msg.values())} return recs, [] - smsgs = [{"op": "add-keys", "keys": [ts.key], "stimulus_id": stimulus_id}] - return recs, smsgs + smsg = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id) + return recs, [smsg] def transition_released_forgotten( self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: recommendations: Recs = {} # Dependents _should_ be released by the scheduler before this if self.validate: @@ -2756,7 +2520,7 @@ def transition_released_forgotten( def _transition( self, ts: TaskState, finish: str | tuple, *args, stimulus_id: str, **kwargs - ) -> tuple[Recs, Smsgs]: + ) -> tuple[Recs, Instructions]: if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple assert not args @@ -2770,13 +2534,15 @@ def _transition( if func is not None: self._transition_counter += 1 - recs, smsgs = func(ts, *args, stimulus_id=stimulus_id, **kwargs) + recs, instructions = func(ts, *args, stimulus_id=stimulus_id, **kwargs) self._notify_plugins("transition", ts.key, start, finish, **kwargs) elif "released" not in (start, finish): # start -> "released" -> finish try: - recs, smsgs = self._transition(ts, "released", stimulus_id=stimulus_id) + recs, instructions = self._transition( + ts, "released", stimulus_id=stimulus_id + ) v = recs.get(ts, (finish, *args)) v_state: str v_args: list | tuple @@ -2784,11 +2550,11 @@ def _transition( v_state, *v_args = v else: v_state, v_args = v, () - b_recs, b_smsgs = self._transition( + b_recs, b_instructions = self._transition( ts, v_state, *v_args, stimulus_id=stimulus_id ) recs.update(b_recs) - smsgs += b_smsgs + instructions += b_instructions except InvalidTransition: raise InvalidTransition( f"Impossible transition from {start} to {finish} for {ts.key}" @@ -2815,7 +2581,7 @@ def _transition( time(), ) ) - return recs, smsgs + return recs, instructions def transition( self, ts: TaskState, finish: str, *, stimulus_id: str, **kwargs @@ -2835,9 +2601,10 @@ def transition( -------- Scheduler.transitions: transitive version of this function """ - recs, smsgs = self._transition(ts, finish, stimulus_id=stimulus_id, **kwargs) - for msg in smsgs: - self.batched_stream.send(msg) + recs, instructions = self._transition( + ts, finish, stimulus_id=stimulus_id, **kwargs + ) + self._handle_instructions(instructions) self.transitions(recs, stimulus_id=stimulus_id) def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: @@ -2846,34 +2613,44 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: This includes feedback from previous transitions and continues until we reach a steady state """ - smsgs = [] + instructions = [] remaining_recs = recommendations.copy() tasks = set() while remaining_recs: ts, finish = remaining_recs.popitem() tasks.add(ts) - a_recs, a_smsgs = self._transition(ts, finish, stimulus_id=stimulus_id) + a_recs, a_instructions = self._transition( + ts, finish, stimulus_id=stimulus_id + ) remaining_recs.update(a_recs) - smsgs += a_smsgs + instructions += a_instructions if self.validate: # Full state validation is very expensive for ts in tasks: self.validate_task(ts) - if not self.batched_stream.closed(): - for msg in smsgs: - self.batched_stream.send(msg) - else: + if self.batched_stream.closed(): logger.debug( "BatchedSend closed while transitioning tasks. %d tasks not sent.", - len(smsgs), + len(instructions), ) + else: + self._handle_instructions(instructions) + + def _handle_instructions(self, instructions: list[Instruction]) -> None: + # TODO this method is temporary. + # See final design: https://github.com/dask/distributed/issues/5894 + for inst in instructions: + if isinstance(inst, SendMessageToScheduler): + self.batched_stream.send(inst.to_dict()) + else: + raise TypeError(inst) # pragma: nocover def maybe_transition_long_running( - self, ts: TaskState, *, stimulus_id: str, compute_duration=None + self, ts: TaskState, *, compute_duration: float, stimulus_id: str ): if ts.state == "executing": self.transition( @@ -2965,7 +2742,7 @@ def ensure_communicating(self) -> None: for el in skipped_worker_in_flight: self.data_needed.push(el) - def _get_task_finished_msg(self, ts: TaskState) -> dict[str, Any]: + def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") typ = ts.type @@ -2983,19 +2760,15 @@ def _get_task_finished_msg(self, ts: TaskState) -> dict[str, Any]: # Some types fail pickling (example: _thread.lock objects), # send their name as a best effort. typ_serialized = pickle.dumps(typ.__name__, protocol=4) - d = { - "op": "task-finished", - "status": "OK", - "key": ts.key, - "nbytes": ts.nbytes, - "thread": self.threads.get(ts.key), - "type": typ_serialized, - "typename": typename(typ), - "metadata": ts.metadata, - } - if ts.startstops: - d["startstops"] = ts.startstops - return d + return TaskFinishedMsg( + key=ts.key, + nbytes=ts.nbytes, + thread=self.threads.get(ts.key), + type=typ_serialized, + typename=typename(typ), + metadata=ts.metadata, + startstops=ts.startstops, + ) def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: """ @@ -3009,13 +2782,14 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: Raises ------ - TypeError: - In case the data is put into the in memory buffer and an exception - occurs during spilling, this raises an exception. This has to be - handled by the caller since most callers generate scheduler messages - on success (see comment above) but we need to signal that this was - not successful. - Can only trigger if spill to disk is enabled and the task is not an + Exception: + In case the data is put into the in memory buffer and a serialization error + occurs during spilling, this raises that error. This has to be handled by + the caller since most callers generate scheduler messages on success (see + comment above) but we need to signal that this was not successful. + + Can only trigger if distributed.worker.memory.target is enabled, the value + is individually larger than target * memory_limit, and the task is not an actor. """ if ts.key in self.data: diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py new file mode 100644 index 00000000000..8b09b02677f --- /dev/null +++ b/distributed/worker_state_machine.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +import heapq +import sys +from collections.abc import Callable, Container, Iterator +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Collection # TODO move to collections.abc (requires Python >=3.9) +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict + +import dask +from dask.utils import parse_bytes + +from .utils import recursive_to_dict + +if TYPE_CHECKING: + # TODO move to typing (requires Python >=3.10) + from typing_extensions import TypeAlias + + TaskStateState: TypeAlias = Literal[ + "cancelled", + "constrained", + "error", + "executing", + "fetch", + "flight", + "forgotten", + "long-running", + "memory", + "missing", + "ready", + "released", + "rescheduled", + "resumed", + "waiting", + ] + + +# TaskState.state subsets +PROCESSING: set[TaskStateState] = { + "waiting", + "ready", + "constrained", + "executing", + "long-running", + "cancelled", + "resumed", +} +READY: set[TaskStateState] = {"ready", "constrained"} + + +class SerializedTask(NamedTuple): + function: Callable + args: tuple + kwargs: dict[str, Any] + task: object # distributed.scheduler.TaskState.run_spec + + +class StartStop(TypedDict, total=False): + action: str + start: float + stop: float + source: str # optional + + +class InvalidTransition(Exception): + pass + + +@lru_cache +def _default_data_size() -> int: + return parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) + + +# Note: can't specify __slots__ manually to enable slots in Python <3.10 in a @dataclass +# that defines any default values +dc_slots = {"slots": True} if sys.version_info >= (3, 10) else {} + + +@dataclass(repr=False, eq=False, **dc_slots) +class TaskState: + """Holds volatile state relating to an individual Dask task. + + Not to be confused with :class:`distributed.scheduler.TaskState`, which holds + similar information on the scheduler side. + """ + + #: Task key. Mandatory. + key: str + #: A named tuple containing the ``function``, ``args``, ``kwargs`` and ``task`` + #: associated with this `TaskState` instance. This defaults to ``None`` and can + #: remain empty if it is a dependency that this worker will receive from another + #: worker. + run_spec: SerializedTask | None = None + + #: The data needed by this key to run + dependencies: set[TaskState] = field(default_factory=set) + #: The keys that use this dependency + dependents: set[TaskState] = field(default_factory=set) + #: Subset of dependencies that are not in memory + waiting_for_data: set[TaskState] = field(default_factory=set) + #: Subset of dependents that are not in memory + waiters: set[TaskState] = field(default_factory=set) + + #: The current state of the task + state: TaskStateState = "released" + #: The previous state of the task. This is a state machine implementation detail. + _previous: TaskStateState | None = None + #: The next state of the task. This is a state machine implementation detail. + _next: TaskStateState | None = None + + #: Expected duration of the task + duration: float | None = None + #: The priority this task given by the scheduler. Determines run order. + priority: tuple[int, ...] | None = None + #: Addresses of workers that we believe have this data + who_has: set[str] = field(default_factory=set) + #: The worker that current task data is coming from if task is in flight + coming_from: str | None = None + #: Abstract resources required to run a task + resource_restrictions: dict[str, float] = field(default_factory=dict) + #: The exception caused by running a task if it erred + exception: Exception | None = None + #: string representation of exception + exception_text: str = "" + #: The traceback caused by running a task if it erred + traceback: object | None = None + #: string representation of traceback + traceback_text: str = "" + #: The type of a particular piece of data + type: type | None = None + #: The number of times a dependency has not been where we expected it + suspicious_count: int = 0 + #: Log of transfer, load, and compute times for a task + startstops: list[StartStop] = field(default_factory=list) + #: Time at which task begins running + start_time: float | None = None + #: Time at which task finishes running + stop_time: float | None = None + #: Metadata related to the task. + #: Stored metadata should be msgpack serializable (e.g. int, string, list, dict). + metadata: dict = field(default_factory=dict) + #: The size of the value of the task, if in memory + nbytes: int | None = None + #: Arbitrary task annotations + annotations: dict | None = None + #: True if the task is in memory or erred; False otherwise + done: bool = False + + # Support for weakrefs to a class with __slots__ + __weakref__: Any = field(init=False) + + def __repr__(self) -> str: + return f"" + + def get_nbytes(self) -> int: + nbytes = self.nbytes + return nbytes if nbytes is not None else _default_data_size() + + def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict: + """Dictionary representation for debugging purposes. + Not type stable and not intended for roundtrips. + + See also + -------- + Client.dump_cluster_state + distributed.utils.recursive_to_dict + + Notes + ----- + This class uses ``_to_dict_no_nest`` instead of ``_to_dict``. + When a task references another task, just print the task repr. All tasks + should neatly appear under Worker.tasks. This also prevents a RecursionError + during particularly heavy loads, which have been observed to happen whenever + there's an acyclic dependency chain of ~200+ tasks. + """ + out = recursive_to_dict(self, exclude=exclude, members=True) + # Remove all Nones and empty containers + return {k: v for k, v in out.items() if v} + + def is_protected(self) -> bool: + return self.state in PROCESSING or any( + dep_ts.state in PROCESSING for dep_ts in self.dependents + ) + + +class UniqueTaskHeap(Collection[TaskState]): + """A heap of TaskState objects ordered by TaskState.priority. + Ties are broken by string comparison of the key. Keys are guaranteed to be + unique. Iterating over this object returns the elements in priority order. + """ + + __slots__ = ("_known", "_heap") + _known: set[str] + _heap: list[tuple[tuple[int, ...], str, TaskState]] + + def __init__(self): + self._known = set() + self._heap = [] + + def push(self, ts: TaskState) -> None: + """Add a new TaskState instance to the heap. If the key is already + known, no object is added. + + Note: This does not update the priority / heap order in case priority + changes. + """ + assert isinstance(ts, TaskState) + if ts.key not in self._known: + assert ts.priority + heapq.heappush(self._heap, (ts.priority, ts.key, ts)) + self._known.add(ts.key) + + def pop(self) -> TaskState: + """Pop the task with highest priority from the heap.""" + _, key, ts = heapq.heappop(self._heap) + self._known.remove(key) + return ts + + def peek(self) -> TaskState: + """Get the highest priority TaskState without removing it from the heap""" + return self._heap[0][2] + + def __contains__(self, x: object) -> bool: + if isinstance(x, TaskState): + x = x.key + return x in self._known + + def __iter__(self) -> Iterator[TaskState]: + return (ts for _, _, ts in sorted(self._heap)) + + def __len__(self) -> int: + return len(self._known) + + def __repr__(self) -> str: + return f"<{type(self).__name__}: {len(self)} items>" + + +class Instruction: + """Command from the worker state machine to the Worker, in response to an event""" + + __slots__ = () + + +# TODO https://github.com/dask/distributed/issues/5736 + +# @dataclass +# class GatherDep(Instruction): +# __slots__ = ("worker", "to_gather") +# worker: str +# to_gather: set[str] + + +# @dataclass +# class FindMissing(Instruction): +# __slots__ = () + + +# @dataclass +# class Execute(Instruction): +# __slots__ = ("key", "stimulus_id") +# key: str +# stimulus_id: str + + +class SendMessageToScheduler(Instruction): + __slots__ = () + #: Matches a key in Scheduler.stream_handlers + op: ClassVar[str] + + def __init_subclass__(cls, op: str): + cls.op = op + + def to_dict(self) -> dict[str, Any]: + """Convert object to dict so that it can be serialized with msgpack""" + d = {k: getattr(self, k) for k in self.__annotations__} + d["op"] = self.op + return d + + +# Note: as of Python 3.10.2, @dataclass(slots=True) doesn't work with __init__subclass__ +# https://bugs.python.org/issue46970 +@dataclass +class TaskFinishedMsg(SendMessageToScheduler, op="task-finished"): + key: str + nbytes: int | None + thread: int | None + type: bytes + typename: str + metadata: dict + startstops: list[StartStop] + __slots__ = tuple(__annotations__) # type: ignore + + +@dataclass +class TaskErredMsg(SendMessageToScheduler, op="task-erred"): + key: str + thread: int | None + exception: Exception + exception_text: str + traceback: object + traceback_text: str + startstops: list[StartStop] + __slots__ = tuple(__annotations__) # type: ignore + + +@dataclass +class ReleaseWorkerDataMsg(SendMessageToScheduler, op="release-worker-data"): + __slots__ = ("key",) + key: str + + +@dataclass +class RescheduleMsg(SendMessageToScheduler, op="reschedule"): + # Not to be confused with the distributed.Reschedule Exception + __slots__ = ("key", "worker") + key: str + worker: str + + +@dataclass +class LongRunningMsg(SendMessageToScheduler, op="long-running"): + __slots__ = ("key", "compute_duration") + key: str + compute_duration: float + + +@dataclass +class AddKeysMsg(SendMessageToScheduler, op="add-keys"): + __slots__ = ("key", "stimulus_id") + keys: list[str] + stimulus_id: str diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 6e223743694..98d1b0cb0b3 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -415,13 +415,20 @@ Dask workers are by default launched, monitored, and managed by a small Nanny process. .. autoclass:: distributed.nanny.Nanny + :members: API Documentation ----------------- -.. autoclass:: distributed.worker.TaskState +.. autoclass:: distributed.worker_state_machine.TaskState + :members: + +.. autoclass:: distributed.worker_state_machine.UniqueTaskHeap + :members: + .. autoclass:: distributed.worker.Worker + :members: .. _malloc: https://www.man7.org/linux/man-pages/man3/malloc.3.html From 3ae3686daa0211d80e2f644011e2de3cb841ad5a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 10 Mar 2022 12:03:31 +0000 Subject: [PATCH 2/9] fix slots --- distributed/tests/test_worker_state_machine.py | 10 +++++++--- distributed/worker_state_machine.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 153f7c5f108..2c97b217aab 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,6 +1,9 @@ +import pytest + from distributed.utils import recursive_to_dict from distributed.worker_state_machine import ( ReleaseWorkerDataMsg, + SendMessageToScheduler, TaskState, UniqueTaskHeap, ) @@ -79,12 +82,13 @@ def test_unique_task_heap(): assert repr(heap) == "" -def test_sendmsg_slots(): - # Sample test on one of the subclasses - smsg = ReleaseWorkerDataMsg(key="x") +@pytest.mark.parametrize("cls", SendMessageToScheduler.__subclasses__()) +def test_sendmsg_slots(cls): + smsg = cls(**dict.fromkeys(cls.__annotations__)) assert not hasattr(smsg, "__dict__") def test_sendmsg_to_dict(): + # Arbitrary sample class smsg = ReleaseWorkerDataMsg(key="x") assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 8b09b02677f..53b3f0e215c 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -327,6 +327,6 @@ class LongRunningMsg(SendMessageToScheduler, op="long-running"): @dataclass class AddKeysMsg(SendMessageToScheduler, op="add-keys"): - __slots__ = ("key", "stimulus_id") + __slots__ = ("keys", "stimulus_id") keys: list[str] stimulus_id: str From 3c8abb86c59b778b2778bf4d455e1cc13772095a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 10 Mar 2022 12:33:59 +0000 Subject: [PATCH 3/9] move TaskState in imports --- distributed/tests/test_failed_workers.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index e8dc492aab1..9912f77aa1c 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -25,6 +25,7 @@ slowadd, slowinc, ) +from distributed.worker_state_machine import TaskState pytestmark = pytest.mark.ci1 @@ -494,19 +495,13 @@ async def test_worker_time_to_live(c, s, a, b): @gen_cluster() async def test_forget_data_not_supposed_to_have(s, a, b): - """ - If a depednecy fetch finishes on a worker after the scheduler already - released everything, the worker might be stuck with a redundant replica - which is never cleaned up. + """If a dependency fetch finishes on a worker after the scheduler already released + everything, the worker might be stuck with a redundant replica which is never + cleaned up. """ # FIXME: Replace with "blackbox test" which shows an actual example where - # this situation is provoked if this is even possible. - # If this cannot be constructed, the entire superfuous_data handler and its - # corresponding pieces on the scheduler side may be removed - from distributed.worker import TaskState - - ts = TaskState("key") - ts.state = "flight" + # this situation is provoked if this is even possible. + ts = TaskState("key", state="flight") a.tasks["key"] = ts recommendations = {ts: ("memory", 123)} a.transitions(recommendations, stimulus_id="test") From 3f15676385e618daa10af65afa5de5ecd0539117 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 10 Mar 2022 16:02:50 +0000 Subject: [PATCH 4/9] fix sphinx docs --- docs/source/worker.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 98d1b0cb0b3..40b4381a17e 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -92,19 +92,19 @@ more details on the command line options, please have a look at the Internal Scheduling ------------------- -Internally tasks that come to the scheduler proceed through the following -pipeline as :py:class:`distributed.worker.TaskState` objects. Tasks which -follow this path have a :py:attr:`distributed.worker.TaskState.runspec` defined -which instructs the worker how to execute them. +Internally tasks that come to the scheduler proceed through the following pipeline as +:class:`distributed.worker_state_machine.TaskState` objects. Tasks which follow this +path have a :attr:`~distributed.worker_state_machine.TaskState.runspec` defined which +instructs the worker how to execute them. .. image:: images/worker-task-state.svg :alt: Dask worker task states Data dependencies are also represented as -:py:class:`distributed.worker.TaskState` objects and follow a simpler path -through the execution pipeline. These tasks do not have a -:py:attr:`distributed.worker.TaskState.runspec` defined and instead contain a -listing of workers to collect their result from. +:class:`~distributed.worker_state_machine.TaskState` objects and follow a simpler path +through the execution pipeline. These tasks do not have a +:attr:`~distributed.worker_state_machine.TaskState.runspec` defined and instead contain +a listing of workers to collect their result from. .. image:: images/worker-dep-state.svg From 2f750349b9e67e42b05f98e523654109276850f4 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 10 Mar 2022 16:40:39 +0000 Subject: [PATCH 5/9] Don't expose UniqueTaskHeap as public API --- docs/source/worker.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 40b4381a17e..e6892220eb3 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -424,9 +424,6 @@ API Documentation .. autoclass:: distributed.worker_state_machine.TaskState :members: -.. autoclass:: distributed.worker_state_machine.UniqueTaskHeap - :members: - .. autoclass:: distributed.worker.Worker :members: From 5af52de2e3d7e1d6b4a6f3f962696404d91d1966 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 10 Mar 2022 17:55:22 +0000 Subject: [PATCH 6/9] Don't swallow kwargs --- distributed/scheduler.py | 82 +++++++++++++++++------------ distributed/worker.py | 3 -- distributed/worker_state_machine.py | 5 +- 3 files changed, 49 insertions(+), 41 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index f10244d8f28..ef49af64685 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2558,7 +2558,7 @@ def transition_no_worker_waiting(self, key): raise def transition_no_worker_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None + self, key: str, nbytes: "int | None", typename: str, worker: str ): try: ws: WorkerState = self._workers_dv[worker] @@ -2580,7 +2580,7 @@ def transition_no_worker_memory( self.check_idle_saturated(ws) _add_to_memory( - self, ts, ws, recommendations, client_msgs, type=type, typename=typename + self, ts, ws, recommendations, client_msgs, typename=typename ) ts.state = "memory" @@ -2765,8 +2765,15 @@ def transition_waiting_processing(self, key): pdb.set_trace() raise + # This method must have the same signature as transition_processing_memory def transition_waiting_memory( - self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs + self, + key: str, + nbytes: "int | None", + type: bytes, + typename: str, + startstops: list, + worker: str, ): try: ws: WorkerState = self._workers_dv[worker] @@ -2805,15 +2812,15 @@ def transition_waiting_memory( pdb.set_trace() raise + # This method must have the same signature as transition_waiting_memory def transition_processing_memory( self, - key, - nbytes=None, - type=None, - typename: str = None, - worker=None, - startstops=None, - **kwargs, + key: str, + nbytes: "int | None", + type: bytes, + typename: str, + startstops: list, # list[distributed.worker_state_machine.StartStop] + worker: str, ): ws: WorkerState wws: WorkerState @@ -2858,14 +2865,13 @@ def transition_processing_memory( ############################# # Update Timing Information # ############################# - if startstops: - startstop: dict - for startstop in startstops: - ts._group.add_duration( - stop=startstop["stop"], - start=startstop["start"], - action=startstop["action"], - ) + startstop: dict + for startstop in startstops: + ts._group.add_duration( + stop=startstop["stop"], + start=startstop["start"], + action=startstop["action"], + ) s: set = self._unknown_durations.pop(ts._prefix._name, set()) tts: TaskState @@ -3161,7 +3167,6 @@ def transition_processing_erred( exception_text: str = None, traceback_text: str = None, worker: str = None, - **kwargs, ): ws: WorkerState try: @@ -4544,6 +4549,7 @@ async def add_worker( if ts.state == "memory": self.add_keys(worker=address, keys=[key]) else: + # Call transition_no_worker_memory t: tuple = parent._transition( key, "memory", @@ -4966,7 +4972,7 @@ def update_graph( # TODO: balance workers - def stimulus_task_finished(self, key=None, worker=None, **kwargs): + def stimulus_task_finished(self, key: str, worker: str, metadata: dict, **kwargs): """Mark that a task has finished execution on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task finished %s, %s", key, worker) @@ -4976,7 +4982,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): worker_msgs: dict = {} ws: WorkerState = parent._workers_dv[worker] - ts: TaskState = parent._tasks.get(key) + ts: TaskState = parent._tasks.get(key) # type: ignore if ts is None or ts._state == "released": logger.debug( "Received already computed task, worker: %s, state: %s" @@ -4996,7 +5002,7 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): elif ts._state == "memory": self.add_keys(worker=worker, keys=[key]) else: - ts._metadata.update(kwargs["metadata"]) + ts._metadata.update(metadata) r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) recommendations, client_msgs, worker_msgs = r @@ -5005,13 +5011,19 @@ def stimulus_task_finished(self, key=None, worker=None, **kwargs): return recommendations, client_msgs, worker_msgs def stimulus_task_erred( - self, key=None, worker=None, exception=None, traceback=None, **kwargs + self, + key: str, + worker: str, + exception: Exception, + traceback: object, + exception_text: str, + traceback_text: str, ): """Mark that a task has erred on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = parent._tasks.get(key) + ts: TaskState = parent._tasks.get(key) # type: ignore if ts is None or ts._state != "processing": return {}, {}, {} @@ -5022,11 +5034,13 @@ def stimulus_task_erred( return parent._transition( key, "erred", + # kwargs to transition_processing_erred cause=key, exception=exception, traceback=traceback, worker=worker, - **kwargs, + exception_text=exception_text, + traceback_text=traceback_text, ) def stimulus_retry(self, keys, client=None): @@ -5549,7 +5563,7 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) - def handle_task_finished(self, key=None, worker=None, **msg): + def handle_task_finished(self, key: str, worker: str, **msg): parent: SchedulerState = cast(SchedulerState, self) if worker not in parent._workers_dv: return @@ -5559,23 +5573,22 @@ def handle_task_finished(self, key=None, worker=None, **msg): client_msgs: dict worker_msgs: dict - r: tuple = self.stimulus_task_finished( - key=key, worker=worker, status="OK", **msg - ) + r: tuple = self.stimulus_task_finished(key=key, worker=worker, **msg) recommendations, client_msgs, worker_msgs = r parent._transitions(recommendations, client_msgs, worker_msgs) self.send_all(client_msgs, worker_msgs) - def handle_task_erred(self, key=None, **msg): + def handle_task_erred(self, **msg): parent: SchedulerState = cast(SchedulerState, self) recommendations: dict client_msgs: dict worker_msgs: dict - r: tuple = self.stimulus_task_erred(key=key, status="error", **msg) + # msg is the output of distributed.worker_state_machine.TaskErredMsg.to_dict(), + # plus worker which is added by the RPC + r: tuple = self.stimulus_task_erred(**msg) recommendations, client_msgs, worker_msgs = r parent._transitions(recommendations, client_msgs, worker_msgs) - self.send_all(client_msgs, worker_msgs) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): @@ -8195,8 +8208,9 @@ def _add_to_memory( ws: WorkerState, recommendations: dict, client_msgs: dict, - type=None, - typename: str = None, + *, + typename: str, + type: "bytes | None" = None, ): """ Add *ts* to the set of in-memory tasks. diff --git a/distributed/worker.py b/distributed/worker.py index 50f840d7864..097f9b5754d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2162,12 +2162,10 @@ def transition_generic_error( ts.state = "error" smsg = TaskErredMsg( key=ts.key, - thread=self.threads.get(ts.key), exception=ts.exception, traceback=ts.traceback, exception_text=ts.exception_text, traceback_text=ts.traceback_text, - startstops=ts.startstops, ) return {}, [smsg] @@ -2763,7 +2761,6 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: return TaskFinishedMsg( key=ts.key, nbytes=ts.nbytes, - thread=self.threads.get(ts.key), type=typ_serialized, typename=typename(typ), metadata=ts.metadata, diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 53b3f0e215c..b635cda3800 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -284,8 +284,7 @@ def to_dict(self) -> dict[str, Any]: class TaskFinishedMsg(SendMessageToScheduler, op="task-finished"): key: str nbytes: int | None - thread: int | None - type: bytes + type: bytes # serialized class typename: str metadata: dict startstops: list[StartStop] @@ -295,12 +294,10 @@ class TaskFinishedMsg(SendMessageToScheduler, op="task-finished"): @dataclass class TaskErredMsg(SendMessageToScheduler, op="task-erred"): key: str - thread: int | None exception: Exception exception_text: str traceback: object traceback_text: str - startstops: list[StartStop] __slots__ = tuple(__annotations__) # type: ignore From 349dd15d0ee4d76459f35b5f6afcee4ecc6d3e33 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 11 Mar 2022 10:48:16 +0000 Subject: [PATCH 7/9] revert revert revert --- distributed/scheduler.py | 84 ++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 51 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ef49af64685..912ba378da4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2558,7 +2558,7 @@ def transition_no_worker_waiting(self, key): raise def transition_no_worker_memory( - self, key: str, nbytes: "int | None", typename: str, worker: str + self, key, nbytes=None, type=None, typename: str = None, worker=None ): try: ws: WorkerState = self._workers_dv[worker] @@ -2580,7 +2580,7 @@ def transition_no_worker_memory( self.check_idle_saturated(ws) _add_to_memory( - self, ts, ws, recommendations, client_msgs, typename=typename + self, ts, ws, recommendations, client_msgs, type=type, typename=typename ) ts.state = "memory" @@ -2765,15 +2765,8 @@ def transition_waiting_processing(self, key): pdb.set_trace() raise - # This method must have the same signature as transition_processing_memory def transition_waiting_memory( - self, - key: str, - nbytes: "int | None", - type: bytes, - typename: str, - startstops: list, - worker: str, + self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs ): try: ws: WorkerState = self._workers_dv[worker] @@ -2812,15 +2805,15 @@ def transition_waiting_memory( pdb.set_trace() raise - # This method must have the same signature as transition_waiting_memory def transition_processing_memory( self, - key: str, - nbytes: "int | None", - type: bytes, - typename: str, - startstops: list, # list[distributed.worker_state_machine.StartStop] - worker: str, + key, + nbytes=None, + type=None, + typename: str = None, + worker=None, + startstops=None, + **kwargs, ): ws: WorkerState wws: WorkerState @@ -2865,13 +2858,14 @@ def transition_processing_memory( ############################# # Update Timing Information # ############################# - startstop: dict - for startstop in startstops: - ts._group.add_duration( - stop=startstop["stop"], - start=startstop["start"], - action=startstop["action"], - ) + if startstops: + startstop: dict + for startstop in startstops: + ts._group.add_duration( + stop=startstop["stop"], + start=startstop["start"], + action=startstop["action"], + ) s: set = self._unknown_durations.pop(ts._prefix._name, set()) tts: TaskState @@ -3167,6 +3161,7 @@ def transition_processing_erred( exception_text: str = None, traceback_text: str = None, worker: str = None, + **kwargs, ): ws: WorkerState try: @@ -4549,7 +4544,6 @@ async def add_worker( if ts.state == "memory": self.add_keys(worker=address, keys=[key]) else: - # Call transition_no_worker_memory t: tuple = parent._transition( key, "memory", @@ -4972,7 +4966,7 @@ def update_graph( # TODO: balance workers - def stimulus_task_finished(self, key: str, worker: str, metadata: dict, **kwargs): + def stimulus_task_finished(self, key=None, worker=None, **kwargs): """Mark that a task has finished execution on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task finished %s, %s", key, worker) @@ -4982,7 +4976,7 @@ def stimulus_task_finished(self, key: str, worker: str, metadata: dict, **kwargs worker_msgs: dict = {} ws: WorkerState = parent._workers_dv[worker] - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = parent._tasks.get(key) if ts is None or ts._state == "released": logger.debug( "Received already computed task, worker: %s, state: %s" @@ -5002,7 +4996,7 @@ def stimulus_task_finished(self, key: str, worker: str, metadata: dict, **kwargs elif ts._state == "memory": self.add_keys(worker=worker, keys=[key]) else: - ts._metadata.update(metadata) + ts._metadata.update(kwargs["metadata"]) r: tuple = parent._transition(key, "memory", worker=worker, **kwargs) recommendations, client_msgs, worker_msgs = r @@ -5011,19 +5005,13 @@ def stimulus_task_finished(self, key: str, worker: str, metadata: dict, **kwargs return recommendations, client_msgs, worker_msgs def stimulus_task_erred( - self, - key: str, - worker: str, - exception: Exception, - traceback: object, - exception_text: str, - traceback_text: str, + self, key=None, worker=None, exception=None, traceback=None, **kwargs ): """Mark that a task has erred on a particular worker""" parent: SchedulerState = cast(SchedulerState, self) logger.debug("Stimulus task erred %s, %s", key, worker) - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = parent._tasks.get(key) if ts is None or ts._state != "processing": return {}, {}, {} @@ -5034,13 +5022,11 @@ def stimulus_task_erred( return parent._transition( key, "erred", - # kwargs to transition_processing_erred cause=key, exception=exception, traceback=traceback, worker=worker, - exception_text=exception_text, - traceback_text=traceback_text, + **kwargs, ) def stimulus_retry(self, keys, client=None): @@ -5563,7 +5549,7 @@ def send_task_to_worker(self, worker, ts: TaskState, duration: double = -1): def handle_uncaught_error(self, **msg): logger.exception(clean_exception(**msg)[1]) - def handle_task_finished(self, key: str, worker: str, **msg): + def handle_task_finished(self, key=None, worker=None, **msg): parent: SchedulerState = cast(SchedulerState, self) if worker not in parent._workers_dv: return @@ -5579,16 +5565,15 @@ def handle_task_finished(self, key: str, worker: str, **msg): self.send_all(client_msgs, worker_msgs) - def handle_task_erred(self, **msg): + def handle_task_erred(self, key=None, **msg): parent: SchedulerState = cast(SchedulerState, self) recommendations: dict client_msgs: dict worker_msgs: dict - # msg is the output of distributed.worker_state_machine.TaskErredMsg.to_dict(), - # plus worker which is added by the RPC - r: tuple = self.stimulus_task_erred(**msg) + r: tuple = self.stimulus_task_erred(key=key, **msg) recommendations, client_msgs, worker_msgs = r parent._transitions(recommendations, client_msgs, worker_msgs) + self.send_all(client_msgs, worker_msgs) def handle_missing_data(self, key=None, errant_worker=None, **kwargs): @@ -7094,9 +7079,7 @@ async def _track_retire_worker( logger.info("Retired worker %s", ws._address) return ws._address, ws.identity() - def add_keys( - self, worker: str, keys: "Iterable[str]" = (), stimulus_id: "str | None" = None - ) -> str: + def add_keys(self, worker=None, keys=(), stimulus_id=None): """ Learn that a worker has certain keys @@ -7109,7 +7092,7 @@ def add_keys( ws: WorkerState = parent._workers_dv[worker] redundant_replicas = [] for key in keys: - ts: TaskState = parent._tasks.get(key) # type: ignore + ts: TaskState = parent._tasks.get(key) if ts is not None and ts._state == "memory": if ws not in ts._who_has: parent.add_replica(ts, ws) @@ -8208,9 +8191,8 @@ def _add_to_memory( ws: WorkerState, recommendations: dict, client_msgs: dict, - *, - typename: str, - type: "bytes | None" = None, + type=None, + typename: str = None, ): """ Add *ts* to the set of in-memory tasks. From 6b4f310ce8e8414ab9f5a7163edbd55a5fd965fc Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 11 Mar 2022 11:57:10 +0000 Subject: [PATCH 8/9] fix regressions --- distributed/worker.py | 3 +++ distributed/worker_state_machine.py | 13 +++++++++++++ 2 files changed, 16 insertions(+) diff --git a/distributed/worker.py b/distributed/worker.py index 097f9b5754d..b5e30c45763 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2166,6 +2166,8 @@ def transition_generic_error( traceback=ts.traceback, exception_text=ts.exception_text, traceback_text=ts.traceback_text, + thread=self.threads.get(ts.key), + startstops=ts.startstops, ) return {}, [smsg] @@ -2764,6 +2766,7 @@ def _get_task_finished_msg(self, ts: TaskState) -> TaskFinishedMsg: type=typ_serialized, typename=typename(typ), metadata=ts.metadata, + thread=self.threads.get(ts.key), startstops=ts.startstops, ) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index b635cda3800..0133afab965 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -287,9 +287,15 @@ class TaskFinishedMsg(SendMessageToScheduler, op="task-finished"): type: bytes # serialized class typename: str metadata: dict + thread: int | None startstops: list[StartStop] __slots__ = tuple(__annotations__) # type: ignore + def to_dict(self) -> dict[str, Any]: + d = super().to_dict() + d["status"] = "OK" + return d + @dataclass class TaskErredMsg(SendMessageToScheduler, op="task-erred"): @@ -298,8 +304,15 @@ class TaskErredMsg(SendMessageToScheduler, op="task-erred"): exception_text: str traceback: object traceback_text: str + thread: int | None + startstops: list[StartStop] __slots__ = tuple(__annotations__) # type: ignore + def to_dict(self) -> dict[str, Any]: + d = super().to_dict() + d["status"] = "error" + return d + @dataclass class ReleaseWorkerDataMsg(SendMessageToScheduler, op="release-worker-data"): From 297bed6436a9b29609a10e2ff6bf607a2b3c98f5 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 14 Mar 2022 13:30:00 +0000 Subject: [PATCH 9/9] Remove metaclass magic --- distributed/worker_state_machine.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 40f2fbe96a4..ce8d200ba7a 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -268,9 +268,6 @@ class SendMessageToScheduler(Instruction): #: Matches a key in Scheduler.stream_handlers op: ClassVar[str] - def __init_subclass__(cls, op: str): - cls.op = op - def to_dict(self) -> dict[str, Any]: """Convert object to dict so that it can be serialized with msgpack""" d = {k: getattr(self, k) for k in self.__annotations__} @@ -278,10 +275,10 @@ def to_dict(self) -> dict[str, Any]: return d -# Note: as of Python 3.10.2, @dataclass(slots=True) doesn't work with __init__subclass__ -# https://bugs.python.org/issue46970 @dataclass -class TaskFinishedMsg(SendMessageToScheduler, op="task-finished"): +class TaskFinishedMsg(SendMessageToScheduler): + op = "task-finished" + key: str nbytes: int | None type: bytes # serialized class @@ -298,7 +295,9 @@ def to_dict(self) -> dict[str, Any]: @dataclass -class TaskErredMsg(SendMessageToScheduler, op="task-erred"): +class TaskErredMsg(SendMessageToScheduler): + op = "task-erred" + key: str exception: Exception exception_text: str @@ -315,13 +314,17 @@ def to_dict(self) -> dict[str, Any]: @dataclass -class ReleaseWorkerDataMsg(SendMessageToScheduler, op="release-worker-data"): +class ReleaseWorkerDataMsg(SendMessageToScheduler): + op = "release-worker-data" + __slots__ = ("key",) key: str @dataclass -class RescheduleMsg(SendMessageToScheduler, op="reschedule"): +class RescheduleMsg(SendMessageToScheduler): + op = "reschedule" + # Not to be confused with the distributed.Reschedule Exception __slots__ = ("key", "worker") key: str @@ -329,14 +332,18 @@ class RescheduleMsg(SendMessageToScheduler, op="reschedule"): @dataclass -class LongRunningMsg(SendMessageToScheduler, op="long-running"): +class LongRunningMsg(SendMessageToScheduler): + op = "long-running" + __slots__ = ("key", "compute_duration") key: str compute_duration: float @dataclass -class AddKeysMsg(SendMessageToScheduler, op="add-keys"): +class AddKeysMsg(SendMessageToScheduler): + op = "add-keys" + __slots__ = ("keys", "stimulus_id") keys: list[str] stimulus_id: str