Skip to content

Commit

Permalink
execute() to return StateMachineEvent
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 28, 2022
1 parent b68f6f4 commit fdb59c5
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 80 deletions.
7 changes: 7 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Execute,
ReleaseWorkerDataMsg,
SendMessageToScheduler,
StateMachineEvent,
TaskState,
UniqueTaskHeap,
)
Expand Down Expand Up @@ -99,3 +100,9 @@ def test_sendmsg_to_dict():
# Arbitrary sample class
smsg = ReleaseWorkerDataMsg(key="x")
assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"}


@pytest.mark.parametrize("cls", StateMachineEvent.__subclasses__())
def test_event_slots(cls):
smsg = cls(**dict.fromkeys(cls.__annotations__), stimulus_id="test")
assert not hasattr(smsg, "__dict__")
223 changes: 144 additions & 79 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import bisect
import builtins
import errno
import functools
import heapq
import logging
import os
Expand All @@ -15,7 +16,6 @@
import weakref
from collections import defaultdict, deque
from collections.abc import (
Awaitable,
Callable,
Collection,
Container,
Expand Down Expand Up @@ -109,13 +109,18 @@
PROCESSING,
READY,
AddKeysMsg,
CancelComputeEvent,
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
InvalidTransition,
LongRunningMsg,
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
SendMessageToScheduler,
SerializedTask,
StateMachineEvent,
TaskErredMsg,
TaskFinishedMsg,
TaskState,
Expand All @@ -135,7 +140,9 @@
# {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)}
Recs: TypeAlias = "dict[TaskState, TaskStateState | tuple]"
Instructions: TypeAlias = "list[Instruction]"

else:
Recs = dict
Instructions = list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1775,14 +1782,7 @@ def handle_cancel_compute(self, key: str, stimulus_id: str) -> None:
is in state `waiting` or `ready`.
Nothing will happen otherwise.
"""
ts = self.tasks.get(key)
if ts and ts.state in READY | {"waiting"}:
self.log.append((key, "cancel-compute", stimulus_id, time()))
# All possible dependents of TS should not be in state Processing on
# scheduler side and therefore should not be assigned to a worker,
# yet.
assert not ts.dependents
self.transition(ts, "released", stimulus_id=stimulus_id)
self.handle_stimulus(CancelComputeEvent(key=key, stimulus_id=stimulus_id))

def handle_acquire_replicas(
self,
Expand Down Expand Up @@ -2597,31 +2597,38 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None:
else:
self._handle_instructions(instructions)

def handle_stimulus(self, stim: StateMachineEvent) -> None:
with log_errors():
# self.stimulus_history.append(stim) # TODO
recs, instructions = self.handle_event(stim)
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)
self.ensure_computing()
self.ensure_communicating()

def _handle_stimulus_from_future(
self, future: asyncio.Future[StateMachineEvent | None]
) -> None:
with log_errors():
# This *should* never raise
stim = future.result()
if stim:
self.handle_stimulus(stim)

def _handle_instructions(self, instructions: list[Instruction]) -> None:
# TODO this method is temporary.
# See final design: https://github.com/dask/distributed/issues/5894
for inst in instructions:
if isinstance(inst, SendMessageToScheduler):
self.batched_stream.send(inst.to_dict())
elif isinstance(inst, Execute):
self.loop.add_callback(
self._async_instruction_callback,
self.execute(inst.key, stimulus_id=inst.stimulus_id),
stimulus_id=inst.stimulus_id,
)
coro = self.execute(inst.key, stimulus_id=inst.stimulus_id)
task = asyncio.create_task(coro)
# TODO track task (at the moment it's fire-and-forget)
task.add_done_callback(self._handle_stimulus_from_future)
else:
raise TypeError(inst) # pragma: nocover

async def _async_instruction_callback(
self, coro: Awaitable[tuple[Recs, Instructions]], *, stimulus_id: str
) -> None:
with log_errors():
recs, instructions = await coro
self.transitions(recs, stimulus_id=stimulus_id)
self._handle_instructions(instructions)
self.ensure_computing()
self.ensure_communicating()

def maybe_transition_long_running(
self, ts: TaskState, *, compute_duration: float, stimulus_id: str
):
Expand Down Expand Up @@ -3410,23 +3417,16 @@ def ensure_computing(self) -> None:
pdb.set_trace()
raise

async def execute(self, key: str, *, stimulus_id: str) -> tuple[Recs, Instructions]:
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
return {}, []
if key not in self.tasks:
return {}, []
ts = self.tasks[key]
return None
ts = self.tasks.get(key)
if not ts:
return None
if ts.state == "cancelled":
return CancelComputeEvent(key=ts.key, stimulus_id=stimulus_id)

try:
if ts.state == "cancelled":
# This might happen if keys are canceled
logger.debug(
"Trying to execute task %s which is not in executing state anymore",
ts,
)
ts.done = True
return {ts: "released"}, []

if self.validate:
assert not ts.waiting_for_data
assert ts.state == "executing"
Expand Down Expand Up @@ -3482,59 +3482,124 @@ async def execute(self, key: str, *, stimulus_id: str) -> tuple[Recs, Instructio
finally:
self.active_keys.discard(ts.key)

key = ts.key
# key *must* be still in tasks. Releasing it directly is forbidden
# without going through cancelled
ts = self.tasks.get(key) # type: ignore
assert ts, self.story(key)
ts.done = True
result["key"] = ts.key
value = result.pop("result", None)
ts.startstops.append(
{"action": "compute", "start": result["start"], "stop": result["stop"]}
)
self.threads[ts.key] = result["thread"]
recommendations: Recs = {}

if result["op"] == "task-finished":
ts.nbytes = result["nbytes"]
ts.type = result["type"]
recommendations[ts] = ("memory", value)
if self.digests is not None:
self.digests["task-duration"].add(result["stop"] - result["start"])
elif isinstance(result.pop("actual-exception"), Reschedule):
recommendations[ts] = "rescheduled"
else:
logger.warning(
"Compute Failed\n"
"Key: %s\n"
"Function: %s\n"
"args: %s\n"
"kwargs: %s\n"
"Exception: %r\n",
ts.key,
str(funcname(function))[:1000],
convert_args_to_str(args2, max_len=1000),
convert_kwargs_to_str(kwargs2, max_len=1000),
result["exception_text"],
)
recommendations[ts] = (
"error",
result["exception"],
result["traceback"],
result["exception_text"],
result["traceback_text"],
return ExecuteSuccessEvent(
key=ts.key,
value=result["result"],
start=result["start"],
stop=result["stop"],
nbytes=result["nbytes"],
type=result["type"],
stimulus_id=stimulus_id,
)

logger.debug("Send compute response to scheduler: %s, %s", ts.key, result)
if isinstance(result["actual-exception"], Reschedule):
return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id)

logger.warning(
"Compute Failed\n"
"Key: %s\n"
"Function: %s\n"
"args: %s\n"
"kwargs: %s\n"
"Exception: %r\n",
ts.key,
str(funcname(function))[:1000],
convert_args_to_str(args2, max_len=1000),
convert_kwargs_to_str(kwargs2, max_len=1000),
result["exception_text"],
)
return ExecuteFailureEvent(
key=ts.key,
start=result["start"],
stop=result["stop"],
exception=result["exception"],
traceback=result["traceback"],
exception_text=result["exception_text"],
traceback_text=result["traceback_text"],
stimulus_id=stimulus_id,
)

except Exception as exc:
logger.error(
"Exception during execution of task %s.", ts.key, exc_info=True
)
msg = error_message(exc)
recommendations = {ts: tuple(msg.values())}
return ExecuteFailureEvent(
key=ts.key,
start=None,
stop=None,
exception=msg["exception"],
traceback=msg["traceback"],
exception_text=msg["exception_text"],
traceback_text=msg["traceback_text"],
stimulus_id=stimulus_id,
)

return recommendations, []
@functools.singledispatchmethod
def handle_event(self, ev: StateMachineEvent) -> tuple[Recs, Instructions]:
raise TypeError(ev) # pragma: nocover

@handle_event.register
def _(self, ev: CancelComputeEvent) -> tuple[Recs, Instructions]:
ts = self.tasks.get(ev.key)
if not ts or ts.state not in READY | {"waiting"}:
return {}, []

self.log.append((ev.key, "cancel-compute", ev.stimulus_id, time()))
# All possible dependents of ts should not be in state Processing on
# scheduler side and therefore should not be assigned to a worker, yet.
assert not ts.dependents
ts.done = True
return {ts: "released"}, []

@handle_event.register
def _(self, ev: ExecuteSuccessEvent) -> tuple[Recs, Instructions]:
# key *must* be still in tasks. Releasing it directly is forbidden
# without going through cancelled
ts = self.tasks.get(ev.key) # type: ignore
assert ts, self.story(ev.key)

ts.done = True
ts.startstops.append({"action": "compute", "start": ev.start, "stop": ev.stop})
ts.nbytes = ev.nbytes
ts.type = ev.type
return {ts: ("memory", ev.value)}, []

@handle_event.register
def _(self, ev: ExecuteFailureEvent) -> tuple[Recs, Instructions]:
# key *must* be still in tasks. Releasing it directly is forbidden
# without going through cancelled
ts = self.tasks.get(ev.key) # type: ignore
assert ts, self.story(ev.key)

ts.done = True
if ev.start is not None and ev.stop is not None:
ts.startstops.append(
{"action": "compute", "start": ev.start, "stop": ev.stop}
)

return {
ts: (
"error",
ev.exception,
ev.traceback,
ev.exception_text,
ev.traceback_text,
)
}, []

@handle_event.register
def _(self, ev: RescheduleEvent) -> tuple[Recs, Instructions]:
# key *must* be still in tasks. Releasing it directly is forbidden
# without going through cancelled
ts = self.tasks.get(ev.key) # type: ignore
assert ts, self.story(ev.key)
return {ts: "rescheduled"}, []

def _prepare_args_for_execution(
self, ts: TaskState, args: tuple, kwargs: dict[str, Any]
Expand Down
44 changes: 43 additions & 1 deletion distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,11 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler):
key: str


# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
@dataclass
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"

# Not to be confused with the distributed.Reschedule Exception
__slots__ = ("key", "worker")
key: str
worker: str
Expand All @@ -347,3 +347,45 @@ class AddKeysMsg(SendMessageToScheduler):
__slots__ = ("keys", "stimulus_id")
keys: list[str]
stimulus_id: str


@dataclass
class StateMachineEvent:
__slots__ = ("stimulus_id",)
stimulus_id: str


@dataclass
class ExecuteSuccessEvent(StateMachineEvent):
key: str
value: object
start: float
stop: float
nbytes: int
type: type | None
__slots__ = tuple(__annotations__) # type: ignore


@dataclass
class ExecuteFailureEvent(StateMachineEvent):
key: str
start: float | None
stop: float | None
exception: bytes # serialized
traceback: bytes # serialized
exception_text: str
traceback_text: str
__slots__ = tuple(__annotations__) # type: ignore


@dataclass
class CancelComputeEvent(StateMachineEvent):
__slots__ = ("key",)
key: str


# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception
@dataclass
class RescheduleEvent(StateMachineEvent):
__slots__ = ("key",)
key: str

0 comments on commit fdb59c5

Please sign in to comment.