Commit 0ed224d 1 parent 4e44d5d commit 0ed224d Copy full SHA for 0ed224d
File tree 2 files changed +65
-0
lines changed
python/dask_cudf/dask_cudf
2 files changed +65
-0
lines changed Original file line number Diff line number Diff line change @@ -108,3 +108,34 @@ class Index(DXIndex):
108
108
get_collection_type .register (cudf .DataFrame , lambda _ : DataFrame )
109
109
get_collection_type .register (cudf .Series , lambda _ : Series )
110
110
get_collection_type .register (cudf .BaseIndex , lambda _ : Index )
111
+
112
+
113
+ ##
114
+ ## Support conversion to GPU-backed Array collections
115
+ ##
116
+
117
+
118
+ try :
119
+ from dask_expr ._backends import create_array_collection
120
+
121
+ @get_collection_type .register_lazy ("cupy" )
122
+ def _register_cupy ():
123
+ import cupy
124
+
125
+ @get_collection_type .register (cupy .ndarray )
126
+ def get_collection_type_cupy_array (_ ):
127
+ return create_array_collection
128
+
129
+ @get_collection_type .register_lazy ("cupyx" )
130
+ def _register_cupyx ():
131
+ # Needed for cuml
132
+ from cupyx .scipy .sparse import spmatrix
133
+
134
+ @get_collection_type .register (spmatrix )
135
+ def get_collection_type_csr_matrix (_ ):
136
+ return create_array_collection
137
+
138
+ except ImportError :
139
+ # Older version of dask-expr.
140
+ # Implicit conversion to array wont work.
141
+ pass
Original file line number Diff line number Diff line change @@ -913,3 +913,37 @@ def test_categorical_dtype_round_trip():
913
913
actual = ds .compute ()
914
914
expected = pds .compute ()
915
915
assert actual .dtype .ordered == expected .dtype .ordered
916
+
917
+
918
+ def test_implicit_array_conversion_cupy ():
919
+ s = cudf .Series (range (10 ))
920
+ ds = dask_cudf .from_cudf (s , npartitions = 2 )
921
+
922
+ def func (x ):
923
+ return x .values
924
+
925
+ # Need to compute the dask collection for now.
926
+ # See: https://github.com/dask/dask/issues/11017
927
+ result = ds .map_partitions (func , meta = s .values ).compute ()
928
+ expect = func (s )
929
+
930
+ dask .array .assert_eq (result , expect )
931
+
932
+
933
+ def test_implicit_array_conversion_cupy_sparse ():
934
+ cupyx = pytest .importorskip ("cupyx" )
935
+
936
+ s = cudf .Series (range (10 ), dtype = "float32" )
937
+ ds = dask_cudf .from_cudf (s , npartitions = 2 )
938
+
939
+ def func (x ):
940
+ return cupyx .scipy .sparse .csr_matrix (x .values )
941
+
942
+ # Need to compute the dask collection for now.
943
+ # See: https://github.com/dask/dask/issues/11017
944
+ result = ds .map_partitions (func , meta = s .values ).compute ()
945
+ expect = func (s )
946
+
947
+ # NOTE: The calculation here doesn't need to make sense.
948
+ # We just need to make sure we get the right type back.
949
+ assert type (result ) == type (expect )
You can’t perform that action at this time.
0 commit comments