Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve work stealing deadlock caused by race in move_task_confirm #5379

Merged
merged 9 commits into from
Oct 14, 2021
11 changes: 7 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2765,10 +2765,10 @@ def transition_processing_memory(

s: set = self._unknown_durations.pop(ts._prefix._name, set())
tts: TaskState
steal = self.extensions.get("stealing")
for tts in s:
if tts._processing_on:
self.set_duration_estimate(tts, tts._processing_on)
steal = self.extensions.get("stealing")
if steal:
steal.put_key_in_stealable(tts)

Expand Down Expand Up @@ -7029,8 +7029,12 @@ def get_metadata(self, comm=None, keys=None, default=no_default):
raise

def set_restrictions(self, comm=None, worker=None):
ts: TaskState
for key, restrictions in worker.items():
self.tasks[key]._worker_restrictions = set(restrictions)
ts = self.tasks[key]
if isinstance(restrictions, str):
restrictions = {restrictions}
ts._worker_restrictions = set(restrictions)

def get_task_status(self, comm=None, keys=None):
parent: SchedulerState = cast(SchedulerState, self)
Expand Down Expand Up @@ -7960,8 +7964,7 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
return
if ws._occupancy > old * 1.3 or old > ws._occupancy * 1.3:
for ts in ws._processing:
steal.remove_key_from_stealable(ts)
steal.put_key_in_stealable(ts)
steal.recalculate_cost(ts)


@cfunc
Expand Down
159 changes: 91 additions & 68 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,24 @@

LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")

_WORKER_STATE_CONFIRM = {
"ready",
"constrained",
"waiting",
}

_WORKER_STATE_REJECT = {
"memory",
"executing",
"long-running",
"cancelled",
"resumed",
}
_WORKER_STATE_UNDEFINED = {
"released",
None,
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these constant names generate confusion with Worker.status.
Could you change them e.g. to _TASKSTATE_STATE_CONFIRM etc.?

On a related note, this IMHO highlights that TaskState.state direly needs to become an Enum - it's just too easy to miss a state here, now or in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change them e.g. to _TASKSTATE_STATE_CONFIRM etc.?

sure

this IMHO highlights that TaskState.state direly needs to become an Enum

agreed. I would prefer doing this in another PR if you don't mind

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I missed this outstanding request. @fjetter would you mind pushing up a separate PR with the updated name @crusaderky suggested?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. I would prefer doing this in another PR

Yes obviously


class WorkStealing(SchedulerPlugin):
def __init__(self, scheduler):
Expand Down Expand Up @@ -78,8 +96,19 @@ def transition(
elif start == "processing":
ts = self.scheduler.tasks[key]
self.remove_key_from_stealable(ts)
if finish != "memory":
self.in_flight.pop(ts, None)
d = self.in_flight.pop(ts, None)
if d:
thief = d["thief"]
victim = d["victim"]
self.in_flight_occupancy[thief] -= d["thief_duration"]
self.in_flight_occupancy[victim] += d["victim_duration"]
if not self.in_flight:
self.in_flight_occupancy.clear()

def recalculate_cost(self, ts):
if ts not in self.in_flight:
self.remove_key_from_stealable(ts)
self.put_key_in_stealable(ts)

def put_key_in_stealable(self, ts):
cost_multiplier, level = self.steal_time_ratio(ts)
Expand Down Expand Up @@ -138,13 +167,11 @@ def steal_time_ratio(self, ts):

return cost_multiplier, level

def move_task_request(self, ts, victim, thief):
def move_task_request(self, ts, victim, thief) -> str:
try:
if self.scheduler.validate:
if victim is not ts.processing_on and LOG_PDB:
import pdb

pdb.set_trace()
if ts in self.in_flight:
return "in-flight"
stimulus_id = f"steal-{time()}"

key = ts.key
self.remove_key_from_stealable(ts)
Expand All @@ -164,20 +191,22 @@ def move_task_request(self, ts, victim, thief):
) + self.scheduler.get_comm_cost(ts, thief)

self.scheduler.stream_comms[victim.address].send(
{"op": "steal-request", "key": key}
{"op": "steal-request", "key": key, "stimulus_id": stimulus_id}
)

self.in_flight[ts] = {
"victim": victim,
"victim": victim, # guaranteed to be processing_on
"thief": thief,
"victim_duration": victim_duration,
"thief_duration": thief_duration,
"stimulus_id": stimulus_id,
}

self.in_flight_occupancy[victim] -= victim_duration
self.in_flight_occupancy[thief] += thief_duration
return stimulus_id
except CommClosedError:
logger.info("Worker comm %r closed while stealing: %r", victim, ts)
return "comm-closed"
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand All @@ -186,30 +215,41 @@ def move_task_request(self, ts, victim, thief):
pdb.set_trace()
raise

async def move_task_confirm(self, key=None, worker=None, state=None):
async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
try:
try:
ts = self.scheduler.tasks[key]
except KeyError:
logger.debug("Key released between request and confirm: %s", key)
return
try:
d = self.in_flight.pop(ts)
except KeyError:
ts = self.scheduler.tasks[key]
except KeyError:
logger.debug("Key released between request and confirm: %s", key)
return
try:
d = self.in_flight.pop(ts)
if d["stimulus_id"] != stimulus_id:
self.log(("stale-response", key, state, worker, stimulus_id))
self.in_flight[ts] = d
return
thief = d["thief"]
victim = d["victim"]
logger.debug(
"Confirm move %s, %s -> %s. State: %s", key, victim, thief, state
)
except KeyError:
self.log(("already-aborted", key, state, stimulus_id))
return

self.in_flight_occupancy[thief] -= d["thief_duration"]
self.in_flight_occupancy[victim] += d["victim_duration"]
thief = d["thief"]
victim = d["victim"]

if not self.in_flight:
self.in_flight_occupancy = defaultdict(lambda: 0)
logger.debug("Confirm move %s, %s -> %s. State: %s", key, victim, thief, state)

if ts.state != "processing" or ts.processing_on is not victim:
self.in_flight_occupancy[thief] -= d["thief_duration"]
self.in_flight_occupancy[victim] += d["victim_duration"]

if not self.in_flight:
self.in_flight_occupancy.clear()

if self.scheduler.validate:
assert ts.processing_on == victim

try:
_log_msg = [key, state, victim.address, thief.address, stimulus_id]

if ts.state != "processing":
self.log(("not-processing", *_log_msg))
old_thief = thief.occupancy
new_thief = sum(thief.processing.values())
old_victim = victim.occupancy
Expand All @@ -219,32 +259,24 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
self.scheduler.total_occupancy += (
new_thief - old_thief + new_victim - old_victim
)
return

# One of the pair has left, punt and reschedule
if (
thief.address not in self.scheduler.workers
or victim.address not in self.scheduler.workers
elif (
state in _WORKER_STATE_UNDEFINED
or state in _WORKER_STATE_CONFIRM
and thief.address not in self.scheduler.workers
):
self.log(
(
"reschedule",
thief.address not in self.scheduler.workers,
*_log_msg,
)
)
self.scheduler.reschedule(key)
return

# Victim had already started execution, reverse stealing
if state in (
"memory",
"executing",
"long-running",
"released",
"cancelled",
"resumed",
None,
):
self.log(("already-computing", key, victim.address, thief.address))
self.scheduler.check_idle_saturated(thief)
self.scheduler.check_idle_saturated(victim)

# Victim had already started execution
elif state in _WORKER_STATE_REJECT:
self.log(("already-computing", *_log_msg))
# Victim was waiting, has given up task, enact steal
elif state in ("waiting", "ready", "constrained"):
elif state in _WORKER_STATE_CONFIRM:
self.remove_key_from_stealable(ts)
ts.processing_on = thief
duration = victim.processing.pop(ts)
Expand All @@ -258,11 +290,8 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
self.scheduler.total_occupancy += d["thief_duration"]
self.put_key_in_stealable(ts)

try:
self.scheduler.send_task_to_worker(thief.address, ts)
except CommClosedError:
await self.scheduler.remove_worker(thief.address)
self.log(("confirm", key, victim.address, thief.address))
self.scheduler.send_task_to_worker(thief.address, ts)
self.log(("confirm", *_log_msg))
else:
raise ValueError(f"Unexpected task state: {state}")
except Exception as e:
Expand All @@ -273,14 +302,8 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
pdb.set_trace()
raise
finally:
try:
self.scheduler.check_idle_saturated(thief)
except Exception:
pass
try:
self.scheduler.check_idle_saturated(victim)
except Exception:
pass
self.scheduler.check_idle_saturated(thief)
self.scheduler.check_idle_saturated(victim)

def balance(self):
s = self.scheduler
Expand Down Expand Up @@ -413,9 +436,9 @@ def restart(self, scheduler):
self.key_stealable.clear()

def story(self, *keys):
keys = set(keys)
keys = {key.key if not isinstance(key, str) else key for key in keys}
out = []
for _, L in self.scheduler.get_event("stealing"):
for _, L in self.scheduler.get_events(topic="stealing"):
if not isinstance(L, list):
L = [L]
for t in L:
Expand Down
21 changes: 11 additions & 10 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,16 +1297,6 @@ async def test_non_existent_worker(c, s):
assert all(ts.state == "no-worker" for ts in s.tasks.values())


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_correct_bad_time_estimate(c, s, *workers):
future = c.submit(slowinc, 1, delay=0)
await wait(future)
futures = [c.submit(slowinc, future, delay=0.1, pure=False) for i in range(20)]
await asyncio.sleep(0.5)
await wait(futures)
assert all(w.data for w in workers), [sorted(w.data) for w in workers]


@pytest.mark.parametrize(
"host", ["tcp://0.0.0.0", "tcp://127.0.0.1", "tcp://127.0.0.1:38275"]
)
Expand Down Expand Up @@ -3191,3 +3181,14 @@ async def test_worker_reconnect_task_memory_with_resources(c, s, a):
assert ("no-worker", "memory") in {
(start, finish) for (_, start, finish, _, _) in s.transition_log
}


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_set_restrictions(c, s, a, b):

f = c.submit(inc, 1, workers=[b.address])
await f
s.set_restrictions(worker={f.key: a.address})
assert s.tasks[f.key].worker_restrictions == {a.address}
s.reschedule(f)
await f
Loading