diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 4bf83d939e7..934e52c4fe7 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -264,29 +264,25 @@ async def test_nanny_timeout(c, s, a): @gen_cluster( - nthreads=[("127.0.0.1", 1)] * 8, + nthreads=[("", 1)] * 8, client=True, - Worker=Worker, clean_kwargs={"threads": False}, + config={"distributed.worker.memory.pause": False}, ) -async def test_throttle_outgoing_connections(c, s, a, *workers): - # But a bunch of small data on worker a - await c.run(lambda: logging.getLogger("distributed.worker").setLevel(logging.DEBUG)) +async def test_throttle_outgoing_connections(c, s, a, *other_workers): + # Put a bunch of small data on worker a + logging.getLogger("distributed.worker").setLevel(logging.DEBUG) remote_data = c.map( lambda x: b"0" * 10000, range(10), pure=False, workers=[a.address] ) await wait(remote_data) - def pause(dask_worker): - # Disable memory_monitor on the worker - dask_worker.extensions["memory_monitor"].stop() - dask_worker.status = Status.paused - dask_worker.outgoing_current_count = 2 + a.status = Status.paused + a.outgoing_current_count = 2 - await c.run(pause, workers=[a.address]) requests = [ await a.get_data(await w.rpc.connect(w.address), keys=[f.key], who=w.address) - for w in workers + for w in other_workers for f in remote_data ] await wait(requests) @@ -295,18 +291,13 @@ def pause(dask_worker): assert "throttling" in wlogs.lower() -@gen_cluster(nthreads=[], client=True) -async def test_scheduler_address_config(c, s): +@gen_cluster(nthreads=[]) +async def test_scheduler_address_config(s): with dask.config.set({"scheduler-address": s.address}): - nanny = await Nanny(loop=s.loop) - assert nanny.scheduler.address == s.address - - start = time() - while not s.workers: - await asyncio.sleep(0.1) - assert time() < start + 10 - - await nanny.close() + async with Nanny() as nanny: + assert nanny.scheduler.address == s.address + while not s.workers: + await asyncio.sleep(0.01) @pytest.mark.slow diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 5968ebc7d26..c0f9b84aad5 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -19,7 +19,15 @@ from dask import delayed from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename -from distributed import Client, Nanny, Worker, fire_and_forget, wait +from distributed import ( + Client, + Lock, + Nanny, + SchedulerPlugin, + Worker, + fire_and_forget, + wait, +) from distributed.compatibility import LINUX, WINDOWS from distributed.core import ConnectionPool, Status, clean_exception, connect, rpc from distributed.metrics import time @@ -2178,7 +2186,6 @@ async def test_gather_allow_worker_reconnect( """ # GH3246 if reschedule_different_worker: - from distributed.diagnostics.plugin import SchedulerPlugin class SwitchRestrictions(SchedulerPlugin): def __init__(self, scheduler): @@ -2191,8 +2198,6 @@ def transition(self, key, start, finish, **kwargs): plugin = SwitchRestrictions(s) s.add_plugin(plugin) - from distributed import Lock - b_address = b.address def inc_slow(x, lock): @@ -2214,8 +2219,9 @@ def reducer(*args): def finalizer(addr): if swap_data_insert_order: w = get_worker() - new_data = {k: w.data[k] for k in list(w.data.keys())[::-1]} - w.data = new_data + new_data = dict(reversed(list(w.data.items()))) + w.data.clear() + w.data.update(new_data) return addr z = c.submit(reducer, x, key="reducer", workers=[a.address]) @@ -3384,9 +3390,6 @@ async def test_TaskState__to_dict(c, s): @gen_cluster(nthreads=[]) async def test_idempotent_plugins(s): - - from distributed.diagnostics.plugin import SchedulerPlugin - class IdempotentPlugin(SchedulerPlugin): def __init__(self, instance=None): self.name = "idempotentplugin" @@ -3410,9 +3413,6 @@ def start(self, scheduler): @gen_cluster(nthreads=[]) async def test_non_idempotent_plugins(s): - - from distributed.diagnostics.plugin import SchedulerPlugin - class NonIdempotentPlugin(SchedulerPlugin): def __init__(self, instance=None): self.name = "nonidempotentplugin" diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 59bed9ed4e6..094cbe15a75 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -516,9 +516,10 @@ async def test_override_data_worker(s): async with Worker(s.address, data=dict) as w: assert type(w.data) is dict - data = dict() + data = {"x": 1} async with Worker(s.address, data=data) as w: assert w.data is data + assert w.data == {"x": 1} class Data(dict): def __init__(self, x, y): @@ -530,11 +531,15 @@ def __init__(self, x, y): assert w.data.y == 456 -@gen_cluster(nthreads=[], client=True) -async def test_override_data_nanny(c, s): - async with Nanny(s.address, data=dict) as n: - r = await c.run(lambda dask_worker: type(dask_worker.data)) - assert r[n.worker_address] is dict +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=Nanny, + worker_kwargs={"data": dict}, +) +async def test_override_data_nanny(c, s, n): + r = await c.run(lambda dask_worker: type(dask_worker.data)) + assert r[n.worker_address] is dict @gen_cluster( diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index 20cf1db2b7b..0add5bfd78f 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -57,7 +57,7 @@ class WorkerMemoryManager: memory_spill_fraction: float | Literal[False] memory_pause_fraction: float | Literal[False] max_spill: int | Literal[False] - memory_monitor_interval: float | None + memory_monitor_interval: float _memory_monitoring: bool _throttled_gc: ThrottledGC @@ -217,11 +217,10 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None: "Worker is at %.0f%% memory usage. Start spilling data to disk.", frac * 100, ) - # Implement hysteresis cycle where spilling starts at the spill threshold - # and stops at the target threshold. Normally that here the target threshold - # defines process memory, whereas normally it defines reported managed - # memory (e.g. output of sizeof() ). - # If target=False, disable hysteresis. + # Implement hysteresis cycle where spilling starts at the spill threshold and + # stops at the target threshold. Normally that here the target threshold defines + # process memory, whereas normally it defines reported managed memory (e.g. + # output of sizeof() ). If target=False, disable hysteresis. target = self.memory_limit * ( self.memory_target_fraction or self.memory_spill_fraction ) @@ -296,15 +295,15 @@ def __init__( def memory_monitor(self, nanny: Nanny) -> None: """Track worker's memory. Restart if it goes above terminate fraction.""" if nanny.status != Status.running: - return + return # pragma: nocover if nanny.process is None or nanny.process.process is None: - return + return # pragma: nocover process = nanny.process.process try: proc = nanny._psutil_process memory = proc.memory_info().rss except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): - return + return # pragma: nocover if memory / self.memory_limit > self.memory_terminate_fraction: logger.warning(