diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index 2e536a79663..7cff9083064 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -23,6 +23,13 @@ jobs: activate-environment: dask-distributed auto-activate-base: false + - name: Install contextvars + shell: bash -l {0} + run: | + if [[ "${{ matrix.python-version }}" = "3.6" ]]; then + conda install -c conda-forge contextvars + fi + - name: Install tornado shell: bash -l {0} run: | diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index e362ea7f079..8cc862cfadd 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -67,6 +67,10 @@ conda create -n dask-distributed -c conda-forge -c defaults \ source activate dask-distributed +if [[ $PYTHON == 3.6 ]]; then + conda install -c conda-forge -c defaults contextvars +fi + # stacktrace is not currently avaiable for Python 3.8. # Remove the version check block below when it is avaiable. if [[ $PYTHON != 3.8 ]]; then diff --git a/distributed/client.py b/distributed/client.py index 52c0e2b420e..18d77ffab9b 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import contextmanager +from contextvars import ContextVar import copy import errno from functools import partial @@ -89,6 +90,7 @@ _global_clients = weakref.WeakValueDictionary() _global_client_index = [0] +_current_client = ContextVar("_current_client", default=None) DEFAULT_EXTENSIONS = [PubSubClientExtension] @@ -162,7 +164,7 @@ def __init__(self, key, client=None, inform=True, state=None): self.key = key self._cleared = False tkey = tokey(key) - self.client = client or _get_global_client() + self.client = client or Client.current() self.client._inc_ref(tkey) self._generation = self.client.generation @@ -353,11 +355,14 @@ def release(self, _in_destructor=False): pass # Shutting down, add_callback may be None def __getstate__(self): - return (self.key, self.client.scheduler.address) + return self.key, self.client.scheduler.address def __setstate__(self, state): key, address = state - c = get_client(address) + try: + c = Client.current(allow_global=False) + except ValueError: + c = get_client(address) Future.__init__(self, key, c) c._send_to_scheduler( { @@ -727,10 +732,41 @@ def __init__( ReplayExceptionClient(self) + @contextmanager + def as_current(self): + """Thread-local, Task-local context manager that causes the Client.current class + method to return self. Any Future objects deserialized inside this context + manager will be automatically attached to this Client. + """ + # In Python 3.6, contextvars are thread-local but not Task-local. + # We can still detect a race condition though. + if sys.version_info < (3, 7) and _current_client.get() not in (self, None): + raise RuntimeError( + "Detected race condition where multiple asynchronous clients tried " + "entering the as_current() context manager at the same time. " + "Please upgrade to Python 3.7+." + ) + + tok = _current_client.set(self) + try: + yield + finally: + _current_client.reset(tok) + @classmethod - def current(cls): - """ Return global client if one exists, otherwise raise ValueError """ - return default_client() + def current(cls, allow_global=True): + """When running within the context of `as_client`, return the context-local + current client. Otherwise, return the latest initialised Client. + If no Client instances exist, raise ValueError. + If allow_global is set to False, raise ValueError if running outside of the + `as_client` context manager. + """ + out = _current_client.get() + if out: + return out + if allow_global: + return default_client() + raise ValueError("Not running inside the `as_current` context manager") @property def asynchronous(self): @@ -2178,8 +2214,7 @@ def retry(self, futures, asynchronous=None): """ return self.sync(self._retry, futures, asynchronous=asynchronous) - @gen.coroutine - def _publish_dataset(self, *args, name=None, **kwargs): + async def _publish_dataset(self, *args, name=None, **kwargs): with log_errors(): coroutines = [] @@ -2205,7 +2240,7 @@ def add_coro(name, data): for name, data in kwargs.items(): add_coro(name, data) - yield coroutines + await asyncio.gather(*coroutines) def publish_dataset(self, *args, **kwargs): """ @@ -2285,13 +2320,12 @@ def list_datasets(self, **kwargs): return self.sync(self.scheduler.publish_list, **kwargs) async def _get_dataset(self, name): - out = await self.scheduler.publish_get(name=name, client=self.id) - if out is None: - raise KeyError("Dataset '%s' not found" % name) + with self.as_current(): + out = await self.scheduler.publish_get(name=name, client=self.id) - with temp_default_client(self): - data = out["data"] - return data + if out is None: + raise KeyError(f"Dataset '{name}' not found") + return out["data"] def get_dataset(self, name, **kwargs): """ @@ -4697,6 +4731,14 @@ def __exit__(self, typ, value, traceback): def temp_default_client(c): """ Set the default client for the duration of the context + .. note:: + This function should be used exclusively for unit testing the default client + functionality. In all other cases, please use ``Client.as_current`` instead. + + .. note:: + Unlike ``Client.as_current``, this context manager is neither thread-local nor + task-local. + Parameters ---------- c : Client diff --git a/distributed/lock.py b/distributed/lock.py index 3c893a419c2..7a55ccb4413 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -3,7 +3,7 @@ import logging import uuid -from .client import _get_global_client +from .client import Client from .utils import log_errors, TimeoutError from .worker import get_worker @@ -93,7 +93,11 @@ class Lock: """ def __init__(self, name=None, client=None): - self.client = client or _get_global_client() or get_worker().client + try: + self.client = client or Client.current() + except ValueError: + # Initialise new client + self.client = get_worker().client self.name = name or "lock-" + uuid.uuid4().hex self.id = uuid.uuid4().hex self._locked = False diff --git a/distributed/publish.py b/distributed/publish.py index 758e5ccc34b..4b30ebde042 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -6,10 +6,10 @@ class PublishExtension: """ An extension for the scheduler to manage collections - * publish-list - * publish-put - * publish-get - * publish-delete + * publish_list + * publish_put + * publish_get + * publish_delete """ def __init__(self, scheduler): @@ -59,21 +59,60 @@ class Datasets(MutableMapping): """ + __slots__ = ("_client",) + def __init__(self, client): - self.__client = client + self._client = client def __getitem__(self, key): - return self.__client.get_dataset(key) + # When client is asynchronous, it returns a coroutine + return self._client.get_dataset(key) def __setitem__(self, key, value): - self.__client.publish_dataset(value, name=key) + if self._client.asynchronous: + # 'await obj[key] = value' is not supported by Python as of 3.8 + raise TypeError( + "Can't use 'client.datasets[name] = value' when client is " + "asynchronous; please use 'client.publish_dataset(name=value)' instead" + ) + self._client.publish_dataset(value, name=key) def __delitem__(self, key): - self.__client.unpublish_dataset(key) + if self._client.asynchronous: + # 'await del obj[key]' is not supported by Python as of 3.8 + raise TypeError( + "Can't use 'del client.datasets[name]' when client is asynchronous; " + "please use 'client.unpublish_dataset(name)' instead" + ) + return self._client.unpublish_dataset(key) def __iter__(self): - for key in self.__client.list_datasets(): + if self._client.asynchronous: + raise TypeError( + "Can't invoke iter() or 'for' on client.datasets when client is " + "asynchronous; use 'async for' instead" + ) + for key in self._client.list_datasets(): yield key + def __aiter__(self): + if not self._client.asynchronous: + raise TypeError( + "Can't invoke 'async for' on client.datasets when client is " + "synchronous; use iter() or 'for' instead" + ) + + async def _(): + for key in await self._client.list_datasets(): + yield key + + return _() + def __len__(self): - return len(self.__client.list_datasets()) + if self._client.asynchronous: + # 'await len(obj)' is not supported by Python as of 3.8 + raise TypeError( + "Can't use 'len(client.datasets)' when client is asynchronous; " + "please use 'len(await client.list_datasets())' instead" + ) + return len(self._client.list_datasets()) diff --git a/distributed/queues.py b/distributed/queues.py index 81262703ad4..324fb46c40b 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -3,7 +3,7 @@ import logging import uuid -from .client import Future, _get_global_client, Client +from .client import Future, Client from .utils import tokey, sync, thread_state from .worker import get_client @@ -148,7 +148,7 @@ class Queue: not given, a random name will be generated. client: Client (optional) Client used for communication with the scheduler. Defaults to the - value of ``_get_global_client()``. + value of ``Client.current()``. maxsize: int (optional) Number of items allowed in the queue. If 0 (the default), the queue size is unbounded. @@ -167,7 +167,7 @@ class Queue: """ def __init__(self, name=None, client=None, maxsize=0): - self.client = client or _get_global_client() + self.client = client or Client.current() self.name = name or "queue-" + uuid.uuid4().hex self._event_started = asyncio.Event() if self.client.asynchronous or getattr( diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fd95895c84e..634194bbae3 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1201,7 +1201,7 @@ async def test_get_releases_data(c, s, a, b): assert time() < start + 2 -def test_Current(s, a, b): +def test_current(s, a, b): with Client(s["address"]) as c: assert Client.current() is c with pytest.raises(ValueError): @@ -3876,38 +3876,148 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b): @gen_cluster(client=False) async def test_serialize_future(s, a, b): - c = await Client(s.address, asynchronous=True) - f = await Client(s.address, asynchronous=True) + c1 = await Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) - future = c.submit(lambda: 1) + future = c1.submit(lambda: 1) result = await future - with temp_default_client(f): - future2 = pickle.loads(pickle.dumps(future)) - assert future2.client is f - assert tokey(future2.key) in f.futures - result2 = await future2 - assert result == result2 + for ci in (c1, c2): + for ctxman in ci.as_current, lambda: temp_default_client(ci): + with ctxman(): + future2 = pickle.loads(pickle.dumps(future)) + assert future2.client is ci + assert tokey(future2.key) in ci.futures + result2 = await future2 + assert result == result2 - await c.close() - await f.close() + await c1.close() + await c2.close() @gen_cluster(client=False) -async def test_temp_client(s, a, b): - c = await Client(s.address, asynchronous=True) - f = await Client(s.address, asynchronous=True) +async def test_temp_default_client(s, a, b): + c1 = await Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) + + with temp_default_client(c1): + assert default_client() is c1 + assert default_client(c2) is c2 + + with temp_default_client(c2): + assert default_client() is c2 + assert default_client(c1) is c1 + + await c1.close() + await c2.close() + + +@gen_cluster(client=True) +async def test_as_current(c, s, a, b): + c1 = await Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) with temp_default_client(c): - assert default_client() is c - assert default_client(f) is f + assert Client.current() is c + with pytest.raises(ValueError): + Client.current(allow_global=False) + with c1.as_current(): + assert Client.current() is c1 + assert Client.current(allow_global=True) is c1 + with c2.as_current(): + assert Client.current() is c2 + assert Client.current(allow_global=True) is c2 + + await c1.close() + await c2.close() - with temp_default_client(f): - assert default_client() is f - assert default_client(c) is c - await c.close() - await f.close() +def test_as_current_is_thread_local(s): + l1 = threading.Lock() + l2 = threading.Lock() + l3 = threading.Lock() + l4 = threading.Lock() + l1.acquire() + l2.acquire() + l3.acquire() + l4.acquire() + + def run1(): + with Client(s.address) as c: + with c.as_current(): + l1.acquire() + l2.release() + try: + # This line runs only when both run1 and run2 are inside the + # context manager + assert Client.current(allow_global=False) is c + finally: + l3.acquire() + l4.release() + + def run2(): + with Client(s.address) as c: + with c.as_current(): + l1.release() + l2.acquire() + try: + # This line runs only when both run1 and run2 are inside the + # context manager + assert Client.current(allow_global=False) is c + finally: + l3.release() + l4.acquire() + + t1 = threading.Thread(target=run1) + t2 = threading.Thread(target=run2) + t1.start() + t2.start() + t1.join() + t2.join() + + +@pytest.mark.xfail( + sys.version_info < (3, 7), + reason="Python 3.6 contextvars are not copied on Task creation", +) +@gen_cluster(client=False) +async def test_as_current_is_task_local(s, a, b): + l1 = asyncio.Lock() + l2 = asyncio.Lock() + l3 = asyncio.Lock() + l4 = asyncio.Lock() + await l1.acquire() + await l2.acquire() + await l3.acquire() + await l4.acquire() + + async def run1(): + async with Client(s.address, asynchronous=True) as c: + with c.as_current(): + await l1.acquire() + l2.release() + try: + # This line runs only when both run1 and run2 are inside the + # context manager + assert Client.current(allow_global=False) is c + finally: + await l3.acquire() + l4.release() + + async def run2(): + async with Client(s.address, asynchronous=True) as c: + with c.as_current(): + l1.release() + await l2.acquire() + try: + # This line runs only when both run1 and run2 are inside the + # context manager + assert Client.current(allow_global=False) is c + finally: + l3.release() + await l4.acquire() + + await asyncio.gather(run1(), run2()) @nodebug # test timing is fragile diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index ab32d52a112..a789f5a47f9 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -213,6 +213,23 @@ def test_datasets_iter(client): client.publish_dataset(**{str(key): key for key in keys}) for n, key in enumerate(client.datasets): assert key == str(n) + with pytest.raises(TypeError): + client.datasets.__aiter__() + + +@gen_cluster(client=True) +async def test_datasets_async(c, s, a, b): + await c.publish_dataset(foo=1, bar=2) + assert await c.datasets["foo"] == 1 + assert {k async for k in c.datasets} == {"foo", "bar"} + with pytest.raises(TypeError): + c.datasets["baz"] = 3 + with pytest.raises(TypeError): + del c.datasets["foo"] + with pytest.raises(TypeError): + next(iter(c.datasets)) + with pytest.raises(TypeError): + len(c.datasets) @gen_cluster(client=True) @@ -229,3 +246,35 @@ async def test_pickle_safe(c, s, a, b): with pytest.raises(TypeError): await c2.get_dataset("z") + + +@gen_cluster(client=True) +async def test_deserialize_client(c, s, a, b): + """Test that the client attached to Futures returned by Client.get_dataset is always + the instance of the client that invoked the method. + Specifically: + + - when the client is defined by hostname, test that it is not accidentally + reinitialised by IP; + - when multiple clients are connected to the same scheduler, test that they don't + interfere with each other. + + See: test_client.test_serialize_future + See: https://github.com/dask/distributed/issues/3227 + """ + future = await c.scatter("123") + await c.publish_dataset(foo=future) + future = await c.get_dataset("foo") + assert future.client is c + + for addr in (s.address, "localhost:" + s.address.split(":")[-1]): + async with Client(addr, asynchronous=True) as c2: + future = await c.get_dataset("foo") + assert future.client is c + future = await c2.get_dataset("foo") + assert future.client is c2 + + # Ensure cleanup + from distributed.client import _current_client + + assert _current_client.get() is None diff --git a/distributed/variable.py b/distributed/variable.py index a47064b1397..dc717533a28 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -5,7 +5,7 @@ from tlz import merge -from .client import Future, _get_global_client, Client +from .client import Future, Client from .utils import tokey, log_errors, TimeoutError, ignoring from .worker import get_client @@ -142,7 +142,7 @@ class Variable: If not given, a random name will be generated. client: Client (optional) Client used for communication with the scheduler. Defaults to the - value of ``_get_global_client()``. + value of ``Client.current()``. Examples -------- @@ -161,7 +161,7 @@ class Variable: """ def __init__(self, name=None, client=None, maxsize=0): - self.client = client or _get_global_client() + self.client = client or Client.current() self.name = name or "variable-" + uuid.uuid4().hex async def _set(self, value): diff --git a/distributed/worker.py b/distributed/worker.py index 63fd5c6daa4..cfdfdb95256 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3118,14 +3118,15 @@ def get_client(address=None, timeout=3, resolve_address=True): if not address or worker.scheduler.address == address: return worker._get_client(timeout=timeout) - from .client import _get_global_client + from .client import Client - client = _get_global_client() # TODO: assumes the same scheduler + try: + client = Client.current() # TODO: assumes the same scheduler + except ValueError: + client = None if client and (not address or client.scheduler.address == address): return client elif address: - from .client import Client - return Client(address, timeout=timeout) else: raise ValueError("No global client found and no address provided") diff --git a/requirements.txt b/requirements.txt index 4cb3ba60ae7..b0d20cdb1eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ click >= 6.6 cloudpickle >= 0.2.2 +contextvars;python_version<'3.7' dask >= 2.9.0 msgpack >= 0.6.0 psutil >= 5.0