Skip to content

Fix tests for dask PCA #3162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/testing/scanpy/_helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@

import warnings
from itertools import permutations
from typing import TYPE_CHECKING

import numpy as np
from anndata.tests.helpers import asarray, assert_equal

import scanpy as sc

if TYPE_CHECKING:
from scanpy._compat import DaskArray

# TODO: Report more context on the fields being compared on error
# TODO: Allow specifying paths to ignore on comparison

Expand Down Expand Up @@ -124,13 +128,13 @@ def _check_check_values_warnings(function, adata, expected_warning, kwargs={}):


# Delayed imports for case where we aren't using dask
def as_dense_dask_array(*args, **kwargs):
def as_dense_dask_array(*args, **kwargs) -> DaskArray:
from anndata.tests.helpers import as_dense_dask_array

return as_dense_dask_array(*args, **kwargs)


def as_sparse_dask_array(*args, **kwargs):
def as_sparse_dask_array(*args, **kwargs) -> DaskArray:
from anndata.tests.helpers import as_sparse_dask_array

return as_sparse_dask_array(*args, **kwargs)
4 changes: 2 additions & 2 deletions tests/test_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ def test_compare_to_upstream( # noqa: PLR0917
array_type: Callable,
):
if func == "fgd" and flavor == "cell_ranger":
msg = "The deprecated filter_genes_dispersion behaves differently with cell_ranger"
request.node.add_marker(pytest.mark.xfail(reason=msg))
reason = "The deprecated filter_genes_dispersion behaves differently with cell_ranger"
request.applymarker(pytest.mark.xfail(reason=reason))
hvg_info = pd.read_csv(ref_path)

pbmc = pbmc68k_reduced()
Expand Down
59 changes: 41 additions & 18 deletions tests/test_pca.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import warnings
from functools import wraps
from typing import TYPE_CHECKING

import anndata as ad
Expand All @@ -16,7 +17,7 @@
from scipy.sparse import issparse

import scanpy as sc
from testing.scanpy._helpers import as_dense_dask_array, as_sparse_dask_array
from testing.scanpy import _helpers
from testing.scanpy._helpers.data import pbmc3k_normalized
from testing.scanpy._pytest.marks import needs
from testing.scanpy._pytest.params import (
Expand All @@ -26,8 +27,11 @@
)

if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal

from scanpy._compat import DaskArray

A_list = np.array(
[
[0, 0, 7, 0, 0],
Expand Down Expand Up @@ -62,14 +66,37 @@
)


# If one uses dask for PCA it will always require dask-ml
def _chunked_1d(
f: Callable[[np.ndarray], DaskArray],
) -> Callable[[np.ndarray], DaskArray]:
@wraps(f)
def wrapper(a: np.ndarray) -> DaskArray:
da = f(a)
return da.rechunk((da.chunksize[0], -1))

return wrapper


DASK_CONVERTERS = {
f: _chunked_1d(f)
for f in (_helpers.as_dense_dask_array, _helpers.as_sparse_dask_array)
}


@pytest.fixture(
params=[
param_with(at, marks=[needs.dask_ml]) if "dask" in at.id else at
for at in ARRAY_TYPES_SPARSE_DASK_UNSUPPORTED
]
)
def array_type(request: pytest.FixtureRequest):
# If one uses dask for PCA it will always require dask-ml.
# dask-ml can’t do 2D-chunked arrays, so rechunk them.
if as_dask_array := DASK_CONVERTERS.get(request.param):
return as_dask_array

# When not using dask, just return the array type
assert "dask" not in request.param.__name__, "add more branches or refactor"
return request.param


Expand All @@ -92,8 +119,7 @@ def pca_params(
expected_warning = None
svd_solver = None
if svd_solver_type is not None:
# TODO: are these right for sparse?
if array_type in {as_dense_dask_array, as_sparse_dask_array}:
if array_type in DASK_CONVERTERS.values():
svd_solver = (
{"auto", "full", "tsqr", "randomized"}
if zero_center
Expand Down Expand Up @@ -350,19 +376,19 @@ def test_mask_var_argument_equivalence(float_dtype, array_type):
)


def test_mask(array_type, request):
if array_type is as_dense_dask_array:
pytest.xfail("TODO: Dask arrays are not supported")
def test_mask(request: pytest.FixtureRequest, array_type):
if array_type in DASK_CONVERTERS.values():
reason = "TODO: Dask arrays are not supported"
request.applymarker(pytest.mark.xfail(reason=reason))
adata = sc.datasets.blobs(n_variables=10, n_centers=3, n_observations=100)
adata.X = array_type(adata.X)

if isinstance(adata.X, np.ndarray) and Version(ad.__version__) < Version("0.9"):
request.node.add_marker(
pytest.mark.xfail(
reason="TODO: Previous version of anndata would return an F ordered array for one"
" case here, which suprisingly considerably changes the results of PCA. "
)
reason = (
"TODO: Previous version of anndata would return an F ordered array for one"
" case here, which surprisingly considerably changes the results of PCA."
)
request.applymarker(pytest.mark.xfail(reason=reason))
mask_var = np.random.choice([True, False], adata.shape[1])

adata_masked = adata[:, mask_var].copy()
Expand All @@ -379,13 +405,10 @@ def test_mask(array_type, request):
)


def test_mask_order_warning(request):
def test_mask_order_warning(request: pytest.FixtureRequest):
if Version(ad.__version__) >= Version("0.9"):
request.node.add_marker(
pytest.mark.xfail(
reason="Not expected to warn in later versions of anndata"
)
)
reason = "Not expected to warn in later versions of anndata"
request.applymarker(pytest.mark.xfail(reason=reason))

adata = ad.AnnData(X=np.random.randn(50, 5))
mask = np.array([True, False, True, False, True])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_preprocessing_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def test_normalize_per_cell(
request: pytest.FixtureRequest, adata: AnnData, adata_dist: AnnData
):
if isinstance(adata_dist.X, DaskArray):
msg = "normalize_per_cell deprecated and broken for Dask"
request.node.add_marker(pytest.mark.xfail(reason=msg))
reason = "normalize_per_cell deprecated and broken for Dask"
request.applymarker(pytest.mark.xfail(reason=reason))
normalize_per_cell(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
Expand Down
Loading