Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client.get_dataset to always create Futures attached to itself #3729

Merged
merged 24 commits into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
207c27b
Python 3.6+ syntax
crusaderky Apr 11, 2020
02b1079
Code polish
crusaderky Apr 14, 2020
026a602
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 14, 2020
e07b9f8
Revert
crusaderky Apr 14, 2020
5d2566c
Polish
crusaderky Apr 14, 2020
a5dc1be
Revert "Polish"
crusaderky Apr 14, 2020
f2b6e1f
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 15, 2020
f86126a
tests
crusaderky Apr 15, 2020
7ef5832
Merge branch 'master' into get_dataset_async
crusaderky Apr 16, 2020
77a0d8a
revert
crusaderky Apr 16, 2020
de6b457
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 17, 2020
9d6c8c0
revert
crusaderky Apr 17, 2020
6515c1b
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 18, 2020
d8fcdc9
Merge branch 'master' into get_dataset_async
crusaderky Apr 20, 2020
21522df
xfail
crusaderky Apr 20, 2020
4b55b31
Better async functions
crusaderky Apr 20, 2020
fb8f777
Use contextvars to deserialize Future
crusaderky Apr 20, 2020
804b537
Merge branch 'master' into get_dataset_async
crusaderky Apr 21, 2020
2d594f8
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 27, 2020
042eddf
Redesign
crusaderky Apr 27, 2020
621573d
Tweaks
crusaderky Apr 27, 2020
77fa36b
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky May 5, 2020
e70303a
docstrings
crusaderky May 5, 2020
f0f8c6e
Merge branch 'master' into get_dataset_async
crusaderky May 7, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/ci-windows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ jobs:
activate-environment: testenv
auto-activate-base: false

- name: Install contextvars
shell: bash -l {0}
run: |
if [[ "${{ matrix.python-version }}" = "3.6" ]]; then
conda install -c conda-forge contextvars
fi

- name: Install tornado
shell: bash -l {0}
run: |
Expand Down
4 changes: 4 additions & 0 deletions continuous_integration/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ conda install -c conda-forge -q \
zstandard \
$PACKAGES

if [[ $PYTHON == 3.6 ]]; then
conda install -c conda-forge -c defaults contextvars
fi

# stacktrace is not currently avaiable for Python 3.8.
# Remove the version check block below when it is avaiable.
if [[ $PYTHON != 3.8 ]]; then
Expand Down
42 changes: 30 additions & 12 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import DoneAndNotDoneFutures
from contextlib import contextmanager
from contextvars import ContextVar
import copy
import errno
from functools import partial
Expand Down Expand Up @@ -90,6 +91,7 @@
_global_clients = weakref.WeakValueDictionary()
_global_client_index = [0]

