Skip to content

Commit

Permalink
Optimize transitions (#4451)
Browse files Browse the repository at this point in the history
* Annotate `keys` & `new`

* Combine `recommendations` annotation with others

* Annotate `dependents` & `dependencies`

* Annotate `start` & `finish`

* Create empty `dict` for `recommendations` once

Instead of letting Cython generate empty `dict`s for each of these
cases, just create an empty `dict` once and assign it to
`recommendations`. That way we can just `return` it simply and avoid the
C boilerplate that would otherwise be needed.

* Use `.get(...)` to retrieve `TaskState`

* Assign `start, finish` to a variable

* Just use `.get(...)` to retrieve transition func

Avoids checking for the presence of the key and then retrieving the
function corresponding the key by simply trying to get the function in
the first place or `None` if it is absent. As it is pretty quick to
check if something is `None` both in Python and Cython, this should
speed up the check and function retrieval time.

* Annotate `a` & `b`

* Use `.get(...)` to get `key` from `a`

Avoids looking up `key` twice. Once to see if it is there and a second
time to grab it. This way we just grab the value corresponding to `key`
or `None` if it is missing. The following `None` check is quite fast in
both Python and Cython.

* Just `update` `recommendations` with `a` & `b`

* Drop unneeded `KeyError` handling

Neither of these statements should raise a `KeyError`. So just drop this
`try...except...`.

* Annotate `finish2`

* Replace generator with simple `for`-loop

This avoids building a `list`, which makes it easier for Cython to
optimize.

* Bind `tuple` results to typed variable

This should simplify the C code generated by Cython to unpack the
`tuple` as it no longer needs to check if it is a `list` or some other
sequence that needs to be unpacked and can simply use the `tuple`
unpacking logic.

* Collect `list` of messages for clients and workers

* Extend `BatchedSend`'s `send` to take many msgs

* Add `send_all` method and use in `transition`

This allows us to batch all worker and client sends into a single
function.

* Deliver all messages to batched send

* Refactor out private `_transition` function

* Send all messages after processing all transitions

* `declare` `ALL_TASK_STATES` a `set`
  • Loading branch information
jakirkham authored Feb 2, 2021
1 parent 5eaba1a commit 98570fb
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 57 deletions.
6 changes: 3 additions & 3 deletions distributed/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ def _background_send(self):
self.stopped.set()
self.abort()

def send(self, msg):
def send(self, *msgs):
"""Schedule a message for sending to the other side
This completes quickly and synchronously
"""
if self.comm is not None and self.comm.closed():
raise CommClosedError

self.message_count += 1
self.buffer.append(msg)
self.message_count += len(msgs)
self.buffer.extend(msgs)
# Avoid spurious wakeups if possible
if self.next_deadline is None:
self.waker.set()
Expand Down
225 changes: 171 additions & 54 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,10 @@ def nogil(func):
EventExtension,
]

ALL_TASK_STATES = {"released", "waiting", "no-worker", "processing", "erred", "memory"}
ALL_TASK_STATES = declare(
set, {"released", "waiting", "no-worker", "processing", "erred", "memory"}
)
globals()["ALL_TASK_STATES"] = ALL_TASK_STATES


@final
Expand Down Expand Up @@ -1961,7 +1964,7 @@ def transition_waiting_processing(self, key):

# logger.debug("Send job to worker: %s, %s", worker, key)

worker_msgs[worker] = _task_to_msg(self, ts)
worker_msgs[worker] = [_task_to_msg(self, ts)]

return {}, worker_msgs, client_msgs
except Exception as e:
Expand Down Expand Up @@ -2168,11 +2171,13 @@ def transition_memory_released(self, key, safe: bint = False):
ws._has_what.remove(ts)
ws._nbytes -= ts.get_nbytes()
ts._group._nbytes_in_memory -= ts.get_nbytes()
worker_msgs[ws._address] = {
"op": "delete-data",
"keys": [key],
"report": False,
}
worker_msgs[ws._address] = [
{
"op": "delete-data",
"keys": [key],
"report": False,
}
]

ts._who_has.clear()

Expand All @@ -2181,7 +2186,7 @@ def transition_memory_released(self, key, safe: bint = False):
report_msg = {"op": "lost-data", "key": key}
cs: ClientState
for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

if not ts._run_spec: # pure data
recommendations[key] = "forgotten"
Expand Down Expand Up @@ -2234,7 +2239,7 @@ def transition_released_erred(self, key):
}
cs: ClientState
for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

