Skip to content

Commit

Permalink
Fix list to tuple conversion (#555)
Browse files Browse the repository at this point in the history
Fixes #549 by converting received tuples to lists.

Depend on dask/distributed#4621, which fixes an unrelated bug also triggered by our explicit-comms tests.

Authors:
  - Mads R. B. Kristensen (@madsbk)

Approvers:
  - Peter Andreas Entschev (@pentschev)

URL: #555
  • Loading branch information
madsbk authored Mar 24, 2021
1 parent b882bf2 commit 9747c96
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 5 additions & 1 deletion dask_cuda/explicit_comms/dataframe/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ async def recv(
for rank, ep in eps.items():
if rank in in_nparts:
futures.append(ep.read())
out_parts_list.extend(nested_deserialize(await asyncio.gather(*futures)))

# Notice, since Dask may convert lists to tuples, we convert them back into lists
out_parts_list.extend(
[[y for y in x] for x in nested_deserialize(await asyncio.gather(*futures))]
)


def sort_in_parts(
Expand Down
3 changes: 0 additions & 3 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def _test_dataframe_merge(backend, protocol, n_workers):
@pytest.mark.parametrize("nworkers", [1, 2, 4])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/549")
def test_dataframe_merge(backend, protocol, nworkers):
if backend == "cudf":
pytest.importorskip("cudf")
Expand Down Expand Up @@ -204,7 +203,6 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/549")
def test_dataframe_shuffle(backend, protocol, nworkers):
if backend == "cudf":
pytest.importorskip("cudf")
Expand Down Expand Up @@ -245,7 +243,6 @@ def check_shuffle(in_cluster):
check_shuffle(False)


@pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/549")
def test_dask_use_explicit_comms():
p = mp.Process(target=_test_dask_use_explicit_comms)
p.start()
Expand Down

0 comments on commit 9747c96

Please sign in to comment.