_deserialize_client = ContextVar("_deserialize_client", default=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use contextvars here? We currently use a thread local for storing these kind of things, but that doesn't always play nicely with async. Is there a way we could refactor part of distributed to avoid passing around this implicit state?

In general I thinkcontextvars can be quite useful where required, but want to ensure it's needed here before bringing it in. We also already have a fair bit of implicit state, I'm a bit hesitant to add more if we can avoid it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid using contextvars you need to explicitly propagate the client until you stop calling async functions and just call sync functions. In this case, this means that

  • client.Client._get_dataset needs to pass self to
  • core.send_recv, which in turn needs to pass it to
  • comm.tcp.TCP.read (plus its variants comm.inproc.InProc.read and the 100% untested comm.ucx.UCX.read), which in turn needs to pass it to
  • comm.utils.from_frames, which can finally put a context manager (setting a thread-local variable) around
    return protocol.loads(
    frames, deserialize=deserialize, deserializers=deserializers
    )
  • which will alter the behaviour of client.Future.__getstate__

and it should be really self-evident why I don't recommend doing it.

We currently use a thread local for storing these kind of things

For this one you were using a global, client._global_clients.
contextvars exactly fits the use case of thread locals when you use coroutines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me. Should this logic replace how we access implicit clients everywhere? If so, perhaps we should rename the context var key?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other cases where a method of Client loses self along the way and re-acquires it later from a global variable?

Also note the very big fat caveat of the race conditions on Python 3.6. Not sure how much of the dask.distributed user base still uses Python 3.6 AND has more than one client running.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, let me clarify.

It feels a bit weird to handle setting the contextvar in the _get_dataset method. It might be cleaner to have a parent context over all coroutines started by a client that includes a reference back to the client. As it stands this is baking in functionality to Future that is only used by one method, when we likely want this everywhere.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be cleaner to have a parent context over all coroutines started by a client that includes a reference back to the client.

Are you talking to some sort of blanket decorator applied around all methods of Client? If so, it feels a bit overkill to me, as I do not know of another method that needs it.

As it stands this is baking in functionality to Future that is only used by one method, when we likely want this everywhere.

You definitely want to reuse this same variable every time you have a self reference to a Client instance, invoke a method, lose the reference (because it would be too cumbersome to propagated by hand), and then you need to reacquire it deeper down the line before the end of the method. If you can point me to any other function that does this, I'll be happy to look into it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcrist ping

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm redesigning the change to make it more generic, hold on


DEFAULT_EXTENSIONS = [PubSubClientExtension]

Expand Down Expand Up @@ -163,7 +165,7 @@ def __init__(self, key, client=None, inform=True, state=None):
self.key = key
self._cleared = False
tkey = tokey(key)
self.client = client or _get_global_client()
self.client = client or _deserialize_client.get() or _get_global_client()
self.client._inc_ref(tkey)
self._generation = self.client.generation

Expand Down Expand Up @@ -354,11 +356,11 @@ def release(self, _in_destructor=False):
pass # Shutting down, add_callback may be None

def __getstate__(self):
return (self.key, self.client.scheduler.address)
return self.key, self.client.scheduler.address

def __setstate__(self, state):
key, address = state
c = get_client(address)
c = _deserialize_client.get() or get_client(address)
Future.__init__(self, key, c)
c._send_to_scheduler(
{
Expand Down Expand Up @@ -2175,8 +2177,7 @@ def retry(self, futures, asynchronous=None):
"""
return self.sync(self._retry, futures, asynchronous=asynchronous)

@gen.coroutine
def _publish_dataset(self, *args, name=None, **kwargs):
async def _publish_dataset(self, *args, name=None, **kwargs):
with log_errors():
coroutines = []

Expand All @@ -2202,7 +2203,7 @@ def add_coro(name, data):
for name, data in kwargs.items():
add_coro(name, data)

yield coroutines
await asyncio.gather(*coroutines)

def publish_dataset(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -2282,13 +2283,30 @@ def list_datasets(self, **kwargs):
return self.sync(self.scheduler.publish_list, **kwargs)

async def _get_dataset(self, name):
out = await self.scheduler.publish_get(name=name, client=self.id)
if out is None:
raise KeyError("Dataset '%s' not found" % name)
if sys.version_info >= (3, 7):
# Insulate contextvars change with a task
async def _():
_deserialize_client.set(self)
return await self.scheduler.publish_get(name=name, client=self.id)

out = await asyncio.create_task(_())
else:
# Python 3.6; creating a task doesn't copy the context.
# We can still detect a race condition though.
if _deserialize_client.get() not in (self, None):
raise RuntimeError( # pragma: nocover
"Detected race condition where get_dataset() is invoked in "
"parallel by multiple clients. Please upgrade to Python 3.7+."
)
tok = _deserialize_client.set(self)
try:
out = await self.scheduler.publish_get(name=name, client=self.id)
finally:
_deserialize_client.reset(tok)

with temp_default_client(self):
data = out["data"]
return data
if out is None:
raise KeyError(f"Dataset '{name}' not found")
return out["data"]

def get_dataset(self, name, **kwargs):
"""
Expand Down
59 changes: 49 additions & 10 deletions distributed/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
class PublishExtension:
""" An extension for the scheduler to manage collections

* publish-list
* publish-put
* publish-get
* publish-delete
* publish_list
* publish_put
* publish_get
* publish_delete
"""

def __init__(self, scheduler):
Expand Down Expand Up @@ -59,21 +59,60 @@ class Datasets(MutableMapping):

"""

__slots__ = ("_client",)

def __init__(self, client):
self.__client = client
self._client = client

def __getitem__(self, key):
return self.__client.get_dataset(key)
# When client is asynchronous, it returns a coroutine
return self._client.get_dataset(key)

def __setitem__(self, key, value):
self.__client.publish_dataset(value, name=key)
if self._client.asynchronous:
# 'await obj[key] = value' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'client.datasets[name] = value' when client is "
"asynchronous; please use 'client.publish_dataset(name=value)' instead"
)
self._client.publish_dataset(value, name=key)

def __delitem__(self, key):
self.__client.unpublish_dataset(key)
if self._client.asynchronous:
# 'await del obj[key]' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'del client.datasets[name]' when client is asynchronous; "
"please use 'client.unpublish_dataset(name)' instead"
)
return self._client.unpublish_dataset(key)

def __iter__(self):
for key in self.__client.list_datasets():
if self._client.asynchronous:
raise TypeError(
"Can't invoke iter() or 'for' on client.datasets when client is "
"asynchronous; use 'async for' instead"
)
for key in self._client.list_datasets():
yield key

def __aiter__(self):
if not self._client.asynchronous:
raise TypeError(
"Can't invoke 'async for' on client.datasets when client is "
"synchronous; use iter() or 'for' instead"
)

async def _():
for key in await self._client.list_datasets():
yield key

return _()

def __len__(self):
return len(self.__client.list_datasets())
if self._client.asynchronous:
# 'await len(obj)' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'len(client.datasets)' when client is asynchronous; "
"please use 'len(await client.list_datasets())' instead"
)
return len(self._client.list_datasets())
48 changes: 48 additions & 0 deletions distributed/tests/test_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,23 @@ def test_datasets_iter(client):
client.publish_dataset(**{str(key): key for key in keys})
for n, key in enumerate(client.datasets):
assert key == str(n)
with pytest.raises(TypeError):
client.datasets.__aiter__()


@gen_cluster(client=True)
async def test_datasets_async(c, s, a, b):
await c.publish_dataset(foo=1, bar=2)
assert await c.datasets["foo"] == 1
assert {k async for k in c.datasets} == {"foo", "bar"}
with pytest.raises(TypeError):
c.datasets["baz"] = 3
with pytest.raises(TypeError):
del c.datasets["foo"]
with pytest.raises(TypeError):
next(iter(c.datasets))
with pytest.raises(TypeError):
len(c.datasets)


@gen_cluster(client=True)
Expand All @@ -229,3 +246,34 @@ async def test_pickle_safe(c, s, a, b):

with pytest.raises(TypeError):
await c2.get_dataset("z")


@gen_cluster(client=True)
async def test_deserialize_client(c, s, a, b):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test case.

"""Test that the client attached to Futures returned by Client.get_dataset is always
the instance of the client that invoked the method.
Specifically:

- when the client is defined by hostname, test that it is not accidentally
reinitialised by IP;
- when multiple clients are connected to the same scheduler, test that they don't
interfere with each other.

See: https://github.com/dask/distributed/issues/3227
"""
future = await c.scatter("123")
await c.publish_dataset(foo=future)
future = await c.get_dataset("foo")
assert future.client is c

for addr in (s.address, "localhost:" + s.address.split(":")[-1]):
async with Client(addr, asynchronous=True) as c2:
future = await c.get_dataset("foo")
assert future.client is c
future = await c2.get_dataset("foo")
assert future.client is c2

# Ensure cleanup
from distributed.client import _deserialize_client

assert _deserialize_client.get() is None
1 change: 1 addition & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async def test_worksteal_many_thieves(c, s, *workers):
assert sum(map(len, s.has_what.values())) < 150


@pytest.mark.xfail(reason="GH#3574")
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_dont_steal_unknown_functions(c, s, a, b):
futures = c.map(inc, range(100), workers=a.address, allow_other_workers=True)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
click >= 6.6
cloudpickle >= 0.2.2
contextvars;python_version<'3.7'
dask >= 2.9.0
msgpack >= 0.6.0
psutil >= 5.0
Expand Down