1
1
from __future__ import annotations
2
2
3
3
import warnings
4
+ from functools import wraps
4
5
from typing import TYPE_CHECKING
5
6
6
7
import anndata as ad
16
17
from scipy .sparse import issparse
17
18
18
19
import scanpy as sc
19
- from testing .scanpy . _helpers import as_dense_dask_array , as_sparse_dask_array
20
+ from testing .scanpy import _helpers
20
21
from testing .scanpy ._helpers .data import pbmc3k_normalized
21
22
from testing .scanpy ._pytest .marks import needs
22
23
from testing .scanpy ._pytest .params import (
26
27
)
27
28
28
29
if TYPE_CHECKING :
30
+ from collections .abc import Callable
29
31
from typing import Literal
30
32
33
+ from scanpy ._compat import DaskArray
34
+
31
35
A_list = np .array (
32
36
[
33
37
[0 , 0 , 7 , 0 , 0 ],
62
66
)
63
67
64
68
65
- # If one uses dask for PCA it will always require dask-ml
69
+ def _chunked_1d (
70
+ f : Callable [[np .ndarray ], DaskArray ],
71
+ ) -> Callable [[np .ndarray ], DaskArray ]:
72
+ @wraps (f )
73
+ def wrapper (a : np .ndarray ) -> DaskArray :
74
+ da = f (a )
75
+ return da .rechunk ((da .chunksize [0 ], - 1 ))
76
+
77
+ return wrapper
78
+
79
+
80
+ DASK_CONVERTERS = {
81
+ f : _chunked_1d (f )
82
+ for f in (_helpers .as_dense_dask_array , _helpers .as_sparse_dask_array )
83
+ }
84
+
85
+
66
86
@pytest .fixture (
67
87
params = [
68
88
param_with (at , marks = [needs .dask_ml ]) if "dask" in at .id else at
69
89
for at in ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED
70
90
]
71
91
)
72
92
def array_type (request : pytest .FixtureRequest ):
93
+ # If one uses dask for PCA it will always require dask-ml.
94
+ # dask-ml can’t do 2D-chunked arrays, so rechunk them.
95
+ if as_dask_array := DASK_CONVERTERS .get (request .param ):
96
+ return as_dask_array
97
+
98
+ # When not using dask, just return the array type
99
+ assert "dask" not in request .param .__name__ , "add more branches or refactor"
73
100
return request .param
74
101
75
102
@@ -92,8 +119,7 @@ def pca_params(
92
119
expected_warning = None
93
120
svd_solver = None
94
121
if svd_solver_type is not None :
95
- # TODO: are these right for sparse?
96
- if array_type in {as_dense_dask_array , as_sparse_dask_array }:
122
+ if array_type in DASK_CONVERTERS .values ():
97
123
svd_solver = (
98
124
{"auto" , "full" , "tsqr" , "randomized" }
99
125
if zero_center
@@ -350,19 +376,19 @@ def test_mask_var_argument_equivalence(float_dtype, array_type):
350
376
)
351
377
352
378
353
- def test_mask (array_type , request ):
354
- if array_type is as_dense_dask_array :
355
- pytest .xfail ("TODO: Dask arrays are not supported" )
379
+ def test_mask (request : pytest .FixtureRequest , array_type ):
380
+ if array_type in DASK_CONVERTERS .values ():
381
+ reason = "TODO: Dask arrays are not supported"
382
+ request .applymarker (pytest .mark .xfail (reason = reason ))
356
383
adata = sc .datasets .blobs (n_variables = 10 , n_centers = 3 , n_observations = 100 )
357
384
adata .X = array_type (adata .X )
358
385
359
386
if isinstance (adata .X , np .ndarray ) and Version (ad .__version__ ) < Version ("0.9" ):
360
- request .node .add_marker (
361
- pytest .mark .xfail (
362
- reason = "TODO: Previous version of anndata would return an F ordered array for one"
363
- " case here, which suprisingly considerably changes the results of PCA. "
364
- )
387
+ reason = (
388
+ "TODO: Previous version of anndata would return an F ordered array for one"
389
+ " case here, which surprisingly considerably changes the results of PCA."
365
390
)
391
+ request .applymarker (pytest .mark .xfail (reason = reason ))
366
392
mask_var = np .random .choice ([True , False ], adata .shape [1 ])
367
393
368
394
adata_masked = adata [:, mask_var ].copy ()
@@ -379,13 +405,10 @@ def test_mask(array_type, request):
379
405
)
380
406
381
407
382
- def test_mask_order_warning (request ):
408
+ def test_mask_order_warning (request : pytest . FixtureRequest ):
383
409
if Version (ad .__version__ ) >= Version ("0.9" ):
384
- request .node .add_marker (
385
- pytest .mark .xfail (
386
- reason = "Not expected to warn in later versions of anndata"
387
- )
388
- )
410
+ reason = "Not expected to warn in later versions of anndata"
411
+ request .applymarker (pytest .mark .xfail (reason = reason ))
389
412
390
413
adata = ad .AnnData (X = np .random .randn (50 , 5 ))
391
414
mask = np .array ([True , False , True , False , True ])
0 commit comments