Skip to content

Commit 0ed224d

Browse files
authored
Support implicit array conversion with query-planning enabled (#15378)
when query-planning is enabled, implicit conversion is not yet supported from a cudf-backed collection to a dask array. [Some cuml + dask CI failures are related to this limitation](rapidsai/cuml#5815 (comment)). This PR adds basic support for implicit conversion. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: #15378
1 parent 4e44d5d commit 0ed224d

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

python/dask_cudf/dask_cudf/expr/_collection.py

+31
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,34 @@ class Index(DXIndex):
108108
get_collection_type.register(cudf.DataFrame, lambda _: DataFrame)
109109
get_collection_type.register(cudf.Series, lambda _: Series)
110110
get_collection_type.register(cudf.BaseIndex, lambda _: Index)
111+
112+
113+
##
114+
## Support conversion to GPU-backed Array collections
115+
##
116+
117+
118+
try:
119+
from dask_expr._backends import create_array_collection
120+
121+
@get_collection_type.register_lazy("cupy")
122+
def _register_cupy():
123+
import cupy
124+
125+
@get_collection_type.register(cupy.ndarray)
126+
def get_collection_type_cupy_array(_):
127+
return create_array_collection
128+
129+
@get_collection_type.register_lazy("cupyx")
130+
def _register_cupyx():
131+
# Needed for cuml
132+
from cupyx.scipy.sparse import spmatrix
133+
134+
@get_collection_type.register(spmatrix)
135+
def get_collection_type_csr_matrix(_):
136+
return create_array_collection
137+
138+
except ImportError:
139+
# Older version of dask-expr.
140+
# Implicit conversion to array wont work.
141+
pass

python/dask_cudf/dask_cudf/tests/test_core.py

+34
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,37 @@ def test_categorical_dtype_round_trip():
913913
actual = ds.compute()
914914
expected = pds.compute()
915915
assert actual.dtype.ordered == expected.dtype.ordered
916+
917+
918+
def test_implicit_array_conversion_cupy():
919+
s = cudf.Series(range(10))
920+
ds = dask_cudf.from_cudf(s, npartitions=2)
921+
922+
def func(x):
923+
return x.values
924+
925+
# Need to compute the dask collection for now.
926+
# See: https://github.com/dask/dask/issues/11017
927+
result = ds.map_partitions(func, meta=s.values).compute()
928+
expect = func(s)
929+
930+
dask.array.assert_eq(result, expect)
931+
932+
933+
def test_implicit_array_conversion_cupy_sparse():
934+
cupyx = pytest.importorskip("cupyx")
935+
936+
s = cudf.Series(range(10), dtype="float32")
937+
ds = dask_cudf.from_cudf(s, npartitions=2)
938+
939+
def func(x):
940+
return cupyx.scipy.sparse.csr_matrix(x.values)
941+
942+
# Need to compute the dask collection for now.
943+
# See: https://github.com/dask/dask/issues/11017
944+
result = ds.map_partitions(func, meta=s.values).compute()
945+
expect = func(s)
946+
947+
# NOTE: The calculation here doesn't need to make sense.
948+
# We just need to make sure we get the right type back.
949+
assert type(result) == type(expect)

0 commit comments

Comments
 (0)