Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 5, 2022
1 parent 434d614 commit 41d6e88
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 50 deletions.
37 changes: 14 additions & 23 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,29 +264,25 @@ async def test_nanny_timeout(c, s, a):


@gen_cluster(
nthreads=[("127.0.0.1", 1)] * 8,
nthreads=[("", 1)] * 8,
client=True,
Worker=Worker,
clean_kwargs={"threads": False},
config={"distributed.worker.memory.pause": False},
)
async def test_throttle_outgoing_connections(c, s, a, *workers):
# But a bunch of small data on worker a
await c.run(lambda: logging.getLogger("distributed.worker").setLevel(logging.DEBUG))
async def test_throttle_outgoing_connections(c, s, a, *other_workers):
# Put a bunch of small data on worker a
logging.getLogger("distributed.worker").setLevel(logging.DEBUG)
remote_data = c.map(
lambda x: b"0" * 10000, range(10), pure=False, workers=[a.address]
)
await wait(remote_data)

def pause(dask_worker):
# Disable memory_monitor on the worker
dask_worker.extensions["memory_monitor"].stop()
dask_worker.status = Status.paused
dask_worker.outgoing_current_count = 2
a.status = Status.paused
a.outgoing_current_count = 2

await c.run(pause, workers=[a.address])
requests = [
await a.get_data(await w.rpc.connect(w.address), keys=[f.key], who=w.address)
for w in workers
for w in other_workers
for f in remote_data
]
await wait(requests)
Expand All @@ -295,18 +291,13 @@ def pause(dask_worker):
assert "throttling" in wlogs.lower()


@gen_cluster(nthreads=[], client=True)
async def test_scheduler_address_config(c, s):
@gen_cluster(nthreads=[])
async def test_scheduler_address_config(s):
with dask.config.set({"scheduler-address": s.address}):
nanny = await Nanny(loop=s.loop)
assert nanny.scheduler.address == s.address

start = time()
while not s.workers:
await asyncio.sleep(0.1)
assert time() < start + 10

await nanny.close()
async with Nanny() as nanny:
assert nanny.scheduler.address == s.address
while not s.workers:
await asyncio.sleep(0.01)


@pytest.mark.slow
Expand Down
24 changes: 12 additions & 12 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
from dask import delayed
from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename

from distributed import Client, Nanny, Worker, fire_and_forget, wait
from distributed import (
Client,
Lock,
Nanny,
SchedulerPlugin,
Worker,
fire_and_forget,
wait,
)
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import ConnectionPool, Status, clean_exception, connect, rpc
from distributed.metrics import time
Expand Down Expand Up @@ -2178,7 +2186,6 @@ async def test_gather_allow_worker_reconnect(
"""
# GH3246
if reschedule_different_worker:
from distributed.diagnostics.plugin import SchedulerPlugin

class SwitchRestrictions(SchedulerPlugin):
def __init__(self, scheduler):
Expand All @@ -2191,8 +2198,6 @@ def transition(self, key, start, finish, **kwargs):
plugin = SwitchRestrictions(s)
s.add_plugin(plugin)

from distributed import Lock

b_address = b.address

def inc_slow(x, lock):
Expand All @@ -2214,8 +2219,9 @@ def reducer(*args):
def finalizer(addr):
if swap_data_insert_order:
w = get_worker()
new_data = {k: w.data[k] for k in list(w.data.keys())[::-1]}
w.data = new_data
new_data = dict(reversed(list(w.data.items())))
w.data.clear()
w.data.update(new_data)
return addr

z = c.submit(reducer, x, key="reducer", workers=[a.address])
Expand Down Expand Up @@ -3384,9 +3390,6 @@ async def test_TaskState__to_dict(c, s):

@gen_cluster(nthreads=[])
async def test_idempotent_plugins(s):

from distributed.diagnostics.plugin import SchedulerPlugin

class IdempotentPlugin(SchedulerPlugin):
def __init__(self, instance=None):
self.name = "idempotentplugin"
Expand All @@ -3410,9 +3413,6 @@ def start(self, scheduler):

@gen_cluster(nthreads=[])
async def test_non_idempotent_plugins(s):

from distributed.diagnostics.plugin import SchedulerPlugin

class NonIdempotentPlugin(SchedulerPlugin):
def __init__(self, instance=None):
self.name = "nonidempotentplugin"
Expand Down
17 changes: 11 additions & 6 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,10 @@ async def test_override_data_worker(s):
async with Worker(s.address, data=dict) as w:
assert type(w.data) is dict

data = dict()
data = {"x": 1}
async with Worker(s.address, data=data) as w:
assert w.data is data
assert w.data == {"x": 1}

class Data(dict):
def __init__(self, x, y):
Expand All @@ -530,11 +531,15 @@ def __init__(self, x, y):
assert w.data.y == 456


@gen_cluster(nthreads=[], client=True)
async def test_override_data_nanny(c, s):
async with Nanny(s.address, data=dict) as n:
r = await c.run(lambda dask_worker: type(dask_worker.data))
assert r[n.worker_address] is dict
@gen_cluster(
client=True,
nthreads=[("", 1)],
Worker=Nanny,
worker_kwargs={"data": dict},
)
async def test_override_data_nanny(c, s, n):
r = await c.run(lambda dask_worker: type(dask_worker.data))
assert r[n.worker_address] is dict


@gen_cluster(
Expand Down
17 changes: 8 additions & 9 deletions distributed/worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class WorkerMemoryManager:
memory_spill_fraction: float | Literal[False]
memory_pause_fraction: float | Literal[False]
max_spill: int | Literal[False]
memory_monitor_interval: float | None
memory_monitor_interval: float
_memory_monitoring: bool
_throttled_gc: ThrottledGC

Expand Down Expand Up @@ -217,11 +217,10 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
"Worker is at %.0f%% memory usage. Start spilling data to disk.",
frac * 100,
)
# Implement hysteresis cycle where spilling starts at the spill threshold
# and stops at the target threshold. Normally that here the target threshold
# defines process memory, whereas normally it defines reported managed
# memory (e.g. output of sizeof() ).
# If target=False, disable hysteresis.
# Implement hysteresis cycle where spilling starts at the spill threshold and
# stops at the target threshold. Normally that here the target threshold defines
# process memory, whereas normally it defines reported managed memory (e.g.
# output of sizeof() ). If target=False, disable hysteresis.
target = self.memory_limit * (
self.memory_target_fraction or self.memory_spill_fraction
)
Expand Down Expand Up @@ -296,15 +295,15 @@ def __init__(
def memory_monitor(self, nanny: Nanny) -> None:
"""Track worker's memory. Restart if it goes above terminate fraction."""
if nanny.status != Status.running:
return
return # pragma: nocover
if nanny.process is None or nanny.process.process is None:
return
return # pragma: nocover
process = nanny.process.process
try:
proc = nanny._psutil_process
memory = proc.memory_info().rss
except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied):
return
return # pragma: nocover

if memory / self.memory_limit > self.memory_terminate_fraction:
logger.warning(
Expand Down

0 comments on commit 41d6e88

Please sign in to comment.