Skip to content

Commit 85ec553

Browse files
flying-sheepmeeseeksmachine
authored andcommitted
Backport PR scverse#3162: Fix tests for dask PCA
1 parent d894167 commit 85ec553

File tree

4 files changed

+51
-24
lines changed

4 files changed

+51
-24
lines changed

src/testing/scanpy/_helpers/__init__.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66

77
import warnings
88
from itertools import permutations
9+
from typing import TYPE_CHECKING
910

1011
import numpy as np
1112
from anndata.tests.helpers import asarray, assert_equal
1213

1314
import scanpy as sc
1415

16+
if TYPE_CHECKING:
17+
from scanpy._compat import DaskArray
18+
1519
# TODO: Report more context on the fields being compared on error
1620
# TODO: Allow specifying paths to ignore on comparison
1721

@@ -124,13 +128,13 @@ def _check_check_values_warnings(function, adata, expected_warning, kwargs={}):
124128

125129

126130
# Delayed imports for case where we aren't using dask
127-
def as_dense_dask_array(*args, **kwargs):
131+
def as_dense_dask_array(*args, **kwargs) -> DaskArray:
128132
from anndata.tests.helpers import as_dense_dask_array
129133

130134
return as_dense_dask_array(*args, **kwargs)
131135

132136

133-
def as_sparse_dask_array(*args, **kwargs):
137+
def as_sparse_dask_array(*args, **kwargs) -> DaskArray:
134138
from anndata.tests.helpers import as_sparse_dask_array
135139

136140
return as_sparse_dask_array(*args, **kwargs)

tests/test_highly_variable_genes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@ def test_compare_to_upstream( # noqa: PLR0917
352352
array_type: Callable,
353353
):
354354
if func == "fgd" and flavor == "cell_ranger":
355-
msg = "The deprecated filter_genes_dispersion behaves differently with cell_ranger"
356-
request.node.add_marker(pytest.mark.xfail(reason=msg))
355+
reason = "The deprecated filter_genes_dispersion behaves differently with cell_ranger"
356+
request.applymarker(pytest.mark.xfail(reason=reason))
357357
hvg_info = pd.read_csv(ref_path)
358358

359359
pbmc = pbmc68k_reduced()

tests/test_pca.py

+41-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4+
from functools import wraps
45
from typing import TYPE_CHECKING
56

67
import anndata as ad
@@ -16,7 +17,7 @@
1617
from scipy.sparse import issparse
1718

1819
import scanpy as sc
19-
from testing.scanpy._helpers import as_dense_dask_array, as_sparse_dask_array
20+
from testing.scanpy import _helpers
2021
from testing.scanpy._helpers.data import pbmc3k_normalized
2122
from testing.scanpy._pytest.marks import needs
2223
from testing.scanpy._pytest.params import (
@@ -26,8 +27,11 @@
2627
)
2728

2829
if TYPE_CHECKING:
30+
from collections.abc import Callable
2931
from typing import Literal
3032

33+
from scanpy._compat import DaskArray
34+
3135
A_list = np.array(
3236
[
3337
[0, 0, 7, 0, 0],
@@ -62,14 +66,37 @@
6266
)
6367

6468

65-
# If one uses dask for PCA it will always require dask-ml
69+
def _chunked_1d(
70+
f: Callable[[np.ndarray], DaskArray],
71+
) -> Callable[[np.ndarray], DaskArray]:
72+
@wraps(f)
73+
def wrapper(a: np.ndarray) -> DaskArray:
74+
da = f(a)
75+
return da.rechunk((da.chunksize[0], -1))
76+
77+
return wrapper
78+
79+
80+
DASK_CONVERTERS = {
81+
f: _chunked_1d(f)
82+
for f in (_helpers.as_dense_dask_array, _helpers.as_sparse_dask_array)
83+
}
84+
85+
6686
@pytest.fixture(
6787
params=[
6888
param_with(at, marks=[needs.dask_ml]) if "dask" in at.id else at
6989
for at in ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED
7090
]
7191
)
7292
def array_type(request: pytest.FixtureRequest):
93+
# If one uses dask for PCA it will always require dask-ml.
94+
# dask-ml can’t do 2D-chunked arrays, so rechunk them.
95+
if as_dask_array := DASK_CONVERTERS.get(request.param):
96+
return as_dask_array
97+
98+
# When not using dask, just return the array type
99+
assert "dask" not in request.param.__name__, "add more branches or refactor"
73100
return request.param
74101

75102

@@ -92,8 +119,7 @@ def pca_params(
92119
expected_warning = None
93120
svd_solver = None
94121
if svd_solver_type is not None:
95-
# TODO: are these right for sparse?
96-
if array_type in {as_dense_dask_array, as_sparse_dask_array}:
122+
if array_type in DASK_CONVERTERS.values():
97123
svd_solver = (
98124
{"auto", "full", "tsqr", "randomized"}
99125
if zero_center
@@ -350,19 +376,19 @@ def test_mask_var_argument_equivalence(float_dtype, array_type):
350376
)
351377

352378

353-
def test_mask(array_type, request):
354-
if array_type is as_dense_dask_array:
355-
pytest.xfail("TODO: Dask arrays are not supported")
379+
def test_mask(request: pytest.FixtureRequest, array_type):
380+
if array_type in DASK_CONVERTERS.values():
381+
reason = "TODO: Dask arrays are not supported"
382+
request.applymarker(pytest.mark.xfail(reason=reason))
356383
adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=100)
357384
adata.X = array_type(adata.X)
358385

359386
if isinstance(adata.X, np.ndarray) and Version(ad.__version__) < Version("0.9"):
360-
request.node.add_marker(
361-
pytest.mark.xfail(
362-
reason="TODO: Previous version of anndata would return an F ordered array for one"
363-
" case here, which suprisingly considerably changes the results of PCA. "
364-
)
387+
reason = (
388+
"TODO: Previous version of anndata would return an F ordered array for one"
389+
" case here, which surprisingly considerably changes the results of PCA."
365390
)
391+
request.applymarker(pytest.mark.xfail(reason=reason))
366392
mask_var = np.random.choice([True, False], adata.shape[1])
367393

368394
adata_masked = adata[:, mask_var].copy()
@@ -379,13 +405,10 @@ def test_mask(array_type, request):
379405
)
380406

381407

382-
def test_mask_order_warning(request):
408+
def test_mask_order_warning(request: pytest.FixtureRequest):
383409
if Version(ad.__version__) >= Version("0.9"):
384-
request.node.add_marker(
385-
pytest.mark.xfail(
386-
reason="Not expected to warn in later versions of anndata"
387-
)
388-
)
410+
reason = "Not expected to warn in later versions of anndata"
411+
request.applymarker(pytest.mark.xfail(reason=reason))
389412

390413
adata = ad.AnnData(X=np.random.randn(50, 5))
391414
mask = np.array([True, False, True, False, True])

tests/test_preprocessing_distributed.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def test_normalize_per_cell(
7979
request: pytest.FixtureRequest, adata: AnnData, adata_dist: AnnData
8080
):
8181
if isinstance(adata_dist.X, DaskArray):
82-
msg = "normalize_per_cell deprecated and broken for Dask"
83-
request.node.add_marker(pytest.mark.xfail(reason=msg))
82+
reason = "normalize_per_cell deprecated and broken for Dask"
83+
request.applymarker(pytest.mark.xfail(reason=reason))
8484
normalize_per_cell(adata_dist)
8585
assert isinstance(adata_dist.X, DIST_TYPES)
8686
result = materialize_as_ndarray(adata_dist.X)

0 commit comments

Comments
 (0)