From 59385090c73b7d562cc754a9a60034ea8b3e9f46 Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Tue, 9 Jun 2020 15:32:59 -0700 Subject: [PATCH 1/6] Emmit warning when assign/comparing string with Status Enum. Turn this into an error in the test suite. We want to make sure of various things: 1) original behavior of comparing/assign with string still works but emit warnings. 1b) assigning strings convert to proper enum variant. 2) assign/comparison with invalid strings should fail 2) warnings are errors in test suite to make sure we don't re-introduce strings. This is the continuation of #3853 Maybe Cluster in cluster.py should also get status as a property --- distributed/core.py | 20 ++++- distributed/deploy/cluster.py | 11 +-- distributed/deploy/spec.py | 55 ++++++++----- distributed/deploy/ssh.py | 3 +- distributed/deploy/tests/test_local.py | 5 +- distributed/deploy/tests/test_spec_cluster.py | 9 ++- distributed/nanny.py | 8 +- distributed/scheduler.py | 15 +--- distributed/tests/test_client.py | 13 ++-- distributed/tests/test_core.py | 77 ++++++++++++++++++- distributed/tests/test_nanny.py | 14 ++-- distributed/tests/test_scheduler.py | 22 +++--- distributed/tests/test_tls_functional.py | 3 +- distributed/tests/test_worker.py | 16 ++-- distributed/utils_test.py | 4 +- distributed/worker.py | 4 +- setup.cfg | 8 +- 17 files changed, 195 insertions(+), 92 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 16c611142f4..8d5ac0af333 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -9,6 +9,7 @@ import traceback import uuid import weakref +import warnings import dask import tblib @@ -52,11 +53,13 @@ class Status(Enum): closing = "closing" closing_gracefully = "closing-gracefully" init = "init" + created = "created" running = "running" starting = "starting" stopped = "stopped" stopping = "stopping" undefined = None + dont_reply = "dont-reply" def __eq__(self, other): """ @@ -69,6 +72,11 @@ def __eq__(self, other): if isinstance(other, type(self)): return self.value == other.value elif isinstance(other, str) or (other is None): + warnings.warn( + f"Since distributed 2.19 `.status` is now an Enum, please compare with `Status.{other}`", + PendingDeprecationWarning, + stacklevel=1, + ) assert other in [ s.value for s in type(self) ], f"comparison with non-existing states {other}" @@ -261,9 +269,16 @@ def status(self, new_status): if isinstance(new_status, Status): self._status = new_status elif isinstance(new_status, str) or new_status is None: + warnings.warn( + f"Since distributed 2.19 `.status` is now an Enum, please assign `Status.{new_status}`", + PendingDeprecationWarning, + stacklevel=1, + ) corresponding_enum_variants = [s for s in Status if s.value == new_status] assert len(corresponding_enum_variants) == 1 self._status = corresponding_enum_variants[0] + else: + raise TypeError(f"expected Status or str, got {new_status}") async def finished(self): """ Wait until the server has finished """ @@ -519,7 +534,10 @@ async def handle_comm(self, comm, shutting_down=shutting_down): logger.exception(e) result = error_message(e, status="uncaught-error") - if reply and result != "dont-reply": + # result is not type stable: + # when LHS is not Status then RHS must not be Status or it raises. + # when LHS is Status then RHS must be status or it raises in tests + if reply and isinstance(result, Status) and result != Status.dont_reply: try: await comm.write(result, serializers=serializers) except (EnvironmentError, TypeError) as e: diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 3d23e62051b..511733476c6 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -10,6 +10,7 @@ from .adaptive import Adaptive +from ..core import Status from ..utils import ( log_errors, sync, @@ -57,7 +58,7 @@ def __init__(self, asynchronous): self._watch_worker_status_task = None self.scheduler_comm = None - self.status = "created" + self.status = Status.created async def _start(self): comm = await self.scheduler_comm.live_comm() @@ -67,10 +68,10 @@ async def _start(self): self._watch_worker_status_task = asyncio.ensure_future( self._watch_worker_status(comm) ) - self.status = "running" + self.status = Status.running async def _close(self): - if self.status == "closed": + if self.status == Status.closed: return if self._watch_worker_status_comm: @@ -84,14 +85,14 @@ async def _close(self): if self.scheduler_comm: await self.scheduler_comm.close_rpc() - self.status = "closed" + self.status = Status.closed def close(self, timeout=None): with suppress(RuntimeError): # loop closed during process shutdown return self.sync(self._close, callback_timeout=timeout) def __del__(self): - if self.status != "closed": + if self.status != Status.closed: with suppress(AttributeError, RuntimeError): # during closing self.loop.add_callback(self.close) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 6d9e1677a36..595a8057982 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -36,19 +36,39 @@ class ProcessInterface: It should implement the methods below, like ``start`` and ``close`` """ + @property + def status(self): + return self._status + + @status.setter + def status(self, new_status): + if isinstance(new_status, Status): + self._status = new_status + elif isinstance(new_status, str) or new_status is None: + warnings.warn( + f"Since distributed 2.19 `.status` is now an Enum, please assign `Status.{new_status}`", + PendingDeprecationWarning, + stacklevel=1, + ) + corresponding_enum_variants = [s for s in Status if s.value == new_status] + assert len(corresponding_enum_variants) == 1 + self._status = corresponding_enum_variants[0] + else: + raise TypeError(f"expected Status or str, got {new_status}") + def __init__(self, scheduler=None, name=None): self.address = getattr(self, "address", None) self.external_address = None self.lock = asyncio.Lock() - self.status = "created" + self.status = Status.created self._event_finished = asyncio.Event() def __await__(self): async def _(): async with self.lock: - if self.status == "created": + if self.status == Status.created: await self.start() - assert self.status == "running" + assert self.status == Status.running return self return _().__await__() @@ -63,7 +83,7 @@ async def start(self): For the scheduler we will expect the scheduler's ``.address`` attribute to be avaialble after this completes. """ - self.status = "running" + self.status = Status.running async def close(self): """ Close the process @@ -73,7 +93,7 @@ async def close(self): This method should kill the process a bit more forcefully and does not need to worry about shutting down gracefully """ - self.status = "closed" + self.status = Status.closed self._event_finished.set() async def finished(self): @@ -256,11 +276,11 @@ def __init__( self.sync(self._correct_state) async def _start(self): - while self.status == "starting": + while self.status == Status.starting: await asyncio.sleep(0.01) - if self.status == "running": + if self.status == Status.running: return - if self.status == "closed": + if self.status == Status.closed: raise ValueError("Cluster is closed") self._lock = asyncio.Lock() @@ -279,7 +299,7 @@ async def _start(self): cls = import_term(cls) self.scheduler = cls(**self.scheduler_spec.get("options", {})) - self.status = "starting" + self.status = Status.starting self.scheduler = await self.scheduler self.scheduler_comm = rpc( getattr(self.scheduler, "external_address", None) or self.scheduler.address, @@ -359,7 +379,7 @@ def f(): def __await__(self): async def _(): - if self.status == "created": + if self.status == Status.created: await self._start() await self.scheduler await self._correct_state() @@ -370,13 +390,12 @@ async def _(): return _().__await__() async def _close(self): - while self.status == "closing": + while self.status == Status.closing: await asyncio.sleep(0.1) - if self.status == "closed": + if self.status == Status.closed: return - if self.status == "running": - self.status = "closing" - + if self.status == Status.running: + self.status = Status.closing self.scale(0) await self._correct_state() for future in self._futures: @@ -402,7 +421,7 @@ async def _close(self): async def __aenter__(self): await self await self._correct_state() - assert self.status == "running" + assert self.status == Status.running return self def __exit__(self, typ, value, traceback): @@ -453,7 +472,7 @@ def scale(self, n=0, memory=None, cores=None): while len(self.worker_spec) > n: self.worker_spec.popitem() - if self.status not in ("closing", "closed"): + if self.status not in (Status.closing, Status.closed): while len(self.worker_spec) < n: self.worker_spec.update(self.new_worker_spec()) @@ -617,5 +636,5 @@ async def run_spec(spec: dict, *args): def close_clusters(): for cluster in list(SpecCluster._instances): with suppress(gen.TimeoutError, TimeoutError): - if cluster.status != "closed": + if cluster.status != Status.closed: cluster.close(timeout=10) diff --git a/distributed/deploy/ssh.py b/distributed/deploy/ssh.py index 595e21dbd7a..66afde1aa7c 100644 --- a/distributed/deploy/ssh.py +++ b/distributed/deploy/ssh.py @@ -7,6 +7,7 @@ import dask from .spec import SpecCluster, ProcessInterface +from ..core import Status from ..utils import cli_keywords from ..scheduler import Scheduler as _Scheduler from ..worker import Worker as _Worker @@ -130,7 +131,7 @@ async def start(self): logger.info(line.strip()) if "worker at" in line: self.address = line.split("worker at:")[1].strip() - self.status = "running" + self.status = Status.running break logger.debug("%s", line) await super().start() diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 59a2d0c7607..08b0bf26dd5 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -16,6 +16,7 @@ from dask.system import CPU_COUNT from distributed import Client, Worker, Nanny, get_client +from distributed.core import Status from distributed.deploy.local import LocalCluster, nprocesses_nthreads from distributed.metrics import time from distributed.system import MEMORY_LIMIT @@ -188,7 +189,7 @@ def test_Client_with_local(loop): def test_Client_solo(loop): with Client(loop=loop, silence_logs=False) as c: pass - assert c.cluster.status == "closed" + assert c.cluster.status == Status.closed @gen_test() @@ -223,7 +224,7 @@ def test_Client_kwargs(loop): with Client(loop=loop, processes=False, n_workers=2, silence_logs=False) as c: assert len(c.cluster.workers) == 2 assert all(isinstance(w, Worker) for w in c.cluster.workers.values()) - assert c.cluster.status == "closed" + assert c.cluster.status == Status.closed def test_Client_unused_kwargs_with_cluster(loop): diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index f4d6c69827b..6c573f5c7c1 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -5,6 +5,7 @@ import dask from dask.distributed import SpecCluster, Worker, Client, Scheduler, Nanny +from distributed.core import Status from distributed.compatibility import WINDOWS from distributed.deploy.spec import close_clusters, ProcessInterface, run_spec from distributed.metrics import time @@ -237,7 +238,7 @@ def test_spec_close_clusters(loop): cluster = SpecCluster(workers=workers, scheduler=scheduler, loop=loop) assert cluster in SpecCluster._instances close_clusters() - assert cluster.status == "closed" + assert cluster.status == Status.closed @pytest.mark.asyncio @@ -267,11 +268,11 @@ async def test_nanny_port(): @pytest.mark.asyncio async def test_spec_process(): proc = ProcessInterface() - assert proc.status == "created" + assert proc.status == Status.created await proc - assert proc.status == "running" + assert proc.status == Status.running await proc.close() - assert proc.status == "closed" + assert proc.status == Status.closed @pytest.mark.asyncio diff --git a/distributed/nanny.py b/distributed/nanny.py index 84ce01ffebb..db29431211b 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -293,7 +293,7 @@ async def start(self): logger.info(" Start Nanny at: %r", self.address) response = await self.instantiate() - if response == "running": + if response == Status.running: assert self.worker_address self.status = Status.running else: @@ -316,7 +316,7 @@ async def kill(self, comm=None, timeout=2): deadline = self.loop.time() + timeout await self.process.kill(timeout=0.8 * (deadline - self.loop.time())) - async def instantiate(self, comm=None): + async def instantiate(self, comm=None) -> Status: """ Start a local worker process Blocks until the process is up and the scheduler is properly informed @@ -535,7 +535,7 @@ def __init__( self.worker_dir = None self.worker_address = None - async def start(self): + async def start(self) -> Status: """ Ensure the worker process is started. """ @@ -584,7 +584,7 @@ async def start(self): self.worker_address = msg["address"] self.worker_dir = msg["dir"] assert self.worker_address - self.status = "running" + self.status = Status.running self.running.set() init_q.close() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8c469d2ed38..c24b3be447d 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -315,6 +315,8 @@ def status(self, new_status): corresponding_enum_variants = [s for s in Status if s.value == new_status] assert len(corresponding_enum_variants) == 1 self._status = corresponding_enum_variants[0] + else: + raise TypeError(f"expected Status or str, got {new_status}") @property def host(self): @@ -1404,19 +1406,6 @@ def __init__( self.rpc.allow_offload = False self.status = Status.undefined - @property - def status(self): - return self._status - - @status.setter - def status(self, new_status): - if isinstance(new_status, Status): - self._status = new_status - elif isinstance(new_status, str) or new_status is None: - corresponding_enum_variants = [s for s in Status if s.value == new_status] - assert len(corresponding_enum_variants) == 1 - self._status = corresponding_enum_variants[0] - ################## # Administration # ################## diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 98f731e14de..11dd56ab636 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -40,6 +40,7 @@ TimeoutError, CancelledError, ) +from distributed.core import Status from distributed.comm import CommClosedError from distributed.client import ( Client, @@ -3690,7 +3691,7 @@ async def start_worker(sleep, duration, repeat=1): [c.sync(w.close) for w in list(workers)] for w in workers: - assert w.status == "closed" + assert w.status == Status.closed start = time() while proc.num_fds() > before: @@ -4625,7 +4626,7 @@ async def test_retire_workers(c, s, a, b): assert set(s.workers) == {b.address} start = time() - while a.status != "closed": + while a.status != Status.closed: await asyncio.sleep(0.01) assert time() < start + 5 @@ -5829,8 +5830,8 @@ async def test_shutdown(cleanup): async with Client(s.address, asynchronous=True) as c: await c.shutdown() - assert s.status == "closed" - assert w.status == "closed" + assert s.status == Status.closed + assert w.status == Status.closed @pytest.mark.asyncio @@ -5839,7 +5840,7 @@ async def test_shutdown_localcluster(cleanup): async with Client(lc, asynchronous=True) as c: await c.shutdown() - assert lc.scheduler.status == "closed" + assert lc.scheduler.status == Status.closed @pytest.mark.asyncio @@ -5978,7 +5979,7 @@ async def f(): assert result == 11 assert client.status == "closed" - assert cluster.status == "closed" + assert cluster.status == Status.closed def test_client_sync_with_async_def(loop): diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 62b141645c7..b43b143721f 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -4,6 +4,7 @@ import sys import threading import weakref +import warnings import pytest @@ -11,6 +12,7 @@ from distributed.core import ( pingpong, Server, + Status, rpc, connect, send_recv, @@ -76,6 +78,79 @@ def echo_no_serialize(comm, x): return {"result": x} +def test_server_status_is_always_enum(): + """ + Assignments with strings get converted to corresponding Enum variant + """ + server = Server({}) + assert isinstance(server.status, Status) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore") + assert server.status != Status.stopped + server.status = "stopped" + assert isinstance(server.status, Status) + assert server.status == Status.stopped + + +def test_server_status_assign_non_variant_raises(): + server = Server({}) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("ignore") + with pytest.raises(AssertionError): + server.status = "I do not exists" + + +def test_server_status_compare_non_variant_raises(): + server = Server({}) + # turn off warnings into error for assertion checking. + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("default") + with pytest.raises(AssertionError): + server.status == "You can't compare with me" + + +def test_server_status_assign_with_variant_warns(): + server = Server({}) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("default") + with pytest.warns(PendingDeprecationWarning): + server.status = "running" + + +def test_server_status_compare_with_variant_warns(): + server = Server({}) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("default") + with pytest.warns(PendingDeprecationWarning): + server.status == "running" + + +def test_server_status_assign_with_variant_raises_in_tests(): + """That would be the default in user code""" + server = Server({}) + with pytest.raises(PendingDeprecationWarning): + server.status = "running" + + +def test_server_status_compare_with_variant_raises_in_tests(): + """That would be the default in user code""" + server = Server({}) + with pytest.raises(PendingDeprecationWarning): + server.status == "running" + + +def test_server_assign_assign_enum_is_quiet(): + """That would be the default in user code""" + server = Server({}) + server.status = Status.running + + +def test_server_status_compare_enum_is_quiet(): + """That would be the default in user code""" + server = Server({}) + server.status == Status.running + + def test_server(loop): """ Simple Server test. @@ -269,7 +344,7 @@ async def check_rpc(listen_addr, rpc_addr=None, listen_args={}, connection_args= assert response == b"pong" assert not remote.comms - assert remote.status == "closed" + assert remote.status == Status.closed server.stop() await asyncio.sleep(0) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 88f401e2cf4..f6277aa980b 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -16,7 +16,7 @@ import dask from distributed.diagnostics import SchedulerPlugin from distributed import Nanny, rpc, Scheduler, Worker, Client, wait, worker -from distributed.core import CommClosedError +from distributed.core import CommClosedError, Status from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.utils import tmpfile, TimeoutError, parse_ports @@ -140,8 +140,8 @@ async def test_no_hang_when_scheduler_closes(s, a, b): with captured_logger("tornado.application", logging.ERROR) as logger: await s.close() await asyncio.sleep(1.2) - assert a.status == "closed" - assert b.status == "closed" + assert a.status == Status.closed + assert b.status == Status.closed out = logger.getvalue() assert "Timed out trying to connect" not in out @@ -155,7 +155,7 @@ async def test_close_on_disconnect(s, w): await s.close() start = time() - while w.status != "closed": + while w.status != Status.closed: await asyncio.sleep(0.05) assert time() < start + 9 @@ -187,7 +187,7 @@ async def test_nanny_death_timeout(s): with pytest.raises(TimeoutError): await w - assert w.status == "closed" + assert w.status == Status.closed @gen_cluster(client=True, Worker=Nanny) @@ -489,11 +489,11 @@ async def test_nanny_closes_cleanly(cleanup): with client.rpc(n.worker_address) as w: IOLoop.current().add_callback(w.terminate) start = time() - while n.status != "closed": + while n.status != Status.closed: await asyncio.sleep(0.01) assert time() < start + 5 - assert n.status == "closed" + assert n.status == Status.closed @pytest.mark.asyncio diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a55a7213919..333c5099dd2 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1127,7 +1127,7 @@ async def test_close_nanny(c, s, a, b): assert not a.is_alive() assert a.pid is None - while a.status != "closed": + while a.status != Status.closed: await asyncio.sleep(0.05) assert time() < start + 10 @@ -1136,7 +1136,7 @@ async def test_close_nanny(c, s, a, b): async def test_retire_workers_close(c, s, a, b): await s.retire_workers(close_workers=True) assert not s.workers - while a.status != "closed" and b.status != "closed": + while a.status != Status.closed and b.status != Status.closed: await asyncio.sleep(0.01) @@ -1148,7 +1148,7 @@ async def test_retire_nannies_close(c, s, a, b): start = time() - while any(n.status != "closed" for n in nannies): + while any(n.status != Status.closed for n in nannies): await asyncio.sleep(0.05) assert time() < start + 10 @@ -1543,7 +1543,7 @@ async def test_closing_scheduler_closes_workers(s, a, b): await s.close() start = time() - while a.status != "closed" or b.status != "closed": + while a.status != Status.closed or b.status != Status.closed: await asyncio.sleep(0.01) assert time() < start + 2 @@ -1613,16 +1613,16 @@ async def test_idle_timeout(c, s, a, b): await future assert s.idle_since is None or s.idle_since > beginning - assert s.status != "closed" + assert s.status != Status.closed with captured_logger("distributed.scheduler") as logs: start = time() - while s.status != "closed": + while s.status != Status.closed: await asyncio.sleep(0.01) assert time() < start + 3 start = time() - while not (a.status == "closed" and b.status == "closed"): + while not (a.status == Status.closed and b.status == Status.closed): await asyncio.sleep(0.01) assert time() < start + 1 @@ -1686,8 +1686,8 @@ async def test_result_type(c, s, a, b): @gen_cluster() async def test_close_workers(s, a, b): await s.close(close_workers=True) - assert a.status == "closed" - assert b.status == "closed" + assert a.status == Status.closed + assert b.status == Status.closed @pytest.mark.skipif( @@ -1741,9 +1741,9 @@ async def test_adaptive_target(c, s, a, b): @pytest.mark.asyncio async def test_async_context_manager(cleanup): async with Scheduler(port=0) as s: - assert s.status == "running" + assert s.status == Status.running async with Worker(s.address) as w: - assert w.status == "running" + assert w.status == Status.running assert s.workers assert not s.workers diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 7e74f74e09c..3002b0a2c43 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -6,6 +6,7 @@ import pytest from distributed import Scheduler, Worker, Client, Nanny, worker_client, Queue +from distributed.core import Status from distributed.client import wait from distributed.metrics import time from distributed.nanny import Nanny @@ -178,7 +179,7 @@ async def test_retire_workers(c, s, a, b): assert set(s.workers) == {b.worker_address} start = time() - while a.status != "closed": + while a.status != Status.closed: await asyncio.sleep(0.01) assert time() < start + 5 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 268ade2602f..c13fed73ab5 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -27,7 +27,7 @@ wait, ) from distributed.compatibility import WINDOWS -from distributed.core import rpc, CommClosedError +from distributed.core import rpc, CommClosedError, Status from distributed.scheduler import Scheduler from distributed.metrics import time from distributed.worker import ( @@ -343,7 +343,7 @@ async def test_worker_waits_for_scheduler(cleanup): pass else: assert False - assert w.status not in ("closed", "running") + assert w.status not in (Status.closed, Status.running) await w.close(timeout=0.1) @@ -534,7 +534,7 @@ async def test_close_on_disconnect(s, w): await s.close() start = time() - while w.status != "closed": + while w.status != Status.closed: await asyncio.sleep(0.01) assert time() < start + 5 @@ -801,7 +801,7 @@ async def test_worker_death_timeout(s): assert "Worker" in str(info.value) assert "timed out" in str(info.value) or "failed to start" in str(info.value) - assert w.status == "closed" + assert w.status == Status.closed @gen_cluster(client=True) @@ -1569,7 +1569,7 @@ async def test_close_gracefully(c, s, a, b): await b.close_gracefully() - assert b.status == "closed" + assert b.status == Status.closed assert b.address not in s.workers assert mem.issubset(set(a.data)) for key in proc: @@ -1584,7 +1584,7 @@ async def test_lifetime(cleanup): async with Client(s.address, asynchronous=True) as c: futures = c.map(slowinc, range(200), delay=0.1) await asyncio.sleep(1.5) - assert b.status != "running" + assert b.status != Status.running await b.finished() assert set(b.data).issubset(a.data) # successfully moved data over @@ -1648,7 +1648,7 @@ def bad_heartbeat_worker(*args, **kwargs): await w.heartbeat() if reconnect: - assert w.status == "running" + assert w.status == Status.running else: - assert w.status == "closed" + assert w.status == Status.closed assert "Heartbeat to scheduler failed" in logger.getvalue() diff --git a/distributed/utils_test.py b/distributed/utils_test.py index ab1992b5ad9..4daf84e804c 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -914,7 +914,7 @@ async def coro(): s.validate_state() finally: if client and c.status not in ("closing", "closed"): - await c._close(fast=s.status == "closed") + await c._close(fast=s.status == Status.closed) await end_cluster(s, workers) await asyncio.wait_for(cleanup_global_workers(), 1) @@ -1505,7 +1505,7 @@ def check_instances(): ), {n: n.status for n in Nanny._instances} # assert not list(SpecCluster._instances) # TODO - assert all(c.status == "closed" for c in SpecCluster._instances), list( + assert all(c.status == Status.closed for c in SpecCluster._instances), list( SpecCluster._instances ) SpecCluster._instances.clear() diff --git a/distributed/worker.py b/distributed/worker.py index 00d69329c40..f8bf992157a 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1313,7 +1313,7 @@ async def get_data( } ) - return "dont-reply" + return Status.dont_reply ################### # Local Execution # @@ -3071,7 +3071,7 @@ def get_worker(): return thread_state.execution_state["worker"] except AttributeError: try: - return first(w for w in Worker._instances if w.status == "running") + return first(w for w in Worker._instances if w.status == Status.running) except StopIteration: raise ValueError("No workers found") diff --git a/setup.cfg b/setup.cfg index 764ac7ad02c..fbee38f9fea 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,14 +37,10 @@ parentdir_prefix = distributed- [tool:pytest] addopts = -rsx --durations=10 +filterwarnings = + error:Since distributed 2.19.*:PendingDeprecationWarning minversion = 3.2 markers = slow: marks tests as slow (deselect with '-m "not slow"') avoid_travis: marks tests as flaky on TravisCI. ipython: mark a test as exercising IPython - -# filterwarnings = -# error -# ignore::UserWarning -# ignore::ImportWarning -# ignore::PendingDeprecationWarning From 57e2fdbc757e535371b91fd30c008f7ee4480d65 Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Fri, 19 Jun 2020 12:17:34 -0700 Subject: [PATCH 2/6] try to fix tests --- distributed/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index 8d5ac0af333..b9668939d2e 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -537,7 +537,8 @@ async def handle_comm(self, comm, shutting_down=shutting_down): # result is not type stable: # when LHS is not Status then RHS must not be Status or it raises. # when LHS is Status then RHS must be status or it raises in tests - if reply and isinstance(result, Status) and result != Status.dont_reply: + + if reply and result != Status.dont_reply.value: try: await comm.write(result, serializers=serializers) except (EnvironmentError, TypeError) as e: From 03dbcea81cc3ab378c47c8b3ce92eaf1b71b0085 Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Fri, 19 Jun 2020 13:12:45 -0700 Subject: [PATCH 3/6] more fixes --- distributed/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/distributed/core.py b/distributed/core.py index b9668939d2e..35f98e25a19 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -537,8 +537,11 @@ async def handle_comm(self, comm, shutting_down=shutting_down): # result is not type stable: # when LHS is not Status then RHS must not be Status or it raises. # when LHS is Status then RHS must be status or it raises in tests + is_dont_reply = False + if isinstance(result, Status) and (result == Status.dont_reply): + is_dont_reply = True - if reply and result != Status.dont_reply.value: + if reply and not is_dont_reply: try: await comm.write(result, serializers=serializers) except (EnvironmentError, TypeError) as e: From d8636bc0a544565c86dc4f80fbdf2845420cbba9 Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Fri, 19 Jun 2020 20:50:44 -0700 Subject: [PATCH 4/6] flake8 --- distributed/deploy/spec.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 595a8057982..1d1c449f7f7 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -5,6 +5,7 @@ import logging import math import weakref +import warnings import dask from tornado import gen From bebc56855dad6938ca7a0a17042b77c44da2ef24 Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Thu, 23 Jul 2020 18:37:08 -0700 Subject: [PATCH 5/6] allow packing status --- distributed/protocol/core.py | 16 +++++++++++++--- distributed/protocol/serialize.py | 24 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 15d9bd24e97..a997d979101 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -5,7 +5,15 @@ import msgpack from .compression import compressions, maybe_compress, decompress -from .serialize import serialize, deserialize, Serialize, Serialized, extract_serialize +from .serialize import ( + serialize, + deserialize, + Serialize, + Serialized, + extract_serialize, + msgpack_decode_default, + msgpack_encode_default, +) from .utils import frame_split_size, merge_frames, msgpack_opts from ..utils import nbytes @@ -161,7 +169,7 @@ def dumps_msgpack(msg): loads_msgpack """ header = {} - payload = msgpack.dumps(msg, use_bin_type=True) + payload = msgpack.dumps(msg, default=msgpack_encode_default, use_bin_type=True) fmt, payload = maybe_compress(payload) if fmt: @@ -183,7 +191,9 @@ def loads_msgpack(header, payload): """ header = bytes(header) if header: - header = msgpack.loads(header, use_list=False, **msgpack_opts) + header = msgpack.loads( + header, object_hook=msgpack_decode_default, use_list=False, **msgpack_opts + ) else: header = {} diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index fb285bd1f00..1be2b3b44d4 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -64,6 +64,30 @@ def pickle_loads(header, frames): return pickle.loads(x, buffers=buffers) +def msgpack_decode_default(obj): + """ + Custom packer/unpackerfor msgpack to support enums + """ + # avoid circular import + from ..core import Status + + if "__Status__" in obj: + obj = getattr(Status, obj["as_str"]) + return obj + + +def msgpack_encode_default(obj): + """ + Custom packer/unpackerfor msgpack to support enums + """ + # avoid circular import + from ..core import Status + + if isinstance(obj, Status): + return {"__Status__": True, "as_str": obj.name} + return obj + + def msgpack_dumps(x): try: frame = msgpack.dumps(x, use_bin_type=True) From 6f299516df229b5cac877463ae649fb66386508a Mon Sep 17 00:00:00 2001 From: Matthias Bussonnier Date: Thu, 23 Jul 2020 18:47:28 -0700 Subject: [PATCH 6/6] Generalise packer/unpacker to allow arbitrary Enum. --- distributed/protocol/serialize.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 1be2b3b44d4..7d767514c6d 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -1,5 +1,7 @@ from functools import partial import traceback +import importlib +from enum import Enum import dask from dask.base import normalize_token @@ -66,25 +68,27 @@ def pickle_loads(header, frames): def msgpack_decode_default(obj): """ - Custom packer/unpackerfor msgpack to support enums + Custom packer/unpacker for msgpack to support Enums """ - # avoid circular import - from ..core import Status - - if "__Status__" in obj: - obj = getattr(Status, obj["as_str"]) + if "__Enum__" in obj: + mod = importlib.import_module(obj["__module__"]) + enum_type = getattr(mod, obj["__name__"]) + obj = getattr(enum_type, obj["name"]) return obj def msgpack_encode_default(obj): """ - Custom packer/unpackerfor msgpack to support enums + Custom packer/unpacker for msgpack to support Enums """ - # avoid circular import - from ..core import Status - if isinstance(obj, Status): - return {"__Status__": True, "as_str": obj.name} + if isinstance(obj, Enum): + return { + "__Enum__": True, + "name": obj.name, + "__module__": obj.__module__, + "__name__": type(obj).__name__, + } return obj