From f6b2e03cb659f19a516eb14c3270f9430a661af3 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 8 Apr 2022 21:00:31 +0100 Subject: [PATCH] Migrate ensure_executing transitions to new WorkerState event mechanism - part 2 (#6062) --- distributed/scheduler.py | 31 ++- distributed/tests/test_cancelled_state.py | 2 +- distributed/tests/test_cluster_dump.py | 2 +- distributed/tests/test_steal.py | 30 ++- distributed/tests/test_worker.py | 6 +- .../tests/test_worker_state_machine.py | 24 +++ distributed/worker.py | 195 ++++++++++-------- distributed/worker_state_machine.py | 22 ++ 8 files changed, 200 insertions(+), 112 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3c535b552c9..ddf1e1ad2b9 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3612,6 +3612,24 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState): for ts in ws._processing: steal.recalculate_cost(ts) + @ccall + def bulk_schedule_after_adding_worker(self, ws: WorkerState): + """Send tasks with ts.state=='no-worker' in bulk to a worker that just joined. + Return recommendations. As the worker will start executing the new tasks + immediately, without waiting for the batch to end, we can't rely on worker-side + ordering, so the recommendations are sorted by priority order here. + """ + ts: TaskState + tasks = [] + for ts in self._unrunnable: + valid: set = self.valid_workers(ts) + if valid is None or ws in valid: + tasks.append(ts) + # These recommendations will generate {"op": "compute-task"} messages + # to the worker in reversed order + tasks.sort(key=operator.attrgetter("priority"), reverse=True) + return {ts._key: "waiting" for ts in tasks} + class Scheduler(SchedulerState, ServerNode): """Dynamic distributed task scheduler @@ -4584,10 +4602,7 @@ async def add_worker( ) if ws._status == Status.running: - for ts in parent._unrunnable: - valid: set = self.valid_workers(ts) - if valid is None or ws in valid: - recommendations[ts._key] = "waiting" + recommendations.update(self.bulk_schedule_after_adding_worker(ws)) if recommendations: parent._transitions(recommendations, client_msgs, worker_msgs) @@ -5700,13 +5715,7 @@ def handle_worker_status_change(self, status: str, worker: str) -> None: if ws._status == Status.running: parent._running.add(ws) - - recs = {} - ts: TaskState - for ts in parent._unrunnable: - valid: set = self.valid_workers(ts) - if valid is None or ws in valid: - recs[ts._key] = "waiting" + recs = self.bulk_schedule_after_adding_worker(ws) if recs: client_msgs: dict = {} worker_msgs: dict = {} diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 1a0df078376..92ac98587ae 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -87,7 +87,7 @@ def f(ev): [ ("f1", "compute-task"), ("f1", "released", "waiting", "waiting", {"f1": "ready"}), - ("f1", "waiting", "ready", "ready", {}), + ("f1", "waiting", "ready", "ready", {"f1": "executing"}), ("f1", "ready", "executing", "executing", {}), ("free-keys", ("f1",)), ("f1", "executing", "released", "cancelled", {}), diff --git a/distributed/tests/test_cluster_dump.py b/distributed/tests/test_cluster_dump.py index b01cf2611ca..c3912116c46 100644 --- a/distributed/tests/test_cluster_dump.py +++ b/distributed/tests/test_cluster_dump.py @@ -145,7 +145,7 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path): [ (k, "compute-task"), (k, "released", "waiting", "waiting", {k: "ready"}), - (k, "waiting", "ready", "ready", {}), + (k, "waiting", "ready", "ready", {k: "executing"}), (k, "ready", "executing", "executing", {}), (k, "put-in-memory"), (k, "executing", "memory", "memory", {}), diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 55f1bbf043a..98565bc2cfe 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1182,17 +1182,25 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers): await ev.set() await c.gather(futs1) - # If this turns out to be overly flaky, the following may be relaxed or - # removed. The point of this test is to not deadlock but verifying expected - # state is still a nice thing - - # Either the last request goes through or both have been rejected since the - # computation was already done by the time the request comes in. This is - # unfortunately not stable. - if victim_ts.who_has != {wsC}: - msgs = steal.story(victim_ts) - assert len(msgs) == 2 - assert all(msg[0] == "already-aborted" for msg in msgs), msgs + assert victim_ts.who_has != {wsC} + msgs = steal.story(victim_ts) + msgs = [msg[:-1] for msg in msgs] # Remove random IDs + + # There are three possible outcomes + expect1 = [ + ("stale-response", victim_key, "executing", wsA.address), + ("already-computing", victim_key, "executing", wsB.address, wsC.address), + ] + expect2 = [ + ("already-computing", victim_key, "executing", wsB.address, wsC.address), + ("already-aborted", victim_key, "executing", wsA.address), + ] + # This outcome appears only in ~2% of the runs + expect3 = [ + ("already-computing", victim_key, "executing", wsB.address, wsC.address), + ("already-aborted", victim_key, "memory", wsA.address), + ] + assert msgs in (expect1, expect2, expect3) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 357c34e5603..e4af0c62270 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1687,12 +1687,12 @@ async def test_story_with_deps(c, s, a, b): # Story now includes randomized stimulus_ids and timestamps. stimulus_ids = {ev[-2] for ev in story} - assert len(stimulus_ids) == 3, stimulus_ids + assert len(stimulus_ids) == 2, stimulus_ids # This is a simple transition log expected = [ ("res", "compute-task"), ("res", "released", "waiting", "waiting", {"dep": "fetch"}), - ("res", "waiting", "ready", "ready", {}), + ("res", "waiting", "ready", "ready", {"res": "executing"}), ("res", "ready", "executing", "executing", {}), ("res", "put-in-memory"), ("res", "executing", "memory", "memory", {}), @@ -3089,7 +3089,7 @@ async def test_task_flight_compute_oserror(c, s, a, b): # Now, we actually compute the task *once*. This must not cycle back ("f1", "compute-task"), ("f1", "released", "waiting", "waiting", {"f1": "ready"}), - ("f1", "waiting", "ready", "ready", {}), + ("f1", "waiting", "ready", "ready", {"f1": "executing"}), ("f1", "ready", "executing", "executing", {}), ("f1", "put-in-memory"), ("f1", "executing", "memory", "memory", {}), diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 78597a37e67..87828a17cf7 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -6,10 +6,12 @@ from distributed.worker_state_machine import ( Instruction, ReleaseWorkerDataMsg, + RescheduleMsg, SendMessageToScheduler, StateMachineEvent, TaskState, UniqueTaskHeap, + merge_recs_instructions, ) @@ -109,3 +111,25 @@ def test_sendmsg_to_dict(): # Arbitrary sample class smsg = ReleaseWorkerDataMsg(key="x") assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"} + + +def test_merge_recs_instructions(): + x = TaskState("x") + y = TaskState("y") + instr1 = RescheduleMsg(key="foo", worker="a") + instr2 = RescheduleMsg(key="bar", worker="b") + assert merge_recs_instructions( + ({x: "memory"}, [instr1]), + ({y: "released"}, [instr2]), + ) == ( + {x: "memory", y: "released"}, + [instr1, instr2], + ) + + # Identical recommendations are silently ignored; incompatible ones raise + assert merge_recs_instructions(({x: "memory"}, []), ({x: "memory"}, [])) == ( + {x: "memory"}, + [], + ) + with pytest.raises(ValueError): + merge_recs_instructions(({x: "memory"}, []), ({x: "released"}, [])) diff --git a/distributed/worker.py b/distributed/worker.py index 7a0628763dd..7c5bc61ca15 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -129,6 +129,8 @@ TaskState, TaskStateState, UniqueTaskHeap, + UnpauseEvent, + merge_recs_instructions, ) if TYPE_CHECKING: @@ -921,8 +923,7 @@ def status(self, value): ServerNode.status.__set__(self, value) self._send_worker_status_change() if prev_status == Status.paused and value == Status.running: - self.ensure_computing() - self.ensure_communicating() + self.handle_stimulus(UnpauseEvent(stimulus_id=f"set-status-{time()}")) def _send_worker_status_change(self) -> None: if ( @@ -1178,9 +1179,7 @@ async def heartbeat(self): async def handle_scheduler(self, comm): try: - await self.handle_stream( - comm, every_cycle=[self.ensure_communicating, self.ensure_computing] - ) + await self.handle_stream(comm, every_cycle=[self.ensure_communicating]) except Exception as e: logger.exception(e) raise @@ -1958,7 +1957,10 @@ def transition_generic_released( if not ts.dependents: recs[ts] = "forgotten" - return recs, [] + return merge_recs_instructions( + (recs, []), + self._ensure_computing(), + ) def transition_released_waiting( self, ts: TaskState, *, stimulus_id: str @@ -2022,7 +2024,7 @@ def transition_waiting_constrained( assert ts.key not in self.ready ts.state = "constrained" self.constrained.append(ts.key) - return {}, [] + return self._ensure_computing() def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str @@ -2038,9 +2040,10 @@ def transition_executing_rescheduled( self.available_resources[resource] += quantity self._executing.discard(ts) - recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address) - return recs, [smsg] + return merge_recs_instructions( + ({ts: "released"}, [RescheduleMsg(key=ts.key, worker=self.address)]), + self._ensure_computing(), + ) def transition_waiting_ready( self, ts: TaskState, *, stimulus_id: str @@ -2057,7 +2060,7 @@ def transition_waiting_ready( assert ts.priority is not None heapq.heappush(self.ready, (ts.priority, ts.key)) - return {}, [] + return self._ensure_computing() def transition_cancelled_error( self, @@ -2133,13 +2136,17 @@ def transition_executing_error( for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity self._executing.discard(ts) - return self.transition_generic_error( - ts, - exception, - traceback, - exception_text, - traceback_text, - stimulus_id=stimulus_id, + + return merge_recs_instructions( + self.transition_generic_error( + ts, + exception, + traceback, + exception_text, + traceback_text, + stimulus_id=stimulus_id, + ), + self._ensure_computing(), ) def _transition_from_resumed( @@ -2254,12 +2261,11 @@ def transition_cancelled_released( for resource, quantity in ts.resource_restrictions.items(): self.available_resources[resource] += quantity - recs, instructions = self.transition_generic_released( - ts, stimulus_id=stimulus_id + + return merge_recs_instructions( + self.transition_generic_released(ts, stimulus_id=stimulus_id), + ({ts: next_state} if next_state != "released" else {}, []), ) - if next_state != "released": - recs[ts] = next_state - return recs, instructions def transition_executing_released( self, ts: TaskState, *, stimulus_id: str @@ -2269,7 +2275,7 @@ def transition_executing_released( # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 ts.state = "cancelled" ts.done = False - return {}, [] + return self._ensure_computing() def transition_long_running_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str @@ -2292,16 +2298,19 @@ def transition_generic_memory( self._executing.discard(ts) self._in_flight_tasks.discard(ts) ts.coming_from = None + + instructions: Instructions = [] try: recs = self._put_key_in_memory(ts, value, stimulus_id=stimulus_id) except Exception as e: msg = error_message(e) recs = {ts: tuple(msg.values())} - return recs, [] - if self.validate: - assert ts.key in self.data or ts.key in self.actors - smsg = self._get_task_finished_msg(ts) - return recs, [smsg] + else: + if self.validate: + assert ts.key in self.data or ts.key in self.actors + instructions.append(self._get_task_finished_msg(ts)) + + return recs, instructions def transition_executing_memory( self, ts: TaskState, value=no_value, *, stimulus_id: str @@ -2313,7 +2322,10 @@ def transition_executing_memory( self._executing.discard(ts) self.executed_count += 1 - return self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) + return merge_recs_instructions( + self.transition_generic_memory(ts, value=value, stimulus_id=stimulus_id), + self._ensure_computing(), + ) def transition_constrained_executing( self, ts: TaskState, *, stimulus_id: str @@ -2326,10 +2338,7 @@ def transition_constrained_executing( for dep in ts.dependencies: assert dep.key in self.data or dep.key in self.actors - for resource, quantity in ts.resource_restrictions.items(): - self.available_resources[resource] -= quantity ts.state = "executing" - self._executing.add(ts) instr = Execute(key=ts.key, stimulus_id=stimulus_id) return {}, [instr] @@ -2347,7 +2356,6 @@ def transition_ready_executing( ) ts.state = "executing" - self._executing.add(ts) instr = Execute(key=ts.key, stimulus_id=stimulus_id) return {}, [instr] @@ -2417,9 +2425,11 @@ def transition_executing_long_running( ts.state = "long-running" self._executing.discard(ts) self.long_running.add(ts.key) - smsg = LongRunningMsg(key=ts.key, compute_duration=compute_duration) - self.io_loop.add_callback(self.ensure_computing) - return {}, [smsg] + + return merge_recs_instructions( + ({}, [LongRunningMsg(key=ts.key, compute_duration=compute_duration)]), + self._ensure_computing(), + ) def transition_released_memory( self, ts: TaskState, value, *, stimulus_id: str @@ -2592,8 +2602,6 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None: recs, instructions = self.handle_event(stim) self.transitions(recs, stimulus_id=stim.stimulus_id) self._handle_instructions(instructions) - self.ensure_computing() - self.ensure_communicating() def _handle_stimulus_from_future( self, future: asyncio.Future[StateMachineEvent | None] @@ -3035,7 +3043,6 @@ async def gather_dep( recommendations[ts] = "fetch" if ts.who_has else "missing" del data, response self.transitions(recommendations, stimulus_id=stimulus_id) - self.ensure_computing() if not busy: self.repetitively_busy = 0 @@ -3076,7 +3083,6 @@ async def find_missing(self) -> None: "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time self.ensure_communicating() - self.ensure_computing() async def query_who_has(self, *deps: str) -> dict[str, Collection[str]]: with log_errors(): @@ -3327,16 +3333,6 @@ def actor_attribute(self, actor=None, attribute=None) -> dict[str, Any]: except Exception as ex: return {"status": "error", "exception": to_serialize(ex)} - def meets_resource_constraints(self, key: str) -> bool: - ts = self.tasks[key] - if not ts.resource_restrictions: - return True - for resource, needed in ts.resource_restrictions.items(): - if self.available_resources[resource] < needed: - return False - - return True - async def _maybe_deserialize_task( self, ts: TaskState, *, stimulus_id: str ) -> tuple[Callable, tuple, dict[str, Any]] | None: @@ -3369,42 +3365,62 @@ async def _maybe_deserialize_task( ) raise - def ensure_computing(self) -> None: + def _ensure_computing(self) -> RecsInstrs: if self.status in (Status.paused, Status.closing_gracefully): - return - try: - stimulus_id = f"ensure-computing-{time()}" - while self.constrained and self.executing_count < self.nthreads: - key = self.constrained[0] - ts = self.tasks.get(key, None) - if ts is None or ts.state != "constrained": - self.constrained.popleft() - continue - if self.meets_resource_constraints(key): - self.constrained.popleft() - self.transition(ts, "executing", stimulus_id=stimulus_id) - else: - break - while self.ready and self.executing_count < self.nthreads: - priority, key = heapq.heappop(self.ready) - ts = self.tasks.get(key) - if ts is None: - # It is possible for tasks to be released while still remaining on - # `ready` The scheduler might have re-routed to a new worker and - # told this worker to release. If the task has "disappeared" just - # continue through the heap - continue - elif ts.key in self.data: - self.transition(ts, "memory", stimulus_id=stimulus_id) - elif ts.state in READY: - self.transition(ts, "executing", stimulus_id=stimulus_id) - except Exception as e: # pragma: no cover - logger.exception(e) - if LOG_PDB: - import pdb + return {}, [] - pdb.set_trace() - raise + recs: Recs = {} + while self.constrained and len(self._executing) < self.nthreads: + key = self.constrained[0] + ts = self.tasks.get(key, None) + if ts is None or ts.state != "constrained": + self.constrained.popleft() + continue + + # There may be duplicates in the self.constrained and self.ready queues. + # This happens if a task: + # 1. is assigned to a Worker and transitioned to ready (heappush) + # 2. is stolen (no way to pop from heap, the task stays there) + # 3. is assigned to the worker again (heappush again) + if ts in recs: + continue + + if any( + self.available_resources[resource] < needed + for resource, needed in ts.resource_restrictions.items() + ): + break + + self.constrained.popleft() + for resource, needed in ts.resource_restrictions.items(): + self.available_resources[resource] -= needed + + recs[ts] = "executing" + self._executing.add(ts) + + while self.ready and len(self._executing) < self.nthreads: + _, key = heapq.heappop(self.ready) + ts = self.tasks.get(key) + if ts is None: + # It is possible for tasks to be released while still remaining on + # `ready`. The scheduler might have re-routed to a new worker and + # told this worker to release. If the task has "disappeared", just + # continue through the heap. + continue + + if key in self.data: + # See comment above about duplicates + if self.validate: + assert ts not in recs or recs[ts] == "memory" + recs[ts] = "memory" + elif ts.state in READY: + # See comment above about duplicates + if self.validate: + assert ts not in recs or recs[ts] == "executing" + recs[ts] = "executing" + self._executing.add(ts) + + return recs, [] async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None: if self.status in {Status.closing, Status.closed, Status.closing_gracefully}: @@ -3538,6 +3554,15 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No def handle_event(self, ev: StateMachineEvent) -> RecsInstrs: raise TypeError(ev) # pragma: nocover + @handle_event.register + def _(self, ev: UnpauseEvent) -> RecsInstrs: + """Emerge from paused status. Do not send this event directly. Instead, just set + Worker.status back to running. + """ + assert self.status == Status.running + self.ensure_communicating() + return self._ensure_computing() + @handle_event.register def _(self, ev: CancelComputeEvent) -> RecsInstrs: """Scheduler requested to cancel a task""" diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 8ae454417c9..a21c2acb301 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -357,6 +357,11 @@ class StateMachineEvent: stimulus_id: str +@dataclass +class UnpauseEvent(StateMachineEvent): + __slots__ = () + + @dataclass class ExecuteSuccessEvent(StateMachineEvent): key: str @@ -410,3 +415,20 @@ class RescheduleEvent(StateMachineEvent): Recs = dict Instructions = list RecsInstrs = tuple + + +def merge_recs_instructions(*args: RecsInstrs) -> RecsInstrs: + """Merge multiple (recommendations, instructions) tuples. + Collisions in recommendations are only allowed if identical. + """ + recs: Recs = {} + instr: Instructions = [] + for recs_i, instr_i in args: + for k, v in recs_i.items(): + if k in recs and recs[k] != v: + raise ValueError( + f"Mismatched recommendations for {k}: {recs[k]} vs. {v}" + ) + recs[k] = v + instr += instr_i + return recs, instr