diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index db5cc866b5c..dc481aae17f 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -202,9 +202,8 @@ async def test_gen_test_double_parametrized(foo, bar): @gen_test() -async def test_gen_test_pytest_fixture(tmp_path, c): +async def test_gen_test_pytest_fixture(tmp_path): assert isinstance(tmp_path, pathlib.Path) - assert isinstance(c, Client) @contextmanager diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 0001c5f9b54..dc3010cab90 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2,6 +2,7 @@ import asyncio import concurrent.futures +import contextlib import copy import functools import gc @@ -12,19 +13,16 @@ import multiprocessing import os import re -import shutil import signal import socket import subprocess import sys import tempfile import threading -import uuid import weakref from collections import defaultdict from collections.abc import Callable from contextlib import contextmanager, nullcontext, suppress -from glob import glob from itertools import count from time import sleep from typing import Any, Literal @@ -491,12 +489,13 @@ def run_worker(q, scheduler_q, config, **kwargs): scheduler_addr = scheduler_q.get() async def _(): + pid = os.getpid() try: worker = await Worker(scheduler_addr, validate=True, **kwargs) except Exception as exc: - q.put(exc) + q.put((pid, exc)) else: - q.put(worker.address) + q.put((pid, worker.address)) await worker.finished() # Scheduler might've failed @@ -514,12 +513,13 @@ def run_nanny(q, scheduler_q, config, **kwargs): scheduler_addr = scheduler_q.get() async def _(): + pid = os.getpid() try: worker = await Nanny(scheduler_addr, validate=True, **kwargs) except Exception as exc: - q.put(exc) + q.put((pid, exc)) else: - q.put(worker.address) + q.put((pid, worker.address)) await worker.finished() # Scheduler might've failed @@ -630,6 +630,31 @@ def security(): return tls_only_security() +def _terminate_join(proc): + proc.terminate() + proc.join() + proc.close() + + +def _close_queue(q): + q.close() + q.join_thread() + q._writer.close() # https://bugs.python.org/issue42752 + + +class _SafeTemporaryDirectory(tempfile.TemporaryDirectory): + def __exit__(self, exc_type, exc_val, exc_tb): + try: + return super().__exit__(exc_type, exc_val, exc_tb) + except PermissionError: + # It appears that we either have a process still interacting with + # the tmpdirs of the workers or that win process are not releasing + # their lock in time. We are receiving PermissionErrors during + # teardown + # See also https://github.com/dask/distributed/pull/5825 + pass + + @contextmanager def cluster( nworkers=2, @@ -649,115 +674,104 @@ def cluster( else: _run_worker = run_worker - # The scheduler queue will receive the scheduler's address - scheduler_q = mp_context.Queue() - - # Launch scheduler - scheduler = mp_context.Process( - name="Dask cluster test: Scheduler", - target=run_scheduler, - args=(scheduler_q, nworkers + 1, config), - kwargs=scheduler_kwargs, - ) - ws.add(scheduler) - scheduler.daemon = True - scheduler.start() - - # Launch workers - workers = [] - for i in range(nworkers): - q = mp_context.Queue() - fn = "_test_worker-%s" % uuid.uuid4() - kwargs = merge( - { - "nthreads": 1, - "local_directory": fn, - "memory_limit": system.MEMORY_LIMIT, - }, - worker_kwargs, - ) - proc = mp_context.Process( - name="Dask cluster test: Worker", - target=_run_worker, - args=(q, scheduler_q, config), - kwargs=kwargs, + with contextlib.ExitStack() as stack: + # The scheduler queue will receive the scheduler's address + scheduler_q = mp_context.Queue() + stack.callback(_close_queue, scheduler_q) + + # Launch scheduler + scheduler = mp_context.Process( + name="Dask cluster test: Scheduler", + target=run_scheduler, + args=(scheduler_q, nworkers + 1, config), + kwargs=scheduler_kwargs, + daemon=True, ) - ws.add(proc) - workers.append({"proc": proc, "queue": q, "dir": fn}) - - for worker in workers: - worker["proc"].start() - saddr_or_exception = scheduler_q.get() - if isinstance(saddr_or_exception, Exception): - raise saddr_or_exception - saddr = saddr_or_exception - - for worker in workers: - addr_or_exception = worker["queue"].get() - if isinstance(addr_or_exception, Exception): - raise addr_or_exception - worker["address"] = addr_or_exception - - start = time() - try: - try: - security = scheduler_kwargs["security"] - rpc_kwargs = {"connection_args": security.get_connection_args("client")} - except KeyError: - rpc_kwargs = {} - - with rpc(saddr, **rpc_kwargs) as s: - while True: - nthreads = loop.run_sync(s.ncores) - if len(nthreads) == nworkers: - break - if time() - start > 5: - raise Exception("Timeout on cluster creation") - - # avoid sending processes down to function - yield {"address": saddr}, [ - {"address": w["address"], "proc": weakref.ref(w["proc"])} - for w in workers - ] - finally: - logger.debug("Closing out test cluster") + ws.add(scheduler) + scheduler.start() + stack.callback(_terminate_join, scheduler) - loop.run_sync( - lambda: disconnect_all( - [w["address"] for w in workers], - timeout=disconnect_timeout, - rpc_kwargs=rpc_kwargs, + # Launch workers + workers_by_pid = {} + q = mp_context.Queue() + stack.callback(_close_queue, q) + for _ in range(nworkers): + tmpdirname = stack.enter_context( + _SafeTemporaryDirectory(prefix="_dask_test_worker") ) - ) - loop.run_sync( - lambda: disconnect( - saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs + kwargs = merge( + { + "nthreads": 1, + "local_directory": tmpdirname, + "memory_limit": system.MEMORY_LIMIT, + }, + worker_kwargs, ) - ) - - scheduler.terminate() - scheduler_q.close() - scheduler_q._reader.close() - scheduler_q._writer.close() - - for w in workers: - w["proc"].terminate() - w["queue"].close() - w["queue"]._reader.close() - w["queue"]._writer.close() - - scheduler.join(2) - del scheduler - for proc in [w["proc"] for w in workers]: - proc.join(timeout=30) - - with suppress(UnboundLocalError): - del worker, w, proc - del workers[:] - - for fn in glob("_test_worker-*"): - with suppress(OSError): - shutil.rmtree(fn) + proc = mp_context.Process( + name="Dask cluster test: Worker", + target=_run_worker, + args=(q, scheduler_q, config), + kwargs=kwargs, + ) + ws.add(proc) + proc.start() + stack.callback(_terminate_join, proc) + workers_by_pid[proc.pid] = {"proc": proc} + + saddr_or_exception = scheduler_q.get() + if isinstance(saddr_or_exception, Exception): + raise saddr_or_exception + saddr = saddr_or_exception + + for _ in range(nworkers): + pid, addr_or_exception = q.get() + if isinstance(addr_or_exception, Exception): + raise addr_or_exception + workers_by_pid[pid]["address"] = addr_or_exception + + start = time() + try: + try: + security = scheduler_kwargs["security"] + rpc_kwargs = { + "connection_args": security.get_connection_args("client") + } + except KeyError: + rpc_kwargs = {} + + with rpc(saddr, **rpc_kwargs) as s: + while True: + nthreads = loop.run_sync(s.ncores) + if len(nthreads) == nworkers: + break + if time() - start > 5: + raise Exception("Timeout on cluster creation") + + # avoid sending processes down to function + yield {"address": saddr}, [ + {"address": w["address"], "proc": weakref.ref(w["proc"])} + for w in workers_by_pid.values() + ] + finally: + logger.debug("Closing out test cluster") + alive_workers = [ + w["address"] + for w in workers_by_pid.values() + if w["proc"].is_alive() + ] + loop.run_sync( + lambda: disconnect_all( + alive_workers, + timeout=disconnect_timeout, + rpc_kwargs=rpc_kwargs, + ) + ) + if scheduler.is_alive(): + loop.run_sync( + lambda: disconnect( + saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs + ) + ) try: client = default_client() @@ -766,12 +780,6 @@ def cluster( else: client.close() - start = time() - while any(proc.is_alive() for proc in ws): - text = str(list(ws)) - sleep(0.2) - assert time() < start + 5, ("Workers still around after five seconds", text) - async def disconnect(addr, timeout=3, rpc_kwargs=None): rpc_kwargs = rpc_kwargs or {}