Skip to content

Commit

Permalink
Client.get_dataset to always create Futures attached to itself (#3729)
Browse files Browse the repository at this point in the history
Adds a `contextvar` for storing the current client, which allows
calls that load a global client (like deserializing futures) to always
get the correct client for the context.
  • Loading branch information
crusaderky authored May 7, 2020
1 parent 6e90128 commit 77f6c55
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 59 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/ci-windows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ jobs:
activate-environment: dask-distributed
auto-activate-base: false

- name: Install contextvars
shell: bash -l {0}
run: |
if [[ "${{ matrix.python-version }}" = "3.6" ]]; then
conda install -c conda-forge contextvars
fi
- name: Install tornado
shell: bash -l {0}
run: |
Expand Down
4 changes: 4 additions & 0 deletions continuous_integration/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ conda create -n dask-distributed -c conda-forge -c defaults \

source activate dask-distributed

if [[ $PYTHON == 3.6 ]]; then
conda install -c conda-forge -c defaults contextvars
fi

# stacktrace is not currently avaiable for Python 3.8.
# Remove the version check block below when it is avaiable.
if [[ $PYTHON != 3.8 ]]; then
Expand Down
72 changes: 57 additions & 15 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import DoneAndNotDoneFutures
from contextlib import contextmanager
from contextvars import ContextVar
import copy
import errno
from functools import partial
Expand Down Expand Up @@ -89,6 +90,7 @@
_global_clients = weakref.WeakValueDictionary()
_global_client_index = [0]

_current_client = ContextVar("_current_client", default=None)

DEFAULT_EXTENSIONS = [PubSubClientExtension]

Expand Down Expand Up @@ -162,7 +164,7 @@ def __init__(self, key, client=None, inform=True, state=None):
self.key = key
self._cleared = False
tkey = tokey(key)
self.client = client or _get_global_client()
self.client = client or Client.current()
self.client._inc_ref(tkey)
self._generation = self.client.generation

Expand Down Expand Up @@ -353,11 +355,14 @@ def release(self, _in_destructor=False):
pass # Shutting down, add_callback may be None

def __getstate__(self):
return (self.key, self.client.scheduler.address)
return self.key, self.client.scheduler.address

def __setstate__(self, state):
key, address = state
c = get_client(address)
try:
c = Client.current(allow_global=False)
except ValueError:
c = get_client(address)
Future.__init__(self, key, c)
c._send_to_scheduler(
{
Expand Down Expand Up @@ -727,10 +732,41 @@ def __init__(

ReplayExceptionClient(self)

@contextmanager
def as_current(self):
"""Thread-local, Task-local context manager that causes the Client.current class
method to return self. Any Future objects deserialized inside this context
manager will be automatically attached to this Client.
"""
# In Python 3.6, contextvars are thread-local but not Task-local.
# We can still detect a race condition though.
if sys.version_info < (3, 7) and _current_client.get() not in (self, None):
raise RuntimeError(
"Detected race condition where multiple asynchronous clients tried "
"entering the as_current() context manager at the same time. "
"Please upgrade to Python 3.7+."
)

tok = _current_client.set(self)
try:
yield
finally:
_current_client.reset(tok)

@classmethod
def current(cls):
""" Return global client if one exists, otherwise raise ValueError """
return default_client()
def current(cls, allow_global=True):
"""When running within the context of `as_client`, return the context-local
current client. Otherwise, return the latest initialised Client.
If no Client instances exist, raise ValueError.
If allow_global is set to False, raise ValueError if running outside of the
`as_client` context manager.
"""
out = _current_client.get()
if out:
return out
if allow_global:
return default_client()
raise ValueError("Not running inside the `as_current` context manager")

@property
def asynchronous(self):
Expand Down Expand Up @@ -2178,8 +2214,7 @@ def retry(self, futures, asynchronous=None):
"""
return self.sync(self._retry, futures, asynchronous=asynchronous)

@gen.coroutine
def _publish_dataset(self, *args, name=None, **kwargs):
async def _publish_dataset(self, *args, name=None, **kwargs):
with log_errors():
coroutines = []

Expand All @@ -2205,7 +2240,7 @@ def add_coro(name, data):
for name, data in kwargs.items():
add_coro(name, data)

yield coroutines
await asyncio.gather(*coroutines)

def publish_dataset(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -2285,13 +2320,12 @@ def list_datasets(self, **kwargs):
return self.sync(self.scheduler.publish_list, **kwargs)

async def _get_dataset(self, name):
out = await self.scheduler.publish_get(name=name, client=self.id)
if out is None:
raise KeyError("Dataset '%s' not found" % name)
with self.as_current():
out = await self.scheduler.publish_get(name=name, client=self.id)

with temp_default_client(self):
data = out["data"]
return data
if out is None:
raise KeyError(f"Dataset '{name}' not found")
return out["data"]

def get_dataset(self, name, **kwargs):
"""
Expand Down Expand Up @@ -4697,6 +4731,14 @@ def __exit__(self, typ, value, traceback):
def temp_default_client(c):
""" Set the default client for the duration of the context
.. note::
This function should be used exclusively for unit testing the default client
functionality. In all other cases, please use ``Client.as_current`` instead.
.. note::
Unlike ``Client.as_current``, this context manager is neither thread-local nor
task-local.
Parameters
----------
c : Client
Expand Down
8 changes: 6 additions & 2 deletions distributed/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import uuid

from .client import _get_global_client
from .client import Client
from .utils import log_errors, TimeoutError
from .worker import get_worker

Expand Down Expand Up @@ -93,7 +93,11 @@ class Lock:
"""

def __init__(self, name=None, client=None):
self.client = client or _get_global_client() or get_worker().client
try:
self.client = client or Client.current()
except ValueError:
# Initialise new client
self.client = get_worker().client
self.name = name or "lock-" + uuid.uuid4().hex
self.id = uuid.uuid4().hex
self._locked = False
Expand Down
59 changes: 49 additions & 10 deletions distributed/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
class PublishExtension:
""" An extension for the scheduler to manage collections
* publish-list
* publish-put
* publish-get
* publish-delete
* publish_list
* publish_put
* publish_get
* publish_delete
"""

def __init__(self, scheduler):
Expand Down Expand Up @@ -59,21 +59,60 @@ class Datasets(MutableMapping):
"""

__slots__ = ("_client",)

def __init__(self, client):
self.__client = client
self._client = client

def __getitem__(self, key):
return self.__client.get_dataset(key)
# When client is asynchronous, it returns a coroutine
return self._client.get_dataset(key)

def __setitem__(self, key, value):
self.__client.publish_dataset(value, name=key)
if self._client.asynchronous:
# 'await obj[key] = value' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'client.datasets[name] = value' when client is "
"asynchronous; please use 'client.publish_dataset(name=value)' instead"
)
self._client.publish_dataset(value, name=key)

def __delitem__(self, key):
self.__client.unpublish_dataset(key)
if self._client.asynchronous:
# 'await del obj[key]' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'del client.datasets[name]' when client is asynchronous; "
"please use 'client.unpublish_dataset(name)' instead"
)
return self._client.unpublish_dataset(key)

def __iter__(self):
for key in self.__client.list_datasets():
if self._client.asynchronous:
raise TypeError(
"Can't invoke iter() or 'for' on client.datasets when client is "
"asynchronous; use 'async for' instead"
)
for key in self._client.list_datasets():
yield key

def __aiter__(self):
if not self._client.asynchronous:
raise TypeError(
"Can't invoke 'async for' on client.datasets when client is "
"synchronous; use iter() or 'for' instead"
)

async def _():
for key in await self._client.list_datasets():
yield key

return _()

def __len__(self):
return len(self.__client.list_datasets())
if self._client.asynchronous:
# 'await len(obj)' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'len(client.datasets)' when client is asynchronous; "
"please use 'len(await client.list_datasets())' instead"
)
return len(self._client.list_datasets())
6 changes: 3 additions & 3 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import uuid

from .client import Future, _get_global_client, Client
from .client import Future, Client
from .utils import tokey, sync, thread_state
from .worker import get_client

Expand Down Expand Up @@ -148,7 +148,7 @@ class Queue:
not given, a random name will be generated.
client: Client (optional)
Client used for communication with the scheduler. Defaults to the
value of ``_get_global_client()``.
value of ``Client.current()``.
maxsize: int (optional)
Number of items allowed in the queue. If 0 (the default), the queue
size is unbounded.
Expand All @@ -167,7 +167,7 @@ class Queue:
"""

def __init__(self, name=None, client=None, maxsize=0):
self.client = client or _get_global_client()
self.client = client or Client.current()
self.name = name or "queue-" + uuid.uuid4().hex
self._event_started = asyncio.Event()
if self.client.asynchronous or getattr(
Expand Down
Loading

0 comments on commit 77f6c55

Please sign in to comment.