Skip to content

Commit

Permalink
Harden decorators
Browse files Browse the repository at this point in the history
Fix flaky tests
  • Loading branch information
crusaderky committed Apr 16, 2020
1 parent 2981c31 commit faef1ba
Show file tree
Hide file tree
Showing 13 changed files with 61 additions and 60 deletions.
1 change: 0 additions & 1 deletion distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def _get_global_client():
return c
else:
del _global_clients[k]
del L
return None


Expand Down
2 changes: 1 addition & 1 deletion distributed/dashboard/tests/test_scheduler_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def test_simple(c, s, a, b):


@gen_cluster(client=True, worker_kwargs={"dashboard": True})
def test_basic(c, s, a, b):
async def test_basic(c, s, a, b):
for component in [TaskStream, SystemMonitor, Occupancy, StealingTimeSeries]:
ss = component(s)

Expand Down
4 changes: 2 additions & 2 deletions distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ async def test_avoid_churn(cleanup):
assert len(adapt.log) == 1


@gen_test(timeout=None)
@pytest.mark.asyncio
async def test_adapt_quickly():
""" We want to avoid creating and deleting workers frequently
Expand Down Expand Up @@ -332,7 +332,7 @@ def test_basic_no_loop(loop):
loop.add_callback(loop.stop)


@gen_test(timeout=None)
@pytest.mark.asyncio
async def test_target_duration():
""" Ensure that redefining adapt with a lower maximum removes workers """
with dask.config.set(
Expand Down
5 changes: 4 additions & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2293,7 +2293,10 @@ async def test_cancel_collection(c, s, a, b):
await c.cancel(x)
await c.cancel([x])
assert all(f.cancelled() for f in L)
assert not s.tasks
start = time()
while s.tasks:
assert time() < start + 1
await time.sleep(0.01)


def test_cancel(c):
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ async def f():


@gen_cluster()
def test_thread_id(s, a, b):
async def test_thread_id(s, a, b):
assert s.thread_id == a.thread_id == b.thread_id == threading.get_ident()


Expand Down
5 changes: 3 additions & 2 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
)


@gen_cluster(nthreads=[])
# FIXME why does this leave behind unclosed Comm objects?
@gen_cluster(nthreads=[], allow_unclosed=True)
async def test_nanny(s):
async with Nanny(s.address, nthreads=2, loop=s.loop) as n:
async with rpc(n.address) as nn:
Expand Down Expand Up @@ -68,7 +69,7 @@ async def test_many_kills(s):


@gen_cluster(Worker=Nanny)
def test_str(s, a, b):
async def test_str(s, a, b):
assert a.worker_address in str(a)
assert a.worker_address in repr(a)
assert str(a.nthreads) in str(a)
Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@


@gen_cluster()
def test_administration(s, a, b):
async def test_administration(s, a, b):
assert isinstance(s.address, str)
assert s.address in str(s)
assert str(sum(s.nthreads.values())) in repr(s)
Expand Down Expand Up @@ -478,7 +478,7 @@ def test_dumps_task():


@gen_cluster()
def test_ready_remove_worker(s, a, b):
async def test_ready_remove_worker(s, a, b):
s.update_graph(
tasks={"x-%d" % i: dumps_task((inc, i)) for i in range(20)},
keys=["x-%d" % i for i in range(20)],
Expand Down Expand Up @@ -1279,7 +1279,7 @@ async def test_reschedule(c, s, a, b):


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
def test_reschedule_warns(c, s, a, b):
async def test_reschedule_warns(c, s, a, b):
with captured_logger(logging.getLogger("distributed.scheduler")) as sched:
s.reschedule(key="__this-key-does-not-exist__")

Expand Down Expand Up @@ -1515,7 +1515,7 @@ def qux(x):


@gen_cluster(client=True)
def test_collect_versions(c, s, a, b):
async def test_collect_versions(c, s, a, b):
cs = s.clients[c.id]
(w1, w2) = s.workers.values()
assert cs.versions
Expand Down Expand Up @@ -1584,7 +1584,7 @@ async def f(dask_worker):


@gen_cluster()
def test_workerstate_clean(s, a, b):
async def test_workerstate_clean(s, a, b):
ws = s.workers[a.address].clean()
assert ws.address == a.address
b = pickle.dumps(ws)
Expand Down
6 changes: 2 additions & 4 deletions distributed/tests/test_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,17 +325,15 @@ async def test_oversubscribing_leases(c, s, a, b):
accept new leases as long as the semaphore is oversubscribed.
Oversubscription may occur if tasks hold the GIL for a longer time than the
lease-timeout is configured causing the lease refreshs to go stale and
timeout.
lease-timeout is configured causing the lease refresh to go stale and timeout.
We cannot protect ourselves entirely from this but we can ensure that while
a task with a timed out lease is still running, we block further
acquisitions until we return to normal.
An example would be a task which continuously locks the GIL for a longer
time than the lease timeout but this continous lock only makes up a
time than the lease timeout but this continuous lock only makes up a
fraction of the tasks runtime.
"""
# GH3705

Expand Down
4 changes: 3 additions & 1 deletion distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,9 @@ async def assert_balanced(inp, expected, c, s, *workers):
],
)
def test_balance(inp, expected):
test = lambda *args, **kwargs: assert_balanced(inp, expected, *args, **kwargs)
async def test(*args, **kwargs):
await assert_balanced(inp, expected, *args, **kwargs)

test = gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * len(inp),
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_tls_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@gen_tls_cluster(client=True)
def test_basic(c, s, a, b):
async def test_basic(c, s, a, b):
pass