ts.state = "erred"

Expand Down Expand Up @@ -2276,7 +2281,7 @@ def transition_erred_released(self, key):
report_msg = {"op": "task-retried", "key": key}
cs: ClientState
for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

ts.state = "released"

Expand Down Expand Up @@ -2343,7 +2348,7 @@ def transition_processing_released(self, key):

w: str = _remove_from_processing(self, ts)
if w:
worker_msgs[w] = {"op": "release-task", "key": key}
worker_msgs[w] = [{"op": "release-task", "key": key}]

ts.state = "released"

Expand Down Expand Up @@ -2432,7 +2437,7 @@ def transition_processing_erred(
}
cs: ClientState
for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

cs = self._clients["fire-and-forget"]
if ts in cs._wants_what:
Expand Down Expand Up @@ -4706,6 +4711,29 @@ def client_send(self, client, msg):
if self.status == Status.running:
logger.critical("Tried writing to closed comm: %s", msg)

def send_all(self, client_msgs: dict, worker_msgs: dict):
"""Send messages to client and workers"""
stream_comms: dict = self.stream_comms
client_comms: dict = self.client_comms
msgs: list

for worker, msgs in worker_msgs.items():
try:
w = stream_comms[worker]
w.send(*msgs)
except (CommClosedError, AttributeError):
self.loop.add_callback(self.remove_worker, address=worker)

for client, msgs in client_msgs.items():
c = client_comms.get(client)
if c is None:
continue
try:
c.send(*msgs)
except CommClosedError:
if self.status == Status.running:
logger.critical("Tried writing to closed comm: %s", msgs)

############################
# Less common interactions #
############################
Expand Down Expand Up @@ -5814,12 +5842,12 @@ async def register_worker_plugin(self, comm, plugin, name=None):
# State Transitions #
#####################

def transition(self, key, finish, *args, **kwargs):
def _transition(self, key, finish: str, *args, **kwargs):
"""Transition a key from its current state to the finish state
Examples
--------
>>> self.transition('x', 'waiting')
>>> self._transition('x', 'waiting')
{'x': 'processing'}
Returns
Expand All @@ -5832,47 +5860,85 @@ def transition(self, key, finish, *args, **kwargs):
"""
parent: SchedulerState = cast(SchedulerState, self)
ts: TaskState
start: str
start_finish: tuple
finish2: str
recommendations: dict
worker_msgs: dict
client_msgs: dict
msgs: list
new_msgs: list
dependents: set
dependencies: set
try:
try:
ts = parent._tasks[key]
except KeyError:
return {}
recommendations = {}
worker_msgs = {}
client_msgs = {}

ts = parent._tasks.get(key)
if ts is None:
return recommendations, worker_msgs, client_msgs
start = ts._state
if start == finish:
return {}
return recommendations, worker_msgs, client_msgs

if self.plugins:
dependents = set(ts._dependents)
dependencies = set(ts._dependencies)

recommendations: dict = {}
worker_msgs = {}
client_msgs = {}
if (start, finish) in self._transitions:
func = self._transitions[start, finish]
recommendations, worker_msgs, client_msgs = func(key, *args, **kwargs)
elif "released" not in (start, finish):
start_finish = (start, finish)
func = self._transitions.get(start_finish)
if func is not None:
a: tuple = func(key, *args, **kwargs)
recommendations, worker_msgs, client_msgs = a
elif "released" not in start_finish:
func = self._transitions["released", finish]
assert not args and not kwargs
a = self.transition(key, "released")
if key in a:
func = self._transitions["released", a[key]]
b, worker_msgs, client_msgs = func(key)
a = a.copy()
a.update(b)
recommendations = a
a_recs: dict
a_wmsgs: dict
a_cmsgs: dict
a: tuple = self._transition(key, "released")
a_recs, a_wmsgs, a_cmsgs = a
v = a_recs.get(key)
if v is not None:
func = self._transitions["released", v]
b_recs: dict
b_wmsgs: dict
b_cmsgs: dict
b: tuple = func(key)
b_recs, b_wmsgs, b_cmsgs = b

recommendations.update(a_recs)
for w, new_msgs in a_wmsgs.items():
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs
for c, new_msgs in a_cmsgs.items():
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs

