Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove EnsureCommunicatingAfterTransitions #6462

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ComputeTaskEvent,
ExecuteFailureEvent,
ExecuteSuccessEvent,
GatherDep,
Instruction,
RecommendationsConflict,
RefreshWhoHasEvent,
Expand Down Expand Up @@ -116,14 +117,16 @@ def test_WorkerState__to_dict(ws):
"busy_workers": [],
"constrained": [],
"data": {"y": None},
"data_needed": ["x"],
"data_needed_per_worker": {"127.0.0.1:1235": ["x"]},
"data_needed": [],
"data_needed_per_worker": {"127.0.0.1:1235": []},
"executing": [],
"in_flight_tasks": [],
"in_flight_workers": {},
"in_flight_tasks": ["x"],
"in_flight_workers": {"127.0.0.1:1235": ["x"]},
"log": [
["x", "ensure-task-exists", "released", "s1"],
["x", "released", "fetch", "fetch", {}, "s1"],
["gather-dependencies", "127.0.0.1:1235", ["x"], "s1"],
["x", "fetch", "flight", "flight", {}, "s1"],
["y", "put-in-memory", "s2"],
["y", "receive-from-scatter", "s2"],
],
Expand All @@ -147,10 +150,11 @@ def test_WorkerState__to_dict(ws):
],
"tasks": {
"x": {
"coming_from": "127.0.0.1:1235",
"key": "x",
"nbytes": 123,
"priority": [1],
"state": "fetch",
"state": "flight",
"who_has": ["127.0.0.1:1235"],
},
"y": {
Expand All @@ -159,7 +163,7 @@ def test_WorkerState__to_dict(ws):
"state": "memory",
},
},
"transition_counter": 1,
"transition_counter": 2,
}
assert actual == expect

Expand Down Expand Up @@ -855,3 +859,29 @@ async def test_deprecated_worker_attributes(s, a, b):
)
with pytest.warns(FutureWarning, match=msg):
assert a.in_flight_tasks == 0


@pytest.mark.parametrize(
"nbytes,n_in_flight",
[
# Note: target_message_size = 50e6 bytes
(int(10e6), 3),
(int(20e6), 2),
(int(30e6), 1),
],
)
def test_aggregate_gather_deps(ws, nbytes, n_in_flight):
instructions = ws.handle_stimulus(
AcquireReplicasEvent(
who_has={
"x1": ["127.0.0.1:1235"],
"x2": ["127.0.0.1:1235"],
"x3": ["127.0.0.1:1235"],
},
nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes},
stimulus_id="test",
)
)
assert len(instructions) == 1
assert isinstance(instructions[0], GatherDep)
assert len(ws.in_flight_tasks) == n_in_flight
185 changes: 69 additions & 116 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,6 @@ class RetryBusyWorkerLater(Instruction):
worker: str


@dataclass
class EnsureCommunicatingAfterTransitions(Instruction):
__slots__ = ()


@dataclass
class SendMessageToScheduler(Instruction):
#: Matches a key in Scheduler.stream_handlers
op: ClassVar[str]
Expand Down Expand Up @@ -1512,13 +1506,7 @@ def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInst
self.data_needed.add(ts)
for w in ts.who_has:
self.data_needed_per_worker[w].add(ts)

# This is the same as `return self._ensure_communicating()`, except that when
# many tasks transition to fetch at the same time, e.g. from a single
# compute-task or acquire-replicas command from the scheduler, it allows
# clustering the transfers into less GatherDep instructions; see
# _select_keys_for_gather().
return {}, [EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id)]
return {}, []