Expand Down
30 changes: 8 additions & 22 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from contextlib import contextmanager
import socket
import threading
Expand Down Expand Up @@ -43,7 +42,7 @@ def test_cluster(loop):


@gen_cluster(client=True)
def test_gen_cluster(c, s, a, b):
async def test_gen_cluster(c, s, a, b):
assert isinstance(c, Client)
assert isinstance(s, Scheduler)
for w in [a, b]:
Expand All @@ -68,20 +67,25 @@ async def f(c, s, a, b):


@gen_cluster(client=False)
def test_gen_cluster_without_client(s, a, b):
async def test_gen_cluster_without_client(s, a, b):
assert isinstance(s, Scheduler)
for w in [a, b]:
assert isinstance(w, Worker)
assert s.nthreads == {w.address: w.nthreads for w in [a, b]}

async with Client(s.address, asynchronous=True) as c:
future = c.submit(lambda x: x + 1, 1)
result = await future
assert result == 2


@gen_cluster(
client=True,
scheduler="tls://127.0.0.1",
nthreads=[("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)],
security=tls_only_security(),
)
def test_gen_cluster_tls(e, s, a, b):
async def test_gen_cluster_tls(e, s, a, b):
assert isinstance(e, Client)
assert isinstance(s, Scheduler)
assert s.address.startswith("tls://")
Expand All @@ -91,11 +95,6 @@ def test_gen_cluster_tls(e, s, a, b):
assert s.nthreads == {w.address: w.nthreads for w in [a, b]}


@gen_test()
async def test_gen_test():
await asyncio.sleep(0.01)


@contextmanager
def _listen(delay=0):
serv = socket.socket()
Expand Down Expand Up @@ -177,16 +176,3 @@ def test_tls_cluster(tls_client):
async def test_tls_scheduler(security, cleanup):
async with Scheduler(security=security, host="localhost") as s:
assert s.address.startswith("tls")


@gen_cluster()
async def test_gen_cluster_async(s, a, b): # flake8: noqa
async with Client(s.address, asynchronous=True) as c:
future = c.submit(lambda x: x + 1, 1)
result = await future
assert result == 2


@gen_test()
async def test_gen_test_async(): # flake8: noqa
await asyncio.sleep(0.001)
4 changes: 2 additions & 2 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def test_worker_dir(worker):
with tmpfile() as fn:

@gen_cluster(client=True, worker_kwargs={"local_directory": fn})
def test_worker_dir(c, s, a, b):
async def test_worker_dir(c, s, a, b):
directories = [w.local_directory for w in s.workers.values()]
assert all(d.startswith(fn) for d in directories)
assert len(set(directories)) == 2 # distinct
Expand Down Expand Up @@ -1244,7 +1244,7 @@ async def test_avoid_memory_monitor_if_zero_limit(c, s):
"distributed.worker.memory.target": False,
},
)
def test_dict_data_if_no_spill_to_disk(s, w):
async def test_dict_data_if_no_spill_to_disk(s, w):
assert type(w.data) is dict


Expand Down
46 changes: 29 additions & 17 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import collections
import gc
from contextlib import contextmanager
import copy
import functools
Expand Down Expand Up @@ -33,7 +34,6 @@

import dask
from tlz import merge, memoize, assoc
from tornado import gen
from tornado.ioloop import IOLoop

from . import system
Expand Down Expand Up @@ -768,11 +768,9 @@ async def test_foo():
def _(func):
def test_func():
with clean() as loop:
if iscoroutinefunction(func):
cor = func
else:
cor = gen.coroutine(func)
loop.run_sync(cor, timeout=timeout)
if not iscoroutinefunction(func):
raise ValueError("@gen_test should wrap async def functions")
loop.run_sync(func, timeout=timeout)

return test_func

Expand Down Expand Up @@ -856,6 +854,7 @@ def gen_cluster(
active_rpc_timeout=1,
config={},
clean_kwargs={},
allow_unclosed=False,
):
from distributed import Client

Expand All @@ -878,10 +877,10 @@ async def test_foo(scheduler, worker1, worker2):
)

def _(func):
if not iscoroutinefunction(func):
func = gen.coroutine(func)

def test_func():
if not iscoroutinefunction(func):
raise ValueError("@gen_cluster should wrap async def functions")

result = None
workers = []
with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop:
Expand All @@ -905,6 +904,7 @@ async def coro():
"Failed to start gen_cluster, retrying",
exc_info=True,
)
await asyncio.sleep(1)
else:
workers[:] = ws
args = [s] + workers
Expand Down Expand Up @@ -940,16 +940,28 @@ async def coro():
else:
await c._close(fast=True)

for i in range(5):
if all(c.closed() for c in Comm._instances):
break
else:
def get_unclosed():
return [c for c in Comm._instances if not c.closed()] + [
c
for c in _global_clients.values()
if c.status != "closed"
]

try:
start = time()
while time() < start + 5:
gc.collect()
if not get_unclosed():
break
await asyncio.sleep(0.05)
else:
L = [c for c in Comm._instances if not c.closed()]
else:
if allow_unclosed:
print(f"Unclosed Comms: {get_unclosed()}")
else:
raise RuntimeError("Unclosed Comms", get_unclosed())
finally:
Comm._instances.clear()
# raise ValueError("Unclosed Comms", L)
print("Unclosed Comms", L)
_global_clients.clear()

return result

Expand Down

0 comments on commit faef1ba

Please sign in to comment.