recommendations.update(b_recs)
for w, new_msgs in b_wmsgs.items():
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs
for c, new_msgs in b_cmsgs.items():
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs

start = "released"
else:
raise RuntimeError(
"Impossible transition from %r to %r" % (start, finish)
)

for worker, msg in worker_msgs.items():
self.worker_send(worker, msg)
for client, msg in client_msgs.items():
self.client_send(client, msg)
raise RuntimeError("Impossible transition from %r to %r" % start_finish)

finish2 = ts._state
self.transition_log.append((key, start, finish2, recommendations, time()))
Expand All @@ -5888,11 +5954,8 @@ def transition(self, key, finish, *args, **kwargs):
if self.plugins:
# Temporarily put back forgotten key for plugin to retrieve it
if ts._state == "forgotten":
try:
ts._dependents = dependents
ts._dependencies = dependencies
except KeyError:
pass
ts._dependents = dependents
ts._dependencies = dependencies
parent._tasks[ts._key] = ts
for plugin in list(self.plugins):
try:
Expand All @@ -5905,11 +5968,16 @@ def transition(self, key, finish, *args, **kwargs):
tg: TaskGroup = ts._group
if ts._state == "forgotten" and tg._name in parent._task_groups:
# Remove TaskGroup if all tasks are in the forgotten state
if not any([tg._states.get(s) for s in ALL_TASK_STATES]):
all_forgotten: bint = True
for s in ALL_TASK_STATES:
if tg._states.get(s):
all_forgotten = False
break
if all_forgotten:
ts._prefix._groups.remove(tg)
del parent._task_groups[tg._name]

return recommendations
return recommendations, worker_msgs, client_msgs
except Exception as e:
logger.exception("Error transitioning %r from %r to %r", key, start, finish)
if LOG_PDB:
Expand All @@ -5918,20 +5986,69 @@ def transition(self, key, finish, *args, **kwargs):
pdb.set_trace()
raise

def transition(self, key, finish: str, *args, **kwargs):
"""Transition a key from its current state to the finish state
Examples
--------
>>> self.transition('x', 'waiting')
{'x': 'processing'}
Returns
-------
Dictionary of recommendations for future transitions
See Also
--------
Scheduler.transitions: transitive version of this function
"""
recommendations: dict
worker_msgs: dict
client_msgs: dict
a: tuple = self._transition(key, finish, *args, **kwargs)
recommendations, worker_msgs, client_msgs = a
self.send_all(client_msgs, worker_msgs)
return recommendations

def transitions(self, recommendations: dict):
"""Process transitions until none are left
This includes feedback from previous transitions and continues until we
reach a steady state
"""
parent: SchedulerState = cast(SchedulerState, self)
keys = set()
keys: set = set()
recommendations = recommendations.copy()
worker_msgs: dict = {}
client_msgs: dict = {}
msgs: list
new_msgs: list
new: tuple
new_recs: dict
new_wmsgs: dict
new_cmsgs: dict
while recommendations:
key, finish = recommendations.popitem()
keys.add(key)
new = self.transition(key, finish)
recommendations.update(new)

new = self._transition(key, finish)
new_recs, new_wmsgs, new_cmsgs = new

recommendations.update(new_recs)
for w, new_msgs in new_wmsgs.items():
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs
for c, new_msgs in new_cmsgs.items():
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs

self.send_all(client_msgs, worker_msgs)

if parent._validate:
for key in keys:
Expand Down Expand Up @@ -6513,7 +6630,7 @@ def _add_to_memory(
report_msg["type"] = type

for cs in ts._who_wants:
client_msgs[cs._client_key] = report_msg
client_msgs[cs._client_key] = [report_msg]

ts.state = "memory"
ts._type = typename
Expand Down Expand Up @@ -6567,7 +6684,7 @@ def _propagate_forgotten(
ws._nbytes -= ts.get_nbytes()
w: str = ws._address
if w in state._workers_dv: # in case worker has died
worker_msgs[w] = {"op": "delete-data", "keys": [key], "report": False}
worker_msgs[w] = [{"op": "delete-data", "keys": [key], "report": False}]
ts._who_has.clear()


Expand Down Expand Up @@ -6674,7 +6791,7 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict:

client_msgs: dict = {}
for k in client_keys:
client_msgs[k] = report_msg
client_msgs[k] = [report_msg]

return client_msgs

Expand Down

0 comments on commit 98570fb

Please sign in to comment.