Skip to content

Commit

Permalink
Migrate ensure_executing to WorkerState event mechanism (dask#6003)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 30, 2022
1 parent cced80d commit 8d09d74
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 152 deletions.
2 changes: 1 addition & 1 deletion distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ async def handle_comm(self, comm):
"Failed while closing connection to %r: %s", address, e
)

async def handle_stream(self, comm, extra=None, every_cycle=[]):
async def handle_stream(self, comm, extra=None, every_cycle=()):
extra = extra or {}
logger.info("Starting established connection")

Expand Down
31 changes: 27 additions & 4 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from itertools import chain

import pytest

from distributed.utils import recursive_to_dict
from distributed.worker_state_machine import (
Instruction,
ReleaseWorkerDataMsg,
SendMessageToScheduler,
StateMachineEvent,
TaskState,
UniqueTaskHeap,
)
Expand Down Expand Up @@ -82,13 +86,32 @@ def test_unique_task_heap():
assert repr(heap) == "<UniqueTaskHeap: 0 items>"


@pytest.mark.parametrize("cls", SendMessageToScheduler.__subclasses__())
def test_sendmsg_slots(cls):
smsg = cls(**dict.fromkeys(cls.__annotations__))
assert not hasattr(smsg, "__dict__")
@pytest.mark.parametrize(
"cls",
chain(
[UniqueTaskHeap],
Instruction.__subclasses__(),
SendMessageToScheduler.__subclasses__(),
StateMachineEvent.__subclasses__(),
),
)
def test_slots(cls):
params = [
k
for k in dir(cls)
if not k.startswith("_") and k != "op" and not callable(getattr(cls, k))
]
inst = cls(**dict.fromkeys(params))
assert not hasattr(inst, "__dict__")


def test_sendmsg_to_dict():
# Arbitrary sample class
smsg = ReleaseWorkerDataMsg(key="x")
assert smsg.to_dict() == {"op": "release-worker-data", "key": "x"}


@pytest.mark.parametrize("cls", StateMachineEvent.__subclasses__())
def test_event_slots(cls):
smsg = cls(**dict.fromkeys(cls.__annotations__), stimulus_id="test")
assert not hasattr(smsg, "__dict__")
Loading

0 comments on commit 8d09d74

Please sign in to comment.