def _transition_missing_waiting(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2299,18 +2287,30 @@ def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructio
reach a steady state
"""
instructions = []

remaining_recs = recommendations.copy()
tasks = set()
while remaining_recs:
ts, finish = remaining_recs.popitem()
tasks.add(ts)
a_recs, a_instructions = self._transition(
ts, finish, stimulus_id=stimulus_id
)

remaining_recs.update(a_recs)
instructions += a_instructions
def process_recs(recs: Recs) -> None:
while recs:
ts, finish = recs.popitem()
tasks.add(ts)
a_recs, a_instructions = self._transition(
ts, finish, stimulus_id=stimulus_id
)
recs.update(a_recs)
instructions.extend(a_instructions)

process_recs(recommendations.copy())

# We could call _ensure_communicating after we change something that could
# trigger a new call to gather_dep (e.g. on transitions to fetch,
# GatherDepDoneEvent, or RetryBusyWorkerEvent). However, doing so we'd
# potentially call it too early, before all tasks have transitioned to fetch.
# This in turn would hurt aggregation of multiple tasks into a single GatherDep
# instruction.
# Read: https://github.com/dask/distributed/issues/6497
a_recs, a_instructions = self._ensure_communicating(stimulus_id=stimulus_id)
instructions += a_instructions
process_recs(a_recs)

if self.validate:
# Full state validation is very expensive
Expand Down Expand Up @@ -2554,10 +2554,7 @@ def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
self.has_what[ev.worker].discard(ts.key)
recommendations[ts] = "fetch"

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
Comment on lines -2557 to -2560
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love that this is not everywhere anymore ❤️

return recommendations, []

@_handle_event.register
def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs:
Expand Down Expand Up @@ -2586,10 +2583,7 @@ def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs:
)
)

return merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return recommendations, instructions

@_handle_event.register
def _handle_gather_dep_network_failure(
Expand All @@ -2616,10 +2610,7 @@ def _handle_gather_dep_network_failure(
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
recommendations[ts] = "fetch"

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return recommendations, []

@_handle_event.register
def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
Expand All @@ -2637,10 +2628,7 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
for ts in self._gather_dep_done_common(ev)
}

return merge_recs_instructions(
(recommendations, []),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return recommendations, []

@_handle_event.register
def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs:
Expand Down Expand Up @@ -2680,15 +2668,12 @@ def _handle_pause(self, ev: PauseEvent) -> RecsInstrs:
def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs:
"""Emerge from paused status"""
self.running = True
return merge_recs_instructions(
self._ensure_computing(),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
return self._ensure_computing()

@_handle_event.register
def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs:
self.busy_workers.discard(ev.worker)
return self._ensure_communicating(stimulus_id=ev.stimulus_id)
return {}, []

@_handle_event.register
def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs:
Expand Down Expand Up @@ -2790,17 +2775,13 @@ def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs:

if ts.who_has and ts.state == "missing":
recommendations[ts] = "fetch"
elif ts.who_has and ts.state == "fetch":
# We potentially just acquired new replicas whereas all previously known
# workers are in flight or busy. We're deliberately not testing the
# minute use cases here for the sake of simplicity; instead we rely on
# _ensure_communicating to be a no-op when there's nothing to do.
recommendations, instructions = merge_recs_instructions(
(recommendations, instructions),
self._ensure_communicating(stimulus_id=ev.stimulus_id),
)
elif not ts.who_has and ts.state == "fetch":
recommendations[ts] = "missing"
# Note: if ts.who_has and ts.state == "fetch", we may have just acquired new
# replicas whereas all previously known workers are in flight or busy. We
# rely on _transitions to call _ensure_communicating every time, even in
# absence of recommendations, to potentially kick off a new call to
# gather_dep.

return recommendations, instructions

Expand Down Expand Up @@ -3090,73 +3071,45 @@ def handle_stimulus(self, *stims: StateMachineEvent) -> None:
"""
instructions = self.state.handle_stimulus(*stims)

while instructions:
ensure_communicating: EnsureCommunicatingAfterTransitions | None = None
for inst in instructions:
task: asyncio.Task | None = None

if isinstance(inst, SendMessageToScheduler):
self.batched_send(inst.to_dict())

elif isinstance(inst, EnsureCommunicatingAfterTransitions):
# A single compute-task or acquire-replicas command may cause
# multiple tasks to transition to fetch; this in turn means that we
# will receive multiple instances of this instruction.
# _ensure_communicating is a no-op if it runs twice in a row; we're
# not calling it inside the for loop to avoid a O(n^2) condition
# when
# 1. there are many fetches queued because all workers are in flight
# 2. a single compute-task or acquire-replicas command just sent
# many dependencies to fetch at once.
ensure_communicating = inst

elif isinstance(inst, GatherDep):
assert inst.to_gather
keys_str = ", ".join(peekn(27, inst.to_gather)[0])
if len(keys_str) > 80:
keys_str = keys_str[:77] + "..."
task = asyncio.create_task(
self.gather_dep(
inst.worker,
inst.to_gather,
total_nbytes=inst.total_nbytes,
stimulus_id=inst.stimulus_id,
),
name=f"gather_dep({inst.worker}, {{{keys_str}}})",
)

elif isinstance(inst, Execute):
task = asyncio.create_task(
self.execute(inst.key, stimulus_id=inst.stimulus_id),
name=f"execute({inst.key})",
)

elif isinstance(inst, RetryBusyWorkerLater):
task = asyncio.create_task(
self.retry_busy_worker_later(inst.worker),
name=f"retry_busy_worker_later({inst.worker})",
)
for inst in instructions:
task: asyncio.Task | None = None

if isinstance(inst, SendMessageToScheduler):
self.batched_send(inst.to_dict())

elif isinstance(inst, GatherDep):
assert inst.to_gather
keys_str = ", ".join(peekn(27, inst.to_gather)[0])
if len(keys_str) > 80:
keys_str = keys_str[:77] + "..."
task = asyncio.create_task(
self.gather_dep(
inst.worker,
inst.to_gather,
total_nbytes=inst.total_nbytes,
stimulus_id=inst.stimulus_id,
),
name=f"gather_dep({inst.worker}, {{{keys_str}}})",
)

else:
raise TypeError(inst) # pragma: nocover

if task is not None:
self._async_instructions.add(task)
task.add_done_callback(self._handle_stimulus_from_task)

if ensure_communicating:
# Potentially re-fill instructions, causing a second iteration of `while
# instructions` at the top of this method
# FIXME access to private methods
# https://github.com/dask/distributed/issues/6497
recs, instructions = self.state._ensure_communicating(
stimulus_id=ensure_communicating.stimulus_id
elif isinstance(inst, Execute):
task = asyncio.create_task(
self.execute(inst.key, stimulus_id=inst.stimulus_id),
name=f"execute({inst.key})",
)
instructions += self.state._transitions(
recs, stimulus_id=ensure_communicating.stimulus_id

elif isinstance(inst, RetryBusyWorkerLater):
task = asyncio.create_task(
self.retry_busy_worker_later(inst.worker),
name=f"retry_busy_worker_later({inst.worker})",
)

else:
return
raise TypeError(inst) # pragma: nocover

if task is not None:
self._async_instructions.add(task)
task.add_done_callback(self._handle_stimulus_from_task)

async def close(self, timeout: float = 30) -> None:
"""Cancel all asynchronous instructions"""
Expand Down