Skip to content

Commit

Permalink
Refactor ensure_comuting() -> None to _ensure_computing() -> RecsInstrs
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Apr 7, 2022
1 parent c4e07a1 commit c017629
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 111 deletions.
32 changes: 21 additions & 11 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3611,6 +3611,25 @@ def _reevaluate_occupancy_worker(self, ws: WorkerState):
for ts in ws._processing:
steal.recalculate_cost(ts)

@ccall
def bulk_schedule_after_adding_worker(self, ws: WorkerState):
"""Send tasks with ts.state=='no-worker' in bulk to a worker that just joined.
Return recommendations. As the worker will start executing the new tasks
immediately, without waiting for the batch to end, we can't rely on worker-side
ordering, so the recommendations are sorted by priority order here.
"""
ts: TaskState
tasks = []
for ts in self._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
# id(ts) is to prevent calling TaskState.__gt__ given equal priority
tasks.append(ts)
# These recommendations will generate {"op": "compute-task"} messages
# to the worker in reversed order
tasks.sort(key=operator.attrgetter("priority"), reverse=True)
return {ts._key: "waiting" for ts in tasks}


class Scheduler(SchedulerState, ServerNode):
"""Dynamic distributed task scheduler
Expand Down Expand Up @@ -4583,10 +4602,7 @@ async def add_worker(
)

if ws._status == Status.running:
for ts in parent._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recommendations[ts._key] = "waiting"
recommendations.update(self.bulk_schedule_after_adding_worker(ws))

if recommendations:
parent._transitions(recommendations, client_msgs, worker_msgs)
Expand Down Expand Up @@ -5699,13 +5715,7 @@ def handle_worker_status_change(self, status: str, worker: str) -> None:

if ws._status == Status.running:
parent._running.add(ws)

recs = {}
ts: TaskState
for ts in parent._unrunnable:
valid: set = self.valid_workers(ts)
if valid is None or ws in valid:
recs[ts._key] = "waiting"
recs = self.bulk_schedule_after_adding_worker(ws)
if recs:
client_msgs: dict = {}
worker_msgs: dict = {}
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def f(ev):
[
("f1", "compute-task"),
("f1", "released", "waiting", "waiting", {"f1": "ready"}),
("f1", "waiting", "ready", "ready", {}),
("f1", "waiting", "ready", "ready", {"f1": "executing"}),
("f1", "ready", "executing", "executing", {}),
("free-keys", ("f1",)),
("f1", "executing", "released", "cancelled", {}),
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_cluster_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def test_cluster_dump_story(c, s, a, b, tmp_path):
[
(k, "compute-task"),
(k, "released", "waiting", "waiting", {k: "ready"}),
(k, "waiting", "ready", "ready", {}),
(k, "waiting", "ready", "ready", {k: "executing"}),
(k, "ready", "executing", "executing", {}),
(k, "put-in-memory"),
(k, "executing", "memory", "memory", {}),
Expand Down
30 changes: 19 additions & 11 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,17 +1182,25 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
await ev.set()
await c.gather(futs1)

# If this turns out to be overly flaky, the following may be relaxed or
# removed. The point of this test is to not deadlock but verifying expected
# state is still a nice thing

# Either the last request goes through or both have been rejected since the
# computation was already done by the time the request comes in. This is
# unfortunately not stable.
if victim_ts.who_has != {wsC}:
msgs = steal.story(victim_ts)
assert len(msgs) == 2
assert all(msg[0] == "already-aborted" for msg in msgs), msgs
assert victim_ts.who_has != {wsC}
msgs = steal.story(victim_ts)
msgs = [msg[:-1] for msg in msgs] # Remove random IDs

# There are three possible outcomes
expect1 = [
("stale-response", victim_key, "executing", wsA.address),
("already-computing", victim_key, "executing", wsB.address, wsC.address),
]
expect2 = [
("already-computing", victim_key, "executing", wsB.address, wsC.address),
("already-aborted", victim_key, "executing", wsA.address),
]
# This outcome appears only in ~2% of the runs
expect3 = [
("already-computing", victim_key, "executing", wsB.address, wsC.address),
("already-aborted", victim_key, "memory", wsA.address),
]
assert msgs in (expect1, expect2, expect3)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,12 +1687,12 @@ async def test_story_with_deps(c, s, a, b):

# Story now includes randomized stimulus_ids and timestamps.
stimulus_ids = {ev[-2] for ev in story}
assert len(stimulus_ids) == 3, stimulus_ids
assert len(stimulus_ids) == 2, stimulus_ids
# This is a simple transition log
expected = [
("res", "compute-task"),
("res", "released", "waiting", "waiting", {"dep": "fetch"}),
("res", "waiting", "ready", "ready", {}),
("res", "waiting", "ready", "ready", {"res": "executing"}),
("res", "ready", "executing", "executing", {}),
("res", "put-in-memory"),
("res", "executing", "memory", "memory", {}),
Expand Down Expand Up @@ -3089,7 +3089,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):
# Now, we actually compute the task *once*. This must not cycle back
("f1", "compute-task"),
("f1", "released", "waiting", "waiting", {"f1": "ready"}),
("f1", "waiting", "ready", "ready", {}),
("f1", "waiting", "ready", "ready", {"f1": "executing"}),
("f1", "ready", "executing", "executing", {}),
("f1", "put-in-memory"),
("f1", "executing", "memory", "memory", {}),
Expand Down
24 changes: 24 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from distributed.worker_state_machine import (
Instruction,
ReleaseWorkerDataMsg,
RescheduleMsg,
SendMessageToScheduler,
StateMachineEvent,
TaskState,
UniqueTaskHeap,
merge_recs_instructions,
)


Expand Down Expand Up @@ -109,3 +111,25 @@ def test_sendmsg_to_dict():
# Arbitrary sample class
smsg = ReleaseWorkerDataMsg(key="x")
assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"}


def test_merge_recs_instructions():
x = TaskState("x")
y = TaskState("y")
instr1 = RescheduleMsg(key="foo", worker="a")
instr2 = RescheduleMsg(key="bar", worker="b")
assert merge_recs_instructions(
({x: "memory"}, [instr1]),
({y: "released"}, [instr2]),
) == (
{x: "memory", y: "released"},
[instr1, instr2],
)

# Identical recommendations are silently ignored; incompatible ones raise
assert merge_recs_instructions(({x: "memory"}, []), ({x: "memory"}, [])) == (
{x: "memory"},
[],
)
with pytest.raises(ValueError):
merge_recs_instructions(({x: "memory"}, []), ({x: "released"}, []))
Loading

0 comments on commit c017629

Please sign in to comment.