From 9747c969e4b67632bd76d0b94d17dd271e87517e Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 24 Mar 2021 17:13:32 +0100 Subject: [PATCH] Fix list to tuple conversion (#555) Fixes https://github.com/rapidsai/dask-cuda/issues/549 by converting received tuples to lists. Depend on https://github.com/dask/distributed/pull/4621, which fixes an unrelated bug also triggered by our explicit-comms tests. Authors: - Mads R. B. Kristensen (@madsbk) Approvers: - Peter Andreas Entschev (@pentschev) URL: https://github.com/rapidsai/dask-cuda/pull/555 --- dask_cuda/explicit_comms/dataframe/shuffle.py | 6 +++++- dask_cuda/tests/test_explicit_comms.py | 3 --- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 0e0f0d9a8..e9bcb8242 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -39,7 +39,11 @@ async def recv( for rank, ep in eps.items(): if rank in in_nparts: futures.append(ep.read()) - out_parts_list.extend(nested_deserialize(await asyncio.gather(*futures))) + + # Notice, since Dask may convert lists to tuples, we convert them back into lists + out_parts_list.extend( + [[y for y in x] for x in nested_deserialize(await asyncio.gather(*futures))] + ) def sort_in_parts( diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index dbff3b924..05edbfb8b 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -99,7 +99,6 @@ def _test_dataframe_merge(backend, protocol, n_workers): @pytest.mark.parametrize("nworkers", [1, 2, 4]) @pytest.mark.parametrize("backend", ["pandas", "cudf"]) @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) -@pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/549") def test_dataframe_merge(backend, protocol, nworkers): if backend == "cudf": pytest.importorskip("cudf") @@ -204,7 +203,6 @@ def _test_dataframe_shuffle(backend, protocol, n_workers): @pytest.mark.parametrize("nworkers", [1, 2, 3]) @pytest.mark.parametrize("backend", ["pandas", "cudf"]) @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) -@pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/549") def test_dataframe_shuffle(backend, protocol, nworkers): if backend == "cudf": pytest.importorskip("cudf") @@ -245,7 +243,6 @@ def check_shuffle(in_cluster): check_shuffle(False) -@pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/549") def test_dask_use_explicit_comms(): p = mp.Process(target=_test_dask_use_explicit_comms) p.start()