From 7ced391b43ca6a40af551b58a984f14356b78292 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 15 Jan 2024 15:31:53 +0100 Subject: [PATCH 01/50] add some failing tests --- scanpy/tests/test_highly_variable_genes.py | 95 +++++++++++++++------- 1 file changed, 67 insertions(+), 28 deletions(-) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 0ef36643cc..03537db532 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -1,29 +1,51 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import pandas as pd import pytest +from anndata import AnnData from scipy import sparse import scanpy as sc +from scanpy._compat import DaskArray from scanpy.testing._helpers import _check_check_values_warnings from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced from scanpy.testing._pytest.marks import needs +from scanpy.testing._pytest.params import ARRAY_TYPES_SUPPORTED + +if TYPE_CHECKING: + from collections.abc import Callable FILE = Path(__file__).parent / Path("_scripts/seurat_hvg.csv") FILE_V3 = Path(__file__).parent / Path("_scripts/seurat_hvg_v3.csv.gz") FILE_V3_BATCH = Path(__file__).parent / Path("_scripts/seurat_hvg_v3_batch.csv") -def test_highly_variable_genes_runs(): +def validate_array_type(obj: object, at: Callable[[np.ndarray], object]) -> None: + if isinstance(at, type): + assert isinstance(obj, at) + elif at.__name__ == "asarray": + assert isinstance(obj, np.ndarray) + elif "dask" in at.__name__: + assert isinstance(obj, DaskArray) + else: + raise AssertionError(f"Unsupported array type {at}") + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +def test_highly_variable_genes_runs(array_type): adata = sc.datasets.blobs() + adata.X = array_type(adata.X) sc.pp.highly_variable_genes(adata) -def test_highly_variable_genes_supports_batch(): +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +def test_highly_variable_genes_supports_batch(array_type): adata = sc.datasets.blobs() + adata.X = array_type(adata.X) gen = np.random.default_rng(0) adata.obs["batch"] = pd.array( gen.binomial(3, 0.5, size=adata.n_obs), dtype="category" @@ -31,32 +53,49 @@ def test_highly_variable_genes_supports_batch(): sc.pp.highly_variable_genes(adata, batch_key="batch") assert "highly_variable_nbatches" in adata.var.columns assert "highly_variable_intersection" in adata.var.columns + validate_array_type(adata.var["highly_variable_nbatches"], array_type) + validate_array_type(adata.var["highly_variable_intersection"], array_type) -def test_highly_variable_genes_supports_layers(): - adata = sc.datasets.blobs() +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +def test_highly_variable_genes_supports_layers(array_type): gen = np.random.default_rng(0) - adata.obs["batch"] = pd.array( - gen.binomial(4, 0.5, size=adata.n_obs), dtype="category" - ) - sc.pp.highly_variable_genes(adata, batch_key="batch", n_top_genes=3) - assert "highly_variable_nbatches" in adata.var.columns - assert adata.var["highly_variable"].sum() == 3 - highly_var_first_layer = adata.var["highly_variable"] - adata = sc.datasets.blobs() - assert isinstance(adata.X, np.ndarray) - new_layer = adata.X.copy() - gen.shuffle(new_layer) - adata.layers["test_layer"] = new_layer - adata.obs["batch"] = gen.binomial(4, 0.5, size=(adata.n_obs)) - adata.obs["batch"] = adata.obs["batch"].astype("category") - sc.pp.highly_variable_genes( - adata, batch_key="batch", n_top_genes=3, layer="test_layer" - ) - assert "highly_variable_nbatches" in adata.var.columns - assert adata.var["highly_variable"].sum() == 3 - assert (highly_var_first_layer != adata.var["highly_variable"]).any() + def ad1() -> AnnData: + adata = sc.datasets.blobs() + adata.X = array_type(adata.X) + adata.obs["batch"] = pd.array( + gen.binomial(4, 0.5, size=adata.n_obs), dtype="category" + ) + sc.pp.highly_variable_genes(adata, batch_key="batch", n_top_genes=3) + assert "highly_variable_nbatches" in adata.var.columns + validate_array_type(adata.var["highly_variable_nbatches"], array_type) + validate_array_type(adata.var["highly_variable"], array_type) + assert adata.var["highly_variable"].sum() == 3 + return adata + + adata1 = ad1() + + def ad2() -> AnnData: + adata = sc.datasets.blobs() + assert isinstance(adata.X, np.ndarray) + new_layer = adata.X.copy() + gen.shuffle(new_layer) + adata.layers["test_layer"] = array_type(new_layer) + del new_layer + adata.obs["batch"] = gen.binomial(4, 0.5, size=(adata.n_obs)) + adata.obs["batch"] = adata.obs["batch"].astype("category") + sc.pp.highly_variable_genes( + adata, batch_key="batch", n_top_genes=3, layer="test_layer" + ) + assert "highly_variable_nbatches" in adata.var.columns + validate_array_type(adata.var["highly_variable_nbatches"], array_type) + validate_array_type(adata.var["highly_variable"], array_type) + assert adata.var["highly_variable"].sum() == 3 + return adata + + adata2 = ad2() + assert (adata1.var["highly_variable"] != adata2.var["highly_variable"]).any() def test_highly_variable_genes_no_batch_matches_batch(): @@ -73,7 +112,7 @@ def test_highly_variable_genes_no_batch_matches_batch(): ) -def test_highly_variable_genes_(): +def test_highly_variable_genes_no_inplace(): adata = sc.datasets.blobs() adata.obs["batch"] = np.tile(["a", "b"], adata.shape[0] // 2) sc.pp.highly_variable_genes(adata, batch_key="batch") @@ -473,8 +512,8 @@ def test_highly_variable_genes_batches(): adata[:100, :100].X = np.zeros((100, 100)) adata.obs["batch"] = ["0" if i < 100 else "1" for i in range(adata.n_obs)] - adata_1 = adata[adata.obs.batch.isin(["0"]), :] - adata_2 = adata[adata.obs.batch.isin(["1"]), :] + adata_1 = adata[adata.obs["batch"] == "0"].copy() + adata_2 = adata[adata.obs["batch"] == "1"].copy() sc.pp.highly_variable_genes( adata, @@ -545,7 +584,7 @@ def test_seurat_v3_mean_var_output_with_batchkey(): def test_cellranger_n_top_genes_warning(): X = np.random.poisson(2, (100, 30)) - adata = sc.AnnData(X) + adata = AnnData(X) sc.pp.normalize_total(adata) sc.pp.log1p(adata) From 80818690ab4b89f73cf2a00b4e40a22409ee9815 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 15 Jan 2024 16:52:20 +0100 Subject: [PATCH 02/50] simplify test --- scanpy/tests/test_highly_variable_genes.py | 39 +++++++--------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 03537db532..5ad644337f 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -59,34 +59,20 @@ def test_highly_variable_genes_supports_batch(array_type): @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) def test_highly_variable_genes_supports_layers(array_type): - gen = np.random.default_rng(0) - - def ad1() -> AnnData: + def execute(layer: str | None) -> AnnData: + gen = np.random.default_rng(0) adata = sc.datasets.blobs() - adata.X = array_type(adata.X) + assert isinstance(adata.X, np.ndarray) + if layer: + new_layer = adata.X.copy() + gen.shuffle(new_layer) + adata.layers[layer] = array_type(new_layer) + del new_layer, adata.X adata.obs["batch"] = pd.array( - gen.binomial(4, 0.5, size=adata.n_obs), dtype="category" + gen.binomial(4, 0.5, size=(adata.n_obs)), dtype="category" ) - sc.pp.highly_variable_genes(adata, batch_key="batch", n_top_genes=3) - assert "highly_variable_nbatches" in adata.var.columns - validate_array_type(adata.var["highly_variable_nbatches"], array_type) - validate_array_type(adata.var["highly_variable"], array_type) - assert adata.var["highly_variable"].sum() == 3 - return adata - - adata1 = ad1() - - def ad2() -> AnnData: - adata = sc.datasets.blobs() - assert isinstance(adata.X, np.ndarray) - new_layer = adata.X.copy() - gen.shuffle(new_layer) - adata.layers["test_layer"] = array_type(new_layer) - del new_layer - adata.obs["batch"] = gen.binomial(4, 0.5, size=(adata.n_obs)) - adata.obs["batch"] = adata.obs["batch"].astype("category") sc.pp.highly_variable_genes( - adata, batch_key="batch", n_top_genes=3, layer="test_layer" + adata, batch_key="batch", n_top_genes=3, layer=layer ) assert "highly_variable_nbatches" in adata.var.columns validate_array_type(adata.var["highly_variable_nbatches"], array_type) @@ -94,7 +80,7 @@ def ad2() -> AnnData: assert adata.var["highly_variable"].sum() == 3 return adata - adata2 = ad2() + adata1, adata2 = map(execute, [None, "test_layer"]) assert (adata1.var["highly_variable"] != adata2.var["highly_variable"]).any() @@ -103,8 +89,7 @@ def test_highly_variable_genes_no_batch_matches_batch(): sc.pp.highly_variable_genes(adata) no_batch_hvg = adata.var["highly_variable"].copy() assert no_batch_hvg.any() - adata.obs["batch"] = "batch" - adata.obs["batch"] = adata.obs["batch"].astype("category") + adata.obs["batch"] = pd.array(["batch"], dtype="category").repeat(len(adata)) sc.pp.highly_variable_genes(adata, batch_key="batch") assert np.all(no_batch_hvg == adata.var["highly_variable"]) assert np.all( From cf50542ef1dc903cf0d5573029f6deff2f2e041d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 15 Jan 2024 17:12:28 +0100 Subject: [PATCH 03/50] make tests make sense --- scanpy/tests/test_highly_variable_genes.py | 48 +++++++--------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 5ad644337f..d12a9e84e1 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING import numpy as np import pandas as pd @@ -10,42 +9,23 @@ from scipy import sparse import scanpy as sc -from scanpy._compat import DaskArray from scanpy.testing._helpers import _check_check_values_warnings from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced from scanpy.testing._pytest.marks import needs from scanpy.testing._pytest.params import ARRAY_TYPES_SUPPORTED -if TYPE_CHECKING: - from collections.abc import Callable - FILE = Path(__file__).parent / Path("_scripts/seurat_hvg.csv") FILE_V3 = Path(__file__).parent / Path("_scripts/seurat_hvg_v3.csv.gz") FILE_V3_BATCH = Path(__file__).parent / Path("_scripts/seurat_hvg_v3_batch.csv") -def validate_array_type(obj: object, at: Callable[[np.ndarray], object]) -> None: - if isinstance(at, type): - assert isinstance(obj, at) - elif at.__name__ == "asarray": - assert isinstance(obj, np.ndarray) - elif "dask" in at.__name__: - assert isinstance(obj, DaskArray) - else: - raise AssertionError(f"Unsupported array type {at}") - - -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) -def test_highly_variable_genes_runs(array_type): +def test_highly_variable_genes_runs(): adata = sc.datasets.blobs() - adata.X = array_type(adata.X) sc.pp.highly_variable_genes(adata) -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) -def test_highly_variable_genes_supports_batch(array_type): +def test_highly_variable_genes_supports_batch(): adata = sc.datasets.blobs() - adata.X = array_type(adata.X) gen = np.random.default_rng(0) adata.obs["batch"] = pd.array( gen.binomial(3, 0.5, size=adata.n_obs), dtype="category" @@ -53,12 +33,9 @@ def test_highly_variable_genes_supports_batch(array_type): sc.pp.highly_variable_genes(adata, batch_key="batch") assert "highly_variable_nbatches" in adata.var.columns assert "highly_variable_intersection" in adata.var.columns - validate_array_type(adata.var["highly_variable_nbatches"], array_type) - validate_array_type(adata.var["highly_variable_intersection"], array_type) -@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) -def test_highly_variable_genes_supports_layers(array_type): +def test_highly_variable_genes_supports_layers(): def execute(layer: str | None) -> AnnData: gen = np.random.default_rng(0) adata = sc.datasets.blobs() @@ -66,7 +43,6 @@ def execute(layer: str | None) -> AnnData: if layer: new_layer = adata.X.copy() gen.shuffle(new_layer) - adata.layers[layer] = array_type(new_layer) del new_layer, adata.X adata.obs["batch"] = pd.array( gen.binomial(4, 0.5, size=(adata.n_obs)), dtype="category" @@ -75,8 +51,6 @@ def execute(layer: str | None) -> AnnData: adata, batch_key="batch", n_top_genes=3, layer=layer ) assert "highly_variable_nbatches" in adata.var.columns - validate_array_type(adata.var["highly_variable_nbatches"], array_type) - validate_array_type(adata.var["highly_variable"], array_type) assert adata.var["highly_variable"].sum() == 3 return adata @@ -97,23 +71,31 @@ def test_highly_variable_genes_no_batch_matches_batch(): ) -def test_highly_variable_genes_no_inplace(): +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +def test_highly_variable_genes_no_inplace(array_type): adata = sc.datasets.blobs() + adata.X = array_type(adata.X) adata.obs["batch"] = np.tile(["a", "b"], adata.shape[0] // 2) sc.pp.highly_variable_genes(adata, batch_key="batch") assert adata.var["highly_variable"].any() - colnames = [ + colnames = { "means", "dispersions", "dispersions_norm", "highly_variable_nbatches", "highly_variable_intersection", "highly_variable", - ] + } hvg_df = sc.pp.highly_variable_genes(adata, batch_key="batch", inplace=False) assert hvg_df is not None - assert np.all(np.isin(colnames, hvg_df.columns)) + assert colnames == set(hvg_df.columns) + if "dask" in array_type.__name__: + import dask.dataframe as dd + + assert isinstance(hvg_df, dd.DataFrame) + else: + assert isinstance(hvg_df, pd.DataFrame) @pytest.mark.parametrize("base", [None, 10]) From b3ca8a84f4c85d52296d5bab2c8684e48ad916e3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 16 Jan 2024 09:41:45 +0100 Subject: [PATCH 04/50] ddf --- scanpy/_compat.py | 13 ++++++++++++- .../preprocessing/_highly_variable_genes.py | 19 +++++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/scanpy/_compat.py b/scanpy/_compat.py index 244f8588fa..81deadc226 100644 --- a/scanpy/_compat.py +++ b/scanpy/_compat.py @@ -16,13 +16,24 @@ try: from dask.array import Array as DaskArray + from dask.dataframe import DataFrame as DaskDataFrame except ImportError: class DaskArray: pass + class DaskDataFrame: + pass + -__all__ = ["cache", "DaskArray", "fullname", "pkg_metadata", "pkg_version"] +__all__ = [ + "cache", + "DaskArray", + "DaskDataFrame", + "fullname", + "pkg_metadata", + "pkg_version", +] def fullname(typ: type) -> str: diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 5dd28c8748..4f1a89ffe1 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -10,7 +10,7 @@ from anndata import AnnData from .. import logging as logg -from .._compat import old_positionals +from .._compat import DaskDataFrame, old_positionals from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ._distributed import materialize_as_ndarray @@ -192,7 +192,7 @@ def _highly_variable_genes_single_batch( n_top_genes: int | None = None, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger"] = "seurat", -) -> pd.DataFrame: +) -> pd.DataFrame | DaskDataFrame: """\ See `highly_variable_genes`. @@ -337,7 +337,7 @@ def highly_variable_genes( inplace: bool = True, batch_key: str | None = None, check_values: bool = True, -) -> pd.DataFrame | None: +) -> pd.DataFrame | DaskDataFrame | None: """\ Annotate highly variable genes [Satija15]_ [Zheng17]_ [Stuart19]_. @@ -486,7 +486,7 @@ def highly_variable_genes( else: sanitize_anndata(adata) batches = adata.obs[batch_key].cat.categories - df = [] + dfs = [] gene_list = adata.var_names for batch in batches: adata_subset = adata[adata.obs[batch_key] == batch] @@ -526,9 +526,16 @@ def highly_variable_genes( idxs = np.concatenate((np.where(filt)[0], np.where(~filt)[0])) hvg = hvg.loc[np.argsort(idxs)] - df.append(hvg) + dfs.append(hvg) + + df: DaskDataFrame | pd.DataFrame + if isinstance(dfs[0], DaskDataFrame): + import dask.dataframe as dd + + df = dd.concat(dfs, axis=0) + else: + df = pd.concat(dfs, axis=0) - df = pd.concat(df, axis=0) df["highly_variable"] = df["highly_variable"].astype(int) df = df.groupby("gene", observed=True).agg( dict( From de6064ef9f37bd3c9938734a64551a6f03f54a8b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 16 Jan 2024 12:38:17 +0100 Subject: [PATCH 05/50] ddf --- scanpy/tests/test_highly_variable_genes.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index d12a9e84e1..a5e0f434e3 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -9,6 +9,7 @@ from scipy import sparse import scanpy as sc +from scanpy._compat import DaskDataFrame from scanpy.testing._helpers import _check_check_values_warnings from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced from scanpy.testing._pytest.marks import needs @@ -91,9 +92,7 @@ def test_highly_variable_genes_no_inplace(array_type): assert hvg_df is not None assert colnames == set(hvg_df.columns) if "dask" in array_type.__name__: - import dask.dataframe as dd - - assert isinstance(hvg_df, dd.DataFrame) + assert isinstance(hvg_df, DaskDataFrame) else: assert isinstance(hvg_df, pd.DataFrame) From 8d6142a1539faa61a739ba48e6951a3d6b5c7526 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 16 Jan 2024 16:52:28 +0100 Subject: [PATCH 06/50] Fix highly_variable_genes with layer specified --- scanpy/preprocessing/_highly_variable_genes.py | 5 ++++- scanpy/tests/test_highly_variable_genes.py | 7 +++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 4f1a89ffe1..9429259ac5 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -13,6 +13,7 @@ from .._compat import DaskDataFrame, old_positionals from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata +from ..get import _get_obs_rep from ._distributed import materialize_as_ndarray from ._simple import filter_genes from ._utils import _get_mean_var @@ -493,7 +494,9 @@ def highly_variable_genes( # Filter to genes that are in the dataset with settings.verbosity.override(Verbosity.error): - filt = filter_genes(adata_subset, min_cells=1, inplace=False)[0] + filt, _ = filter_genes( + _get_obs_rep(adata_subset, layer=layer), min_cells=1, inplace=False + ) adata_subset = adata_subset[:, filt] diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index a5e0f434e3..98cdc252c7 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -42,11 +42,10 @@ def execute(layer: str | None) -> AnnData: adata = sc.datasets.blobs() assert isinstance(adata.X, np.ndarray) if layer: - new_layer = adata.X.copy() - gen.shuffle(new_layer) - del new_layer, adata.X + adata.X, adata.layers[layer] = None, adata.X.copy() + gen.shuffle(adata.layers[layer]) adata.obs["batch"] = pd.array( - gen.binomial(4, 0.5, size=(adata.n_obs)), dtype="category" + gen.binomial(4, 0.5, size=adata.n_obs), dtype="category" ) sc.pp.highly_variable_genes( adata, batch_key="batch", n_top_genes=3, layer=layer From 16190c622be83816f50bee7488d5a36eb9937a49 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 16 Jan 2024 17:56:19 +0100 Subject: [PATCH 07/50] WIP --- scanpy/_compat.py | 4 ++ .../preprocessing/_highly_variable_genes.py | 46 +++++++++++++------ scanpy/tests/test_highly_variable_genes.py | 25 +++++----- 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/scanpy/_compat.py b/scanpy/_compat.py index e2902e965b..d14bc65f90 100644 --- a/scanpy/_compat.py +++ b/scanpy/_compat.py @@ -17,6 +17,7 @@ try: from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame + from dask.dataframe import Series as DaskSeries except ImportError: class DaskArray: @@ -25,6 +26,9 @@ class DaskArray: class DaskDataFrame: pass + class DaskSeries: + pass + try: from zappy.base import ZappyArray diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 9429259ac5..f4a3386a57 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -10,7 +10,7 @@ from anndata import AnnData from .. import logging as logg -from .._compat import DaskDataFrame, old_positionals +from .._compat import DaskArray, DaskDataFrame, DaskSeries, old_positionals from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ..get import _get_obs_rep @@ -213,7 +213,7 @@ def _highly_variable_genes_single_batch( else: X = np.expm1(X) - mean, var = materialize_as_ndarray(_get_mean_var(X)) + mean, var = _get_mean_var(X) # now actually compute the dispersion mean[mean == 0] = 1e-12 # set entries equal to zero to small value dispersion = var / mean @@ -222,11 +222,17 @@ def _highly_variable_genes_single_batch( dispersion = np.log(dispersion) mean = np.log1p(mean) # all of the following quantities are "per-gene" here - df = pd.DataFrame() - df["means"] = mean - df["dispersions"] = dispersion + if isinstance(X, DaskArray): + import dask.array as da + import dask.dataframe as dd + + df = dd.from_dask_array( + da.vstack((mean, dispersion)).T, columns=["means", "dispersions"] + ) + else: + df = pd.DataFrame(dict(means=mean, dispersions=dispersion)) if flavor == "seurat": - df["mean_bin"] = pd.cut(df["means"], bins=n_bins) + df["mean_bin"] = _ser_cut(df["means"], bins=n_bins) disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] disp_mean_bin = disp_grouped.mean() disp_std_bin = disp_grouped.std(ddof=1) @@ -234,7 +240,7 @@ def _highly_variable_genes_single_batch( # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 one_gene_per_bin = disp_std_bin.isnull() - gen_indices = np.where(one_gene_per_bin[df["mean_bin"].to_numpy()])[0].tolist() + gen_indices = np.flatnonzero(one_gene_per_bin.loc[df["mean_bin"]]) if len(gen_indices) > 0: logg.debug( f"Gene indices {gen_indices} fell into a single bin: their " @@ -255,9 +261,11 @@ def _highly_variable_genes_single_batch( elif flavor == "cell_ranger": from statsmodels import robust - df["mean_bin"] = pd.cut( + df["mean_bin"] = _ser_cut( df["means"], - np.r_[-np.inf, np.percentile(df["means"], np.arange(10, 105, 5)), np.inf], + bins=np.r_[ + -np.inf, np.percentile(df["means"], np.arange(10, 105, 5)), np.inf + ], ) disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] disp_median_bin = disp_grouped.median() @@ -307,6 +315,13 @@ def _highly_variable_genes_single_batch( return df +def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series | DaskSeries: + if isinstance(df, DaskSeries): + # TODO: does this make sense? + return df.map_partitions(pd.cut, bins=bins) + return pd.cut(df, bins=bins) + + @old_positionals( "layer", "n_top_genes", @@ -435,9 +450,9 @@ def highly_variable_genes( For `flavor='seurat_v3'`, rank of the gene according to normalized variance, median rank in the case of multiple batches `adata.var['highly_variable_nbatches']` : :class:`pandas.Series` (dtype `int`) - If batch_key is given, this denotes in how many batches genes are detected as HVG + If `batch_key` is given, this denotes in how many batches genes are detected as HVG `adata.var['highly_variable_intersection']` : :class:`pandas.Series` (dtype `bool`) - If batch_key is given, this denotes the genes that are highly variable in all batches + If `batch_key` is given, this denotes the genes that are highly variable in all batches Notes ----- @@ -494,8 +509,13 @@ def highly_variable_genes( # Filter to genes that are in the dataset with settings.verbosity.override(Verbosity.error): - filt, _ = filter_genes( - _get_obs_rep(adata_subset, layer=layer), min_cells=1, inplace=False + # TODO use groupby or so instead of materialize_as_ndarray + filt, _ = materialize_as_ndarray( + filter_genes( + _get_obs_rep(adata_subset, layer=layer), + min_cells=1, + inplace=False, + ) ) adata_subset = adata_subset[:, filt] diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 98cdc252c7..89ed97c95e 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -71,23 +71,24 @@ def test_highly_variable_genes_no_batch_matches_batch(): ) +@pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) -def test_highly_variable_genes_no_inplace(array_type): +def test_highly_variable_genes_no_inplace(array_type, batch_key): adata = sc.datasets.blobs() adata.X = array_type(adata.X) - adata.obs["batch"] = np.tile(["a", "b"], adata.shape[0] // 2) - sc.pp.highly_variable_genes(adata, batch_key="batch") + if batch_key: + adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2) + sc.pp.highly_variable_genes(adata, batch_key=batch_key, n_bins=3) assert adata.var["highly_variable"].any() - colnames = { - "means", - "dispersions", - "dispersions_norm", - "highly_variable_nbatches", - "highly_variable_intersection", - "highly_variable", - } - hvg_df = sc.pp.highly_variable_genes(adata, batch_key="batch", inplace=False) + colnames = {"means", "dispersions", "dispersions_norm", "highly_variable"} | ( + {"mean_bin"} + if batch_key is None + else {"highly_variable_nbatches", "highly_variable_intersection"} + ) + hvg_df = sc.pp.highly_variable_genes( + adata, batch_key=batch_key, n_bins=3, inplace=False + ) assert hvg_df is not None assert colnames == set(hvg_df.columns) if "dask" in array_type.__name__: From a0279085003f135f33245be9032f7446ddda8c55 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 18 Jan 2024 14:39:55 +0100 Subject: [PATCH 08/50] WIP --- scanpy/_compat.py | 13 ++++- scanpy/preprocessing/_distributed.py | 57 ++++++++++++++++++- .../preprocessing/_highly_variable_genes.py | 34 +++++------ 3 files changed, 86 insertions(+), 18 deletions(-) diff --git a/scanpy/_compat.py b/scanpy/_compat.py index d14bc65f90..a7022dded2 100644 --- a/scanpy/_compat.py +++ b/scanpy/_compat.py @@ -18,6 +18,8 @@ from dask.array import Array as DaskArray from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import Series as DaskSeries + from dask.dataframe.groupby import DataFrameGroupBy as DaskDataFrameGroupBy + from dask.dataframe.groupby import SeriesGroupBy as DaskSeriesGroupBy except ImportError: class DaskArray: @@ -29,6 +31,12 @@ class DaskDataFrame: class DaskSeries: pass + class DaskDataFrameGroupBy: + pass + + class DaskSeriesGroupBy: + pass + try: from zappy.base import ZappyArray @@ -41,8 +49,11 @@ class ZappyArray: __all__ = [ "cache", "DaskArray", - "DaskDataFrame", "ZappyArray", + "DaskDataFrame", + "DaskSeries", + "DaskDataFrameGroupBy", + "DaskSeriesGroupBy", "fullname", "pkg_metadata", "pkg_version", diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index 748ec3d671..bdfb625d53 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -1,13 +1,25 @@ from __future__ import annotations +from contextlib import contextmanager from typing import TYPE_CHECKING, overload import numpy as np -from scanpy._compat import DaskArray, ZappyArray +from scanpy._compat import ( + DaskArray, + DaskDataFrame, + DaskDataFrameGroupBy, + DaskSeries, + DaskSeriesGroupBy, + ZappyArray, +) if TYPE_CHECKING: + from collections.abc import Generator + + import pandas as pd from numpy.typing import ArrayLike + from pandas.core.groupby.generic import DataFrameGroupBy, SeriesGroupBy @overload @@ -47,3 +59,46 @@ def materialize_as_ndarray( import dask.array as da return da.compute(*a, sync=True) + + +@overload +def dask_compute(value: DaskDataFrame) -> pd.DataFrame: + ... + + +@overload +def dask_compute(value: DaskSeries) -> pd.Series: + ... + + +@overload +def dask_compute(value: DaskDataFrameGroupBy) -> DataFrameGroupBy: + ... + + +@overload +def dask_compute(value: DaskSeriesGroupBy) -> SeriesGroupBy: + ... + + +def dask_compute( + value: DaskDataFrame | DaskSeries | DaskDataFrameGroupBy | DaskSeriesGroupBy, +) -> pd.DataFrame | pd.Series | DataFrameGroupBy | SeriesGroupBy: + """Compute a dask array or series.""" + if isinstance( + value, (DaskDataFrame, DaskSeries, DaskDataFrameGroupBy, DaskSeriesGroupBy) + ): + with suppress_pandas_warning(): + return value.compute(sync=True) + return value + + +@contextmanager +def suppress_pandas_warning() -> Generator[None, None, None]: + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", r"The default of observed=False", category=FutureWarning + ) + yield diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index f4a3386a57..6685d78af0 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from functools import partial from inspect import signature from typing import Literal, cast @@ -14,7 +15,7 @@ from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ..get import _get_obs_rep -from ._distributed import materialize_as_ndarray +from ._distributed import dask_compute, materialize_as_ndarray, suppress_pandas_warning from ._simple import filter_genes from ._utils import _get_mean_var @@ -222,6 +223,7 @@ def _highly_variable_genes_single_batch( dispersion = np.log(dispersion) mean = np.log1p(mean) # all of the following quantities are "per-gene" here + df: pd.DataFrame | DaskDataFrame if isinstance(X, DaskArray): import dask.array as da import dask.dataframe as dd @@ -234,12 +236,14 @@ def _highly_variable_genes_single_batch( if flavor == "seurat": df["mean_bin"] = _ser_cut(df["means"], bins=n_bins) disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] - disp_mean_bin = disp_grouped.mean() - disp_std_bin = disp_grouped.std(ddof=1) + with suppress_pandas_warning(): + disp_bin_stats: pd.DataFrame = dask_compute( + disp_grouped.agg(mean="mean", std=partial(np.std, ddof=1)) + ) # retrieve those genes that have nan std, these are the ones where # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 - one_gene_per_bin = disp_std_bin.isnull() + one_gene_per_bin = disp_bin_stats["std"].isnull() gen_indices = np.flatnonzero(one_gene_per_bin.loc[df["mean_bin"]]) if len(gen_indices) > 0: logg.debug( @@ -247,17 +251,15 @@ def _highly_variable_genes_single_batch( "normalized dispersion was set to 1.\n " "Decreasing `n_bins` will likely avoid this effect." ) - # Circumvent pandas 0.23 bug. Both sides of the assignment have dtype==float32, - # but there’s still a dtype error without “.value”. - disp_std_bin[one_gene_per_bin.to_numpy()] = disp_mean_bin[ - one_gene_per_bin.to_numpy() - ].to_numpy() - disp_mean_bin[one_gene_per_bin.to_numpy()] = 0 + disp_bin_stats["std"].loc[one_gene_per_bin] = disp_bin_stats["mean"].loc[ + one_gene_per_bin + ] + disp_bin_stats["mean"].loc[one_gene_per_bin] = 0 + # (use values here as index differs) + disp_mean = disp_bin_stats["mean"].loc[df["mean_bin"]].to_numpy() + disp_std = disp_bin_stats["std"].loc[df["mean_bin"]].to_numpy() # actually do the normalization - df["dispersions_norm"] = ( - df["dispersions"].to_numpy() # use values here as index differs - - disp_mean_bin[df["mean_bin"].to_numpy()].to_numpy() - ) / disp_std_bin[df["mean_bin"].to_numpy()].to_numpy() + df["dispersions_norm"] = (df["dispersions"] - disp_mean) / disp_std elif flavor == "cell_ranger": from statsmodels import robust @@ -315,9 +317,9 @@ def _highly_variable_genes_single_batch( return df -def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series | DaskSeries: +def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series: if isinstance(df, DaskSeries): - # TODO: does this make sense? + # TODO: does map_partitions make sense for bin? return df.map_partitions(pd.cut, bins=bins) return pd.cut(df, bins=bins) From 4049faa5a1889989d97329d10243e6d144cdcbfa Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 09:46:03 +0100 Subject: [PATCH 09/50] WIP --- .../preprocessing/_highly_variable_genes.py | 97 ++++++++++--------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 6685d78af0..fe12c3f8b0 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -233,54 +233,17 @@ def _highly_variable_genes_single_batch( ) else: df = pd.DataFrame(dict(means=mean, dispersions=dispersion)) + # assign "mean_bin" column and compute dispersions_norm if flavor == "seurat": - df["mean_bin"] = _ser_cut(df["means"], bins=n_bins) - disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] - with suppress_pandas_warning(): - disp_bin_stats: pd.DataFrame = dask_compute( - disp_grouped.agg(mean="mean", std=partial(np.std, ddof=1)) - ) - # retrieve those genes that have nan std, these are the ones where - # only a single gene fell in the bin and implicitly set them to have - # a normalized disperion of 1 - one_gene_per_bin = disp_bin_stats["std"].isnull() - gen_indices = np.flatnonzero(one_gene_per_bin.loc[df["mean_bin"]]) - if len(gen_indices) > 0: - logg.debug( - f"Gene indices {gen_indices} fell into a single bin: their " - "normalized dispersion was set to 1.\n " - "Decreasing `n_bins` will likely avoid this effect." - ) - disp_bin_stats["std"].loc[one_gene_per_bin] = disp_bin_stats["mean"].loc[ - one_gene_per_bin - ] - disp_bin_stats["mean"].loc[one_gene_per_bin] = 0 - # (use values here as index differs) - disp_mean = disp_bin_stats["mean"].loc[df["mean_bin"]].to_numpy() - disp_std = disp_bin_stats["std"].loc[df["mean_bin"]].to_numpy() - # actually do the normalization - df["dispersions_norm"] = (df["dispersions"] - disp_mean) / disp_std + disp_avg, disp_dev = _stats_seurat(df, n_bins=n_bins) elif flavor == "cell_ranger": - from statsmodels import robust - - df["mean_bin"] = _ser_cut( - df["means"], - bins=np.r_[ - -np.inf, np.percentile(df["means"], np.arange(10, 105, 5)), np.inf - ], - ) - disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] - disp_median_bin = disp_grouped.median() - # the next line raises the warning: "Mean of empty slice" - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - disp_mad_bin = disp_grouped.apply(robust.mad) - df["dispersions_norm"] = ( - df["dispersions"].to_numpy() - - disp_median_bin[df["mean_bin"].to_numpy()].to_numpy() - ) / disp_mad_bin[df["mean_bin"].to_numpy()].to_numpy() + disp_avg, disp_dev = _stats_cell_ranger(df) else: raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') + + # actually do the normalization + df["dispersions_norm"] = (df["dispersions"] - disp_avg) / disp_dev + dispersion_norm = df["dispersions_norm"].to_numpy() if n_top_genes is not None: dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] @@ -317,6 +280,52 @@ def _highly_variable_genes_single_batch( return df +def _stats_seurat(df: pd.DataFrame | DaskDataFrame, *, n_bins: int): + df["mean_bin"] = _ser_cut(df["means"], bins=n_bins) + disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] + with suppress_pandas_warning(): + disp_bin_stats: pd.DataFrame = dask_compute( + disp_grouped.agg(mean="mean", std=partial(np.std, ddof=1)) + ) + # retrieve those genes that have nan std, these are the ones where + # only a single gene fell in the bin and implicitly set them to have + # a normalized disperion of 1 + one_gene_per_bin = disp_bin_stats["std"].isnull() + gen_indices = np.flatnonzero(one_gene_per_bin.loc[df["mean_bin"]]) + if len(gen_indices) > 0: + logg.debug( + f"Gene indices {gen_indices} fell into a single bin: their " + "normalized dispersion was set to 1.\n " + "Decreasing `n_bins` will likely avoid this effect." + ) + disp_bin_stats["std"].loc[one_gene_per_bin] = disp_bin_stats["mean"].loc[ + one_gene_per_bin + ] + disp_bin_stats["mean"].loc[one_gene_per_bin] = 0 + # (use values here as index differs) + disp_avg = disp_bin_stats["mean"].loc[df["mean_bin"]].reset_index(drop=True) + disp_dev = disp_bin_stats["std"].loc[df["mean_bin"]].reset_index(drop=True) + return disp_avg, disp_dev + + +def _stats_cell_ranger(df: pd.DataFrame | DaskDataFrame): + from statsmodels import robust + + df["mean_bin"] = _ser_cut( + df["means"], + bins=np.r_[-np.inf, np.percentile(df["means"], np.arange(10, 105, 5)), np.inf], + ) + disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] + disp_median_bin = disp_grouped.median() + # the next line raises the warning: "Mean of empty slice" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + disp_mad_bin = disp_grouped.apply(robust.mad) + disp_avg = disp_median_bin.loc[df["mean_bin"]].reset_index(drop=True) + disp_dev = disp_mad_bin.loc[df["mean_bin"]].reset_index(drop=True) + return disp_avg, disp_dev + + def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series: if isinstance(df, DaskSeries): # TODO: does map_partitions make sense for bin? From cff83e0e07711044b4cd26598d6e3c70c7dba9b5 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 10:06:25 +0100 Subject: [PATCH 10/50] extract --- .../preprocessing/_highly_variable_genes.py | 111 +++++++++++------- 1 file changed, 68 insertions(+), 43 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index fe12c3f8b0..a6c7cbab29 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -3,7 +3,7 @@ import warnings from functools import partial from inspect import signature -from typing import Literal, cast +from typing import TYPE_CHECKING, Literal, cast import numpy as np import pandas as pd @@ -19,6 +19,9 @@ from ._simple import filter_genes from ._utils import _get_mean_var +if TYPE_CHECKING: + from numpy.typing import NDArray + def _highly_variable_genes_seurat_v3( adata: AnnData, @@ -187,10 +190,10 @@ def _highly_variable_genes_single_batch( adata: AnnData, *, layer: str | None = None, - min_disp: float | None = 0.5, - max_disp: float | None = np.inf, - min_mean: float | None = 0.0125, - max_mean: float | None = 3, + min_disp: float = 0.5, + max_disp: float = np.inf, + min_mean: float = 0.0125, + max_mean: float = 3, n_top_genes: int | None = None, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger"] = "seurat", @@ -243,40 +246,17 @@ def _highly_variable_genes_single_batch( # actually do the normalization df["dispersions_norm"] = (df["dispersions"] - disp_avg) / disp_dev + df["highly_variable"] = _subset_genes( + adata, + mean=mean, + dispersion_norm=df["dispersions_norm"], + min_disp=min_disp, + max_disp=max_disp, + min_mean=min_mean, + max_mean=max_mean, + n_top_genes=n_top_genes, + ) - dispersion_norm = df["dispersions_norm"].to_numpy() - if n_top_genes is not None: - dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] - dispersion_norm[ - ::-1 - ].sort() # interestingly, np.argpartition is slightly slower - if n_top_genes > adata.n_vars: - logg.info("`n_top_genes` > `adata.n_var`, returning all genes.") - n_top_genes = adata.n_vars - if n_top_genes > dispersion_norm.size: - warnings.warn( - "`n_top_genes` > number of normalized dispersions, returning all genes with normalized dispersions.", - UserWarning, - ) - n_top_genes = dispersion_norm.size - disp_cut_off = dispersion_norm[n_top_genes - 1] - gene_subset = np.nan_to_num(df["dispersions_norm"].to_numpy()) >= disp_cut_off - logg.debug( - f"the {n_top_genes} top genes correspond to a " - f"normalized dispersion cutoff of {disp_cut_off}" - ) - else: - dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat - gene_subset = np.logical_and.reduce( - ( - mean > min_mean, - mean < max_mean, - dispersion_norm > min_disp, - dispersion_norm < max_disp, - ) - ) - - df["highly_variable"] = gene_subset return df @@ -333,6 +313,47 @@ def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series: return pd.cut(df, bins=bins) +def _subset_genes( + adata, + *, + mean: NDArray[np.float64] | DaskArray, + dispersion_norm: pd.Series[float] | DaskSeries, + min_disp: float, + max_disp: float, + min_mean: float, + max_mean: float, + n_top_genes: int | None, +) -> NDArray[np.float64] | DaskArray: + if n_top_genes is None: + dispersion_norm.loc[np.isnan(dispersion_norm)] = 0 # similar to Seurat + return np.logical_and.reduce( + ( + mean > min_mean, + mean < max_mean, + dispersion_norm > min_disp, + dispersion_norm < max_disp, + ) + ) + dispersion_norm = dispersion_norm.loc[~np.isnan(dispersion_norm)] + # interestingly, np.argpartition is slightly slower + dispersion_norm[::-1].sort() + if n_top_genes > adata.n_vars: + logg.info("`n_top_genes` > `adata.n_var`, returning all genes.") + n_top_genes = adata.n_vars + if n_top_genes > dispersion_norm.size: + warnings.warn( + "`n_top_genes` > number of normalized dispersions, returning all genes with normalized dispersions.", + UserWarning, + ) + n_top_genes = dispersion_norm.size + disp_cut_off = dispersion_norm[n_top_genes - 1] + logg.debug( + f"the {n_top_genes} top genes correspond to a " + f"normalized dispersion cutoff of {disp_cut_off}" + ) + return np.nan_to_num(dispersion_norm.to_numpy()) >= disp_cut_off + + @old_positionals( "layer", "n_top_genes", @@ -353,10 +374,10 @@ def highly_variable_genes( *, layer: str | None = None, n_top_genes: int | None = None, - min_disp: float | None = 0.5, - max_disp: float | None = np.inf, - min_mean: float | None = 0.0125, - max_mean: float | None = 3, + min_disp: float = 0.5, + max_disp: float = np.inf, + min_mean: float = 0.0125, + max_mean: float = 3, span: float = 0.3, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger", "seurat_v3"] = "seurat", @@ -470,8 +491,12 @@ def highly_variable_genes( This function replaces :func:`~scanpy.pp.filter_genes_dispersion`. """ + defaults = { + p.name: p.default for p in signature(highly_variable_genes).parameters.values() + } if n_top_genes is not None and not all( - m is None for m in [min_disp, max_disp, min_mean, max_mean] + locals()[m] == defaults[m] + for m in ["min_disp", "max_disp", "min_mean", "max_mean"] ): logg.info("If you pass `n_top_genes`, all cutoffs are ignored.") From 826e0b03842bb01b9b306644874fcf0c387f3364 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 10:16:42 +0100 Subject: [PATCH 11/50] almost --- scanpy/preprocessing/_distributed.py | 14 +++++++++ .../preprocessing/_highly_variable_genes.py | 31 ++++++++++--------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index bdfb625d53..b0622c8248 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -22,6 +22,20 @@ from pandas.core.groupby.generic import DataFrameGroupBy, SeriesGroupBy +@overload +def series_to_array(s: pd.Series) -> np.ndarray: + ... + + +@overload +def series_to_array(s: DaskSeries) -> DaskArray: + ... + + +def series_to_array(s: pd.Series | DaskSeries) -> np.ndarray | DaskArray: + return s.to_dask_array() if isinstance(s, DaskSeries) else s.to_numpy() + + @overload def materialize_as_ndarray(a: ArrayLike) -> np.ndarray: ... diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index a6c7cbab29..ebbe4b5f43 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -15,7 +15,12 @@ from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ..get import _get_obs_rep -from ._distributed import dask_compute, materialize_as_ndarray, suppress_pandas_warning +from ._distributed import ( + dask_compute, + materialize_as_ndarray, + series_to_array, + suppress_pandas_warning, +) from ._simple import filter_genes from ._utils import _get_mean_var @@ -249,7 +254,7 @@ def _highly_variable_genes_single_batch( df["highly_variable"] = _subset_genes( adata, mean=mean, - dispersion_norm=df["dispersions_norm"], + dispersion_norm=series_to_array(df["dispersions_norm"]), min_disp=min_disp, max_disp=max_disp, min_mean=min_mean, @@ -314,10 +319,10 @@ def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series: def _subset_genes( - adata, + adata: AnnData, *, mean: NDArray[np.float64] | DaskArray, - dispersion_norm: pd.Series[float] | DaskSeries, + dispersion_norm: NDArray[np.float64] | DaskArray, min_disp: float, max_disp: float, min_mean: float, @@ -325,16 +330,14 @@ def _subset_genes( n_top_genes: int | None, ) -> NDArray[np.float64] | DaskArray: if n_top_genes is None: - dispersion_norm.loc[np.isnan(dispersion_norm)] = 0 # similar to Seurat - return np.logical_and.reduce( - ( - mean > min_mean, - mean < max_mean, - dispersion_norm > min_disp, - dispersion_norm < max_disp, - ) + dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat + return ( + (mean > min_mean) + & (mean < max_mean) + & (dispersion_norm > min_disp) + & (dispersion_norm < max_disp) ) - dispersion_norm = dispersion_norm.loc[~np.isnan(dispersion_norm)] + dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] # interestingly, np.argpartition is slightly slower dispersion_norm[::-1].sort() if n_top_genes > adata.n_vars: @@ -351,7 +354,7 @@ def _subset_genes( f"the {n_top_genes} top genes correspond to a " f"normalized dispersion cutoff of {disp_cut_off}" ) - return np.nan_to_num(dispersion_norm.to_numpy()) >= disp_cut_off + return np.nan_to_num(dispersion_norm) >= disp_cut_off @old_positionals( From 55b9d4b56a3832f56f1d80b3ec33f6bad7b44bb3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 10:24:24 +0100 Subject: [PATCH 12/50] single works --- scanpy/preprocessing/_distributed.py | 16 ++++++-- .../preprocessing/_highly_variable_genes.py | 38 +++++++++---------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index b0622c8248..45087c3f32 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -23,17 +23,25 @@ @overload -def series_to_array(s: pd.Series) -> np.ndarray: +def series_to_array(s: pd.Series, *, dtype: np.dtype | None = None) -> np.ndarray: ... @overload -def series_to_array(s: DaskSeries) -> DaskArray: +def series_to_array(s: DaskSeries, *, dtype: np.dtype | None = None) -> DaskArray: ... -def series_to_array(s: pd.Series | DaskSeries) -> np.ndarray | DaskArray: - return s.to_dask_array() if isinstance(s, DaskSeries) else s.to_numpy() +def series_to_array( + s: pd.Series | DaskSeries, *, dtype: np.dtype | None = None +) -> np.ndarray | DaskArray: + if isinstance(s, DaskSeries): + return ( + s.to_dask_array(True) + if dtype is None + else s.astype(dtype).to_dask_array(True) + ) + return s.to_numpy() if dtype is None else s.to_numpy().astype(dtype, copy=False) @overload diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index ebbe4b5f43..e3b3e2197c 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -629,15 +629,13 @@ def highly_variable_genes( df = df.loc[adata.var_names, :] else: df = df.loc[adata.var_names] - dispersion_norm = df["dispersions_norm"].to_numpy() + dispersion_norm = series_to_array(df["dispersions_norm"]) dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat - gene_subset = np.logical_and.reduce( - ( - df["means"] > min_mean, - df["means"] < max_mean, - df["dispersions_norm"] > min_disp, - df["dispersions_norm"] < max_disp, - ) + gene_subset = ( + (df["means"] > min_mean) + & (df["means"] < max_mean) + & (df["dispersions_norm"] > min_disp) + & (df["dispersions_norm"] < max_disp) ) df["highly_variable"] = gene_subset @@ -652,22 +650,22 @@ def highly_variable_genes( " 'dispersions', float vector (adata.var)\n" " 'dispersions_norm', float vector (adata.var)" ) - adata.var["highly_variable"] = df["highly_variable"].to_numpy() - adata.var["means"] = df["means"].to_numpy() - adata.var["dispersions"] = df["dispersions"].to_numpy() - adata.var["dispersions_norm"] = ( - df["dispersions_norm"].to_numpy().astype("float32", copy=False) + adata.var["highly_variable"] = series_to_array(df["highly_variable"]) + adata.var["means"] = series_to_array(df["means"]) + adata.var["dispersions"] = series_to_array(df["dispersions"]) + adata.var["dispersions_norm"] = series_to_array(df["dispersions_norm"]).astype( + "float32" ) if batch_key is not None: - adata.var["highly_variable_nbatches"] = df[ - "highly_variable_nbatches" - ].to_numpy() - adata.var["highly_variable_intersection"] = df[ - "highly_variable_intersection" - ].to_numpy() + adata.var["highly_variable_nbatches"] = series_to_array( + df["highly_variable_nbatches"] + ) + adata.var["highly_variable_intersection"] = series_to_array( + df["highly_variable_intersection"] + ) if subset: - adata._inplace_subset_var(df["highly_variable"].to_numpy()) + adata._inplace_subset_var(series_to_array(df["highly_variable"])) else: if subset: From 4f3b41087f0dcf2e2f5d8d1adf1dc081a3a88d96 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 11:09:29 +0100 Subject: [PATCH 13/50] refactor cutoffs --- .../preprocessing/_highly_variable_genes.py | 145 ++++++++++-------- 1 file changed, 79 insertions(+), 66 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index e3b3e2197c..b788a452ac 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from dataclasses import dataclass from functools import partial from inspect import signature from typing import TYPE_CHECKING, Literal, cast @@ -191,15 +192,54 @@ def _highly_variable_genes_seurat_v3( return df +@dataclass +class _Cutoffs: + min_disp: float + max_disp: float + min_mean: float + max_mean: float + + @classmethod + def validate( + cls, + *, + min_disp: float, + max_disp: float, + min_mean: float, + max_mean: float, + n_top_genes: int | None, + ) -> _Cutoffs | int: + if n_top_genes is None: + return cls(min_disp, max_disp, min_mean, max_mean) + + cutoffs = {"min_disp", "max_disp", "min_mean", "max_mean"} + defaults = { + p.name: p.default + for p in signature(highly_variable_genes).parameters.values() + if p.name in cutoffs + } + if {k: v for k, v in locals().items() if k in cutoffs} != defaults: + logg.info("If you pass `n_top_genes`, all cutoffs are ignored.") + return n_top_genes + + def in_bounds( + self, + mean: NDArray[np.float64] | DaskArray, + dispersion_norm: NDArray[np.float64] | DaskArray, + ) -> NDArray[np.bool_] | DaskArray: + return ( + (mean > self.min_mean) + & (mean < self.max_mean) + & (dispersion_norm > self.min_disp) + & (dispersion_norm < self.max_disp) + ) + + def _highly_variable_genes_single_batch( adata: AnnData, *, layer: str | None = None, - min_disp: float = 0.5, - max_disp: float = np.inf, - min_mean: float = 0.0125, - max_mean: float = 3, - n_top_genes: int | None = None, + cutoff: _Cutoffs | int, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger"] = "seurat", ) -> pd.DataFrame | DaskDataFrame: @@ -255,17 +295,16 @@ def _highly_variable_genes_single_batch( adata, mean=mean, dispersion_norm=series_to_array(df["dispersions_norm"]), - min_disp=min_disp, - max_disp=max_disp, - min_mean=min_mean, - max_mean=max_mean, - n_top_genes=n_top_genes, + cutoff=cutoff, ) return df -def _stats_seurat(df: pd.DataFrame | DaskDataFrame, *, n_bins: int): +def _stats_seurat( + df: pd.DataFrame | DaskDataFrame, *, n_bins: int +) -> tuple[pd.Series | DaskSeries, pd.Series | DaskSeries]: + """Assign "mean_bin" column and compute mean and std dev per bin.""" df["mean_bin"] = _ser_cut(df["means"], bins=n_bins) disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] with suppress_pandas_warning(): @@ -293,7 +332,10 @@ def _stats_seurat(df: pd.DataFrame | DaskDataFrame, *, n_bins: int): return disp_avg, disp_dev -def _stats_cell_ranger(df: pd.DataFrame | DaskDataFrame): +def _stats_cell_ranger( + df: pd.DataFrame | DaskDataFrame, +) -> tuple[pd.Series | DaskSeries, pd.Series | DaskSeries]: + """Assign "mean_bin" column and compute median and median absolute dev per bin.""" from statsmodels import robust df["mean_bin"] = _ser_cut( @@ -311,9 +353,9 @@ def _stats_cell_ranger(df: pd.DataFrame | DaskDataFrame): return disp_avg, disp_dev -def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series: +def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series | DaskSeries: if isinstance(df, DaskSeries): - # TODO: does map_partitions make sense for bin? + # TODO: does map_partitions make sense for bin? It would bin per chunk, not globally return df.map_partitions(pd.cut, bins=bins) return pd.cut(df, bins=bins) @@ -323,20 +365,13 @@ def _subset_genes( *, mean: NDArray[np.float64] | DaskArray, dispersion_norm: NDArray[np.float64] | DaskArray, - min_disp: float, - max_disp: float, - min_mean: float, - max_mean: float, - n_top_genes: int | None, -) -> NDArray[np.float64] | DaskArray: - if n_top_genes is None: + cutoff: _Cutoffs | int, +) -> NDArray[np.bool_] | DaskArray: + if isinstance(cutoff, _Cutoffs): dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat - return ( - (mean > min_mean) - & (mean < max_mean) - & (dispersion_norm > min_disp) - & (dispersion_norm < max_disp) - ) + return cutoff.in_bounds(mean, dispersion_norm) + n_top_genes = cutoff + del cutoff dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] # interestingly, np.argpartition is slightly slower dispersion_norm[::-1].sort() @@ -494,15 +529,6 @@ def highly_variable_genes( This function replaces :func:`~scanpy.pp.filter_genes_dispersion`. """ - defaults = { - p.name: p.default for p in signature(highly_variable_genes).parameters.values() - } - if n_top_genes is not None and not all( - locals()[m] == defaults[m] - for m in ["min_disp", "max_disp", "min_mean", "max_mean"] - ): - logg.info("If you pass `n_top_genes`, all cutoffs are ignored.") - start = logg.info("extracting highly variable genes") if not isinstance(adata, AnnData): @@ -526,17 +552,18 @@ def highly_variable_genes( inplace=inplace, ) + cutoff = _Cutoffs.validate( + min_disp=min_disp, + max_disp=max_disp, + min_mean=min_mean, + max_mean=max_mean, + n_top_genes=n_top_genes, + ) + del min_disp, max_disp, min_mean, max_mean, n_top_genes + if batch_key is None: df = _highly_variable_genes_single_batch( - adata, - layer=layer, - min_disp=min_disp, - max_disp=max_disp, - min_mean=min_mean, - max_mean=max_mean, - n_top_genes=n_top_genes, - n_bins=n_bins, - flavor=flavor, + adata, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) else: sanitize_anndata(adata) @@ -560,15 +587,7 @@ def highly_variable_genes( adata_subset = adata_subset[:, filt] hvg = _highly_variable_genes_single_batch( - adata_subset, - layer=layer, - min_disp=min_disp, - max_disp=max_disp, - min_mean=min_mean, - max_mean=max_mean, - n_top_genes=n_top_genes, - n_bins=n_bins, - flavor=flavor, + adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) hvg["gene"] = adata_subset.var_names.to_numpy() @@ -614,7 +633,7 @@ def highly_variable_genes( batches ) - if n_top_genes is not None: + if isinstance(cutoff, int): # sort genes by how often they selected as hvg within each batch and # break ties with normalized dispersion across batches df.sort_values( @@ -623,21 +642,15 @@ def highly_variable_genes( na_position="last", inplace=True, ) - high_var = np.zeros(df.shape[0]) - high_var[:n_top_genes] = True - df["highly_variable"] = high_var.astype(bool) + df["highly_variable"] = np.arange(df.shape[0]) < cutoff df = df.loc[adata.var_names, :] else: df = df.loc[adata.var_names] dispersion_norm = series_to_array(df["dispersions_norm"]) dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat - gene_subset = ( - (df["means"] > min_mean) - & (df["means"] < max_mean) - & (df["dispersions_norm"] > min_disp) - & (df["dispersions_norm"] < max_disp) + df["highly_variable"] = cutoff.in_bounds( + df["means"], df["dispersions_norm"] ) - df["highly_variable"] = gene_subset logg.info(" finished", time=start) @@ -653,8 +666,8 @@ def highly_variable_genes( adata.var["highly_variable"] = series_to_array(df["highly_variable"]) adata.var["means"] = series_to_array(df["means"]) adata.var["dispersions"] = series_to_array(df["dispersions"]) - adata.var["dispersions_norm"] = series_to_array(df["dispersions_norm"]).astype( - "float32" + adata.var["dispersions_norm"] = series_to_array( + df["dispersions_norm"], dtype=np.float32 ) if batch_key is not None: From 6af0e98762e9a9627dd76833ebda3fe1ea510136 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 11:20:18 +0100 Subject: [PATCH 14/50] Fix regression --- scanpy/preprocessing/_highly_variable_genes.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index b788a452ac..6c7605d6b0 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -367,11 +367,14 @@ def _subset_genes( dispersion_norm: NDArray[np.float64] | DaskArray, cutoff: _Cutoffs | int, ) -> NDArray[np.bool_] | DaskArray: + """Get boolean mask of genes with normalized dispersion in bounds.""" if isinstance(cutoff, _Cutoffs): dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat return cutoff.in_bounds(mean, dispersion_norm) n_top_genes = cutoff del cutoff + + dispersion_norm_orig = dispersion_norm # original length dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] # interestingly, np.argpartition is slightly slower dispersion_norm[::-1].sort() @@ -389,7 +392,7 @@ def _subset_genes( f"the {n_top_genes} top genes correspond to a " f"normalized dispersion cutoff of {disp_cut_off}" ) - return np.nan_to_num(dispersion_norm) >= disp_cut_off + return np.nan_to_num(dispersion_norm_orig) >= disp_cut_off @old_positionals( From 350f9f56a535c69d2001ed2b5dc19c4d239336fe Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 11:24:57 +0100 Subject: [PATCH 15/50] extract batch branch --- .../preprocessing/_highly_variable_genes.py | 176 +++++++++--------- 1 file changed, 92 insertions(+), 84 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 6c7605d6b0..87d4084b33 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -395,6 +395,96 @@ def _subset_genes( return np.nan_to_num(dispersion_norm_orig) >= disp_cut_off +def _highly_variable_genes_batched( + adata: AnnData, + batch_key: str, + *, + layer: str | None, + n_bins: int, + flavor: Literal["seurat", "cell_ranger"], + cutoff: _Cutoffs | int, +) -> pd.DataFrame | DaskDataFrame: + sanitize_anndata(adata) + batches = adata.obs[batch_key].cat.categories + dfs = [] + gene_list = adata.var_names + for batch in batches: + adata_subset = adata[adata.obs[batch_key] == batch] + + # Filter to genes that are in the dataset + with settings.verbosity.override(Verbosity.error): + # TODO use groupby or so instead of materialize_as_ndarray + filt, _ = materialize_as_ndarray( + filter_genes( + _get_obs_rep(adata_subset, layer=layer), + min_cells=1, + inplace=False, + ) + ) + + adata_subset = adata_subset[:, filt] + + hvg = _highly_variable_genes_single_batch( + adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor + ) + + hvg["gene"] = adata_subset.var_names.to_numpy() + if (n_removed := np.sum(~filt)) > 0: + # Add 0 values for genes that were filtered out + missing_hvg = pd.DataFrame( + np.zeros((n_removed, len(hvg.columns))), + columns=hvg.columns, + ) + missing_hvg["highly_variable"] = missing_hvg["highly_variable"].astype(bool) + missing_hvg["gene"] = gene_list[~filt] + hvg = pd.concat([hvg, missing_hvg], ignore_index=True) + + # Order as before filtering + idxs = np.concatenate((np.where(filt)[0], np.where(~filt)[0])) + hvg = hvg.loc[np.argsort(idxs)] + + dfs.append(hvg) + + df: DaskDataFrame | pd.DataFrame + if isinstance(dfs[0], DaskDataFrame): + import dask.dataframe as dd + + df = dd.concat(dfs, axis=0) + else: + df = pd.concat(dfs, axis=0) + + df["highly_variable"] = df["highly_variable"].astype(int) + df = df.groupby("gene", observed=True).agg( + dict( + means="mean", + dispersions="mean", + dispersions_norm="mean", + highly_variable="sum", + ) + ) + df.rename(columns=dict(highly_variable="highly_variable_nbatches"), inplace=True) + df["highly_variable_intersection"] = df["highly_variable_nbatches"] == len(batches) + + if isinstance(cutoff, int): + # sort genes by how often they selected as hvg within each batch and + # break ties with normalized dispersion across batches + df.sort_values( + ["highly_variable_nbatches", "dispersions_norm"], + ascending=False, + na_position="last", + inplace=True, + ) + df["highly_variable"] = np.arange(df.shape[0]) < cutoff + df = df.loc[adata.var_names, :] + else: + df = df.loc[adata.var_names] + dispersion_norm = series_to_array(df["dispersions_norm"]) + dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat + df["highly_variable"] = cutoff.in_bounds(df["means"], df["dispersions_norm"]) + + return df + + @old_positionals( "layer", "n_top_genes", @@ -569,91 +659,9 @@ def highly_variable_genes( adata, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) else: - sanitize_anndata(adata) - batches = adata.obs[batch_key].cat.categories - dfs = [] - gene_list = adata.var_names - for batch in batches: - adata_subset = adata[adata.obs[batch_key] == batch] - - # Filter to genes that are in the dataset - with settings.verbosity.override(Verbosity.error): - # TODO use groupby or so instead of materialize_as_ndarray - filt, _ = materialize_as_ndarray( - filter_genes( - _get_obs_rep(adata_subset, layer=layer), - min_cells=1, - inplace=False, - ) - ) - - adata_subset = adata_subset[:, filt] - - hvg = _highly_variable_genes_single_batch( - adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor - ) - - hvg["gene"] = adata_subset.var_names.to_numpy() - if (n_removed := np.sum(~filt)) > 0: - # Add 0 values for genes that were filtered out - missing_hvg = pd.DataFrame( - np.zeros((n_removed, len(hvg.columns))), - columns=hvg.columns, - ) - missing_hvg["highly_variable"] = missing_hvg["highly_variable"].astype( - bool - ) - missing_hvg["gene"] = gene_list[~filt] - hvg = pd.concat([hvg, missing_hvg], ignore_index=True) - - # Order as before filtering - idxs = np.concatenate((np.where(filt)[0], np.where(~filt)[0])) - hvg = hvg.loc[np.argsort(idxs)] - - dfs.append(hvg) - - df: DaskDataFrame | pd.DataFrame - if isinstance(dfs[0], DaskDataFrame): - import dask.dataframe as dd - - df = dd.concat(dfs, axis=0) - else: - df = pd.concat(dfs, axis=0) - - df["highly_variable"] = df["highly_variable"].astype(int) - df = df.groupby("gene", observed=True).agg( - dict( - means="mean", - dispersions="mean", - dispersions_norm="mean", - highly_variable="sum", - ) + df = _highly_variable_genes_batched( + adata, batch_key, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) - df.rename( - columns=dict(highly_variable="highly_variable_nbatches"), inplace=True - ) - df["highly_variable_intersection"] = df["highly_variable_nbatches"] == len( - batches - ) - - if isinstance(cutoff, int): - # sort genes by how often they selected as hvg within each batch and - # break ties with normalized dispersion across batches - df.sort_values( - ["highly_variable_nbatches", "dispersions_norm"], - ascending=False, - na_position="last", - inplace=True, - ) - df["highly_variable"] = np.arange(df.shape[0]) < cutoff - df = df.loc[adata.var_names, :] - else: - df = df.loc[adata.var_names] - dispersion_norm = series_to_array(df["dispersions_norm"]) - dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat - df["highly_variable"] = cutoff.in_bounds( - df["means"], df["dispersions_norm"] - ) logg.info(" finished", time=start) From 47d676c3109c7b142290dcb449e255155ed43702 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 12:21:00 +0100 Subject: [PATCH 16/50] var names --- .../preprocessing/_highly_variable_genes.py | 10 +++---- scanpy/tests/test_highly_variable_genes.py | 27 ++++++++++++------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 87d4084b33..c6c82cebe6 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -203,11 +203,11 @@ class _Cutoffs: def validate( cls, *, + n_top_genes: int | None, min_disp: float, max_disp: float, min_mean: float, max_mean: float, - n_top_genes: int | None, ) -> _Cutoffs | int: if n_top_genes is None: return cls(min_disp, max_disp, min_mean, max_mean) @@ -428,7 +428,7 @@ def _highly_variable_genes_batched( adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) - hvg["gene"] = adata_subset.var_names.to_numpy() + hvg["gene"] = adata_subset.var_names if (n_removed := np.sum(~filt)) > 0: # Add 0 values for genes that were filtered out missing_hvg = pd.DataFrame( @@ -439,8 +439,8 @@ def _highly_variable_genes_batched( missing_hvg["gene"] = gene_list[~filt] hvg = pd.concat([hvg, missing_hvg], ignore_index=True) - # Order as before filtering - idxs = np.concatenate((np.where(filt)[0], np.where(~filt)[0])) + # Order as before filtering + idxs = np.concatenate((np.flatnonzero(filt), np.flatnonzero(~filt))) hvg = hvg.loc[np.argsort(idxs)] dfs.append(hvg) @@ -646,11 +646,11 @@ def highly_variable_genes( ) cutoff = _Cutoffs.validate( + n_top_genes=n_top_genes, min_disp=min_disp, max_disp=max_disp, min_mean=min_mean, max_mean=max_mean, - n_top_genes=n_top_genes, ) del min_disp, max_disp, min_mean, max_mean, n_top_genes diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 89ed97c95e..b149ea1d5b 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from string import ascii_letters import numpy as np import pandas as pd @@ -20,13 +21,23 @@ FILE_V3_BATCH = Path(__file__).parent / Path("_scripts/seurat_hvg_v3_batch.csv") -def test_highly_variable_genes_runs(): +@pytest.fixture(scope="session") +def adata_sess() -> AnnData: adata = sc.datasets.blobs() + adata.var_names = list(ascii_letters[: adata.n_vars]) + return adata + + +@pytest.fixture +def adata(adata_sess: AnnData) -> AnnData: + return adata_sess.copy() + + +def test_highly_variable_genes_runs(adata): sc.pp.highly_variable_genes(adata) -def test_highly_variable_genes_supports_batch(): - adata = sc.datasets.blobs() +def test_highly_variable_genes_supports_batch(adata): gen = np.random.default_rng(0) adata.obs["batch"] = pd.array( gen.binomial(3, 0.5, size=adata.n_obs), dtype="category" @@ -36,10 +47,10 @@ def test_highly_variable_genes_supports_batch(): assert "highly_variable_intersection" in adata.var.columns -def test_highly_variable_genes_supports_layers(): +def test_highly_variable_genes_supports_layers(adata_sess): def execute(layer: str | None) -> AnnData: gen = np.random.default_rng(0) - adata = sc.datasets.blobs() + adata = adata_sess.copy() assert isinstance(adata.X, np.ndarray) if layer: adata.X, adata.layers[layer] = None, adata.X.copy() @@ -58,8 +69,7 @@ def execute(layer: str | None) -> AnnData: assert (adata1.var["highly_variable"] != adata2.var["highly_variable"]).any() -def test_highly_variable_genes_no_batch_matches_batch(): - adata = sc.datasets.blobs() +def test_highly_variable_genes_no_batch_matches_batch(adata): sc.pp.highly_variable_genes(adata) no_batch_hvg = adata.var["highly_variable"].copy() assert no_batch_hvg.any() @@ -73,8 +83,7 @@ def test_highly_variable_genes_no_batch_matches_batch(): @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) -def test_highly_variable_genes_no_inplace(array_type, batch_key): - adata = sc.datasets.blobs() +def test_highly_variable_genes_no_inplace(adata, array_type, batch_key): adata.X = array_type(adata.X) if batch_key: adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2) From 5a63881a006523aa45e681bcd843a946571fc8f0 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 12:46:54 +0100 Subject: [PATCH 17/50] add tests for n_top_genes and cell_ranger --- scanpy/tests/test_highly_variable_genes.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index b149ea1d5b..650b268e20 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -571,11 +571,14 @@ def test_cellranger_n_top_genes_warning(): @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) @pytest.mark.parametrize("inplace", [True, False], ids=["inplace", "copy"]) -def test_highly_variable_genes_subset_inplace_consistency(flavor, subset, inplace): +def test_highly_variable_genes_subset_inplace_consistency( + flavor, array_type, subset, inplace +): adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0) - adata.X = np.abs(adata.X).astype(int) + adata.X = array_type(np.abs(adata.X).astype(int)) if flavor == "seurat" or flavor == "cell_ranger": sc.pp.normalize_total(adata, target_sum=1e4) @@ -599,3 +602,8 @@ def test_highly_variable_genes_subset_inplace_consistency(flavor, subset, inplac assert (output_df is None) == inplace assert len(adata.var if inplace else output_df) == (15 if subset else n_genes) + if output_df is not None: + if "dask" in array_type.__name__: + assert isinstance(output_df, DaskDataFrame) + else: + assert isinstance(output_df, pd.DataFrame) From 2dc0ad865713f752b61c73c3f117114e202eea19 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 19 Jan 2024 14:19:17 +0100 Subject: [PATCH 18/50] n_top_genes --- .../preprocessing/_highly_variable_genes.py | 29 ++++++++++--------- scanpy/tests/test_highly_variable_genes.py | 6 ++++ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index c6c82cebe6..f34811d0d8 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -374,25 +374,28 @@ def _subset_genes( n_top_genes = cutoff del cutoff - dispersion_norm_orig = dispersion_norm # original length - dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] - # interestingly, np.argpartition is slightly slower - dispersion_norm[::-1].sort() if n_top_genes > adata.n_vars: logg.info("`n_top_genes` > `adata.n_var`, returning all genes.") n_top_genes = adata.n_vars - if n_top_genes > dispersion_norm.size: - warnings.warn( - "`n_top_genes` > number of normalized dispersions, returning all genes with normalized dispersions.", - UserWarning, - ) - n_top_genes = dispersion_norm.size - disp_cut_off = dispersion_norm[n_top_genes - 1] + disp_cut_off = _nth_highest(dispersion_norm, n_top_genes) logg.debug( f"the {n_top_genes} top genes correspond to a " f"normalized dispersion cutoff of {disp_cut_off}" ) - return np.nan_to_num(dispersion_norm_orig) >= disp_cut_off + return np.nan_to_num(dispersion_norm) >= disp_cut_off + + +def _nth_highest(x: NDArray[np.float64] | DaskArray, n: int) -> float | DaskArray: + x = x[~np.isnan(x)] + if n > x.size: + msg = "`n_top_genes` > number of normalized dispersions, returning all genes with normalized dispersions." + warnings.warn(msg, UserWarning) + n = x.size + if isinstance(x, DaskArray): + return x.topk(n)[-1] + # interestingly, np.argpartition is slightly slower + x[::-1].sort() + return x[n - 1] def _highly_variable_genes_batched( @@ -689,7 +692,7 @@ def highly_variable_genes( df["highly_variable_intersection"] ) if subset: - adata._inplace_subset_var(series_to_array(df["highly_variable"])) + adata._inplace_subset_var(materialize_as_ndarray(df["highly_variable"])) else: if subset: diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 650b268e20..f0ae529010 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -84,6 +84,7 @@ def test_highly_variable_genes_no_batch_matches_batch(adata): @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) def test_highly_variable_genes_no_inplace(adata, array_type, batch_key): + """Tests that, with `n_top_genes=None` the returned dataframe has the expected columns.""" adata.X = array_type(adata.X) if batch_key: adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2) @@ -577,6 +578,11 @@ def test_cellranger_n_top_genes_warning(): def test_highly_variable_genes_subset_inplace_consistency( flavor, array_type, subset, inplace ): + """Tests that, with `n_top_genes=n` + - `inplace` and `subset` interact correctly + - for both the `seurat` and `cell_ranger` flavors + - for dask arrays and non-dask arrays + """ adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0) adata.X = array_type(np.abs(adata.X).astype(int)) From 5a43095ada80842c0e76496051250872ba42ac11 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 22 Jan 2024 11:25:41 +0100 Subject: [PATCH 19/50] Support cell_ranger --- scanpy/preprocessing/_highly_variable_genes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index f34811d0d8..a19b231168 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -343,11 +343,12 @@ def _stats_cell_ranger( bins=np.r_[-np.inf, np.percentile(df["means"], np.arange(10, 105, 5)), np.inf], ) disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] - disp_median_bin = disp_grouped.median() + # using .agg here doesn’t work: https://github.com/dask/dask/issues/10836 + disp_median_bin = dask_compute(disp_grouped.median()) # the next line raises the warning: "Mean of empty slice" with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) - disp_mad_bin = disp_grouped.apply(robust.mad) + disp_mad_bin = dask_compute(disp_grouped.apply(robust.mad)) disp_avg = disp_median_bin.loc[df["mean_bin"]].reset_index(drop=True) disp_dev = disp_mad_bin.loc[df["mean_bin"]].reset_index(drop=True) return disp_avg, disp_dev From 0448a226c47c08d73a196c93cdf2cfc7b7fb500b Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 22 Jan 2024 11:45:24 +0100 Subject: [PATCH 20/50] refactor --- .../preprocessing/_highly_variable_genes.py | 67 +++++++++++-------- 1 file changed, 40 insertions(+), 27 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index a19b231168..fedd0b01d1 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -12,7 +12,13 @@ from anndata import AnnData from .. import logging as logg -from .._compat import DaskArray, DaskDataFrame, DaskSeries, old_positionals +from .._compat import ( + DaskArray, + DaskDataFrame, + DaskSeries, + DaskSeriesGroupBy, + old_positionals, +) from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ..get import _get_obs_rep @@ -27,6 +33,7 @@ if TYPE_CHECKING: from numpy.typing import NDArray + from pandas.core.groupby.generic import SeriesGroupBy def _highly_variable_genes_seurat_v3( @@ -270,6 +277,7 @@ def _highly_variable_genes_single_batch( dispersion[dispersion == 0] = np.nan dispersion = np.log(dispersion) mean = np.log1p(mean) + # all of the following quantities are "per-gene" here df: pd.DataFrame | DaskDataFrame if isinstance(X, DaskArray): @@ -281,11 +289,12 @@ def _highly_variable_genes_single_batch( ) else: df = pd.DataFrame(dict(means=mean, dispersions=dispersion)) - # assign "mean_bin" column and compute dispersions_norm + df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins) + disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] if flavor == "seurat": - disp_avg, disp_dev = _stats_seurat(df, n_bins=n_bins) + disp_avg, disp_dev = _stats_seurat(df["mean_bin"], disp_grouped) elif flavor == "cell_ranger": - disp_avg, disp_dev = _stats_cell_ranger(df) + disp_avg, disp_dev = _stats_cell_ranger(df["mean_bin"], disp_grouped) else: raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') @@ -301,12 +310,27 @@ def _highly_variable_genes_single_batch( return df +def _get_mean_bins( + means: pd.Series | DaskSeries, flavor: Literal["seurat", "cell_ranger"], n_bins: int +) -> pd.Series | DaskSeries: + if flavor == "seurat": + bins = n_bins + elif flavor == "cell_ranger": + bins = np.r_[-np.inf, np.percentile(means, np.arange(10, 105, 5)), np.inf] + else: + raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') + + if isinstance(means, DaskSeries): + # TODO: does map_partitions make sense for bin? It would bin per chunk, not globally + return means.map_partitions(pd.cut, bins=bins) + return pd.cut(means, bins=bins) + + def _stats_seurat( - df: pd.DataFrame | DaskDataFrame, *, n_bins: int + mean_bins: pd.Series | DaskSeries, + disp_grouped: SeriesGroupBy | DaskSeriesGroupBy, ) -> tuple[pd.Series | DaskSeries, pd.Series | DaskSeries]: - """Assign "mean_bin" column and compute mean and std dev per bin.""" - df["mean_bin"] = _ser_cut(df["means"], bins=n_bins) - disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] + """Compute mean and std dev per bin.""" with suppress_pandas_warning(): disp_bin_stats: pd.DataFrame = dask_compute( disp_grouped.agg(mean="mean", std=partial(np.std, ddof=1)) @@ -315,7 +339,7 @@ def _stats_seurat( # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 one_gene_per_bin = disp_bin_stats["std"].isnull() - gen_indices = np.flatnonzero(one_gene_per_bin.loc[df["mean_bin"]]) + gen_indices = np.flatnonzero(one_gene_per_bin.loc[mean_bins]) if len(gen_indices) > 0: logg.debug( f"Gene indices {gen_indices} fell into a single bin: their " @@ -327,40 +351,29 @@ def _stats_seurat( ] disp_bin_stats["mean"].loc[one_gene_per_bin] = 0 # (use values here as index differs) - disp_avg = disp_bin_stats["mean"].loc[df["mean_bin"]].reset_index(drop=True) - disp_dev = disp_bin_stats["std"].loc[df["mean_bin"]].reset_index(drop=True) + disp_avg = disp_bin_stats["mean"].loc[mean_bins].reset_index(drop=True) + disp_dev = disp_bin_stats["std"].loc[mean_bins].reset_index(drop=True) return disp_avg, disp_dev def _stats_cell_ranger( - df: pd.DataFrame | DaskDataFrame, + mean_bins: pd.Series | DaskSeries, + disp_grouped: SeriesGroupBy | DaskSeriesGroupBy, ) -> tuple[pd.Series | DaskSeries, pd.Series | DaskSeries]: - """Assign "mean_bin" column and compute median and median absolute dev per bin.""" + """Compute median and median absolute dev per bin.""" from statsmodels import robust - df["mean_bin"] = _ser_cut( - df["means"], - bins=np.r_[-np.inf, np.percentile(df["means"], np.arange(10, 105, 5)), np.inf], - ) - disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] # using .agg here doesn’t work: https://github.com/dask/dask/issues/10836 disp_median_bin = dask_compute(disp_grouped.median()) # the next line raises the warning: "Mean of empty slice" with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) disp_mad_bin = dask_compute(disp_grouped.apply(robust.mad)) - disp_avg = disp_median_bin.loc[df["mean_bin"]].reset_index(drop=True) - disp_dev = disp_mad_bin.loc[df["mean_bin"]].reset_index(drop=True) + disp_avg = disp_median_bin.loc[mean_bins].reset_index(drop=True) + disp_dev = disp_mad_bin.loc[mean_bins].reset_index(drop=True) return disp_avg, disp_dev -def _ser_cut(df: pd.Series | DaskSeries, *, bins: int) -> pd.Series | DaskSeries: - if isinstance(df, DaskSeries): - # TODO: does map_partitions make sense for bin? It would bin per chunk, not globally - return df.map_partitions(pd.cut, bins=bins) - return pd.cut(df, bins=bins) - - def _subset_genes( adata: AnnData, *, From 28d25bc7a984f2dd10aa098c8f02723878a02dea Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 22 Jan 2024 16:35:49 +0100 Subject: [PATCH 21/50] more WIP --- .../preprocessing/_highly_variable_genes.py | 66 +++++++++++-------- scanpy/tests/test_highly_variable_genes.py | 3 +- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index fedd0b01d1..79f7cd9f73 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -285,21 +285,27 @@ def _highly_variable_genes_single_batch( import dask.dataframe as dd df = dd.from_dask_array( - da.vstack((mean, dispersion)).T, columns=["means", "dispersions"] + da.vstack((mean, dispersion)).T, + columns=["means", "dispersions"], ) + df["gene"] = adata.var_names.to_series(index=df.index, name="gene") + df = df.set_index("gene") else: - df = pd.DataFrame(dict(means=mean, dispersions=dispersion)) + df = pd.DataFrame( + dict(means=mean, dispersions=dispersion), index=adata.var_names + ) + df.index.name = "gene" df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins) disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] if flavor == "seurat": - disp_avg, disp_dev = _stats_seurat(df["mean_bin"], disp_grouped) + disp_stats = _stats_seurat(df["mean_bin"], disp_grouped) elif flavor == "cell_ranger": - disp_avg, disp_dev = _stats_cell_ranger(df["mean_bin"], disp_grouped) + disp_stats = _stats_cell_ranger(df["mean_bin"], disp_grouped) else: raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') # actually do the normalization - df["dispersions_norm"] = (df["dispersions"] - disp_avg) / disp_dev + df["dispersions_norm"] = (df["dispersions"] - disp_stats["avg"]) / disp_stats["dev"] df["highly_variable"] = _subset_genes( adata, mean=mean, @@ -329,16 +335,16 @@ def _get_mean_bins( def _stats_seurat( mean_bins: pd.Series | DaskSeries, disp_grouped: SeriesGroupBy | DaskSeriesGroupBy, -) -> tuple[pd.Series | DaskSeries, pd.Series | DaskSeries]: +) -> pd.DataFrame | DaskDataFrame: """Compute mean and std dev per bin.""" with suppress_pandas_warning(): disp_bin_stats: pd.DataFrame = dask_compute( - disp_grouped.agg(mean="mean", std=partial(np.std, ddof=1)) + disp_grouped.agg(avg="mean", dev=partial(np.std, ddof=1)) ) # retrieve those genes that have nan std, these are the ones where # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 - one_gene_per_bin = disp_bin_stats["std"].isnull() + one_gene_per_bin = disp_bin_stats["dev"].isnull() gen_indices = np.flatnonzero(one_gene_per_bin.loc[mean_bins]) if len(gen_indices) > 0: logg.debug( @@ -346,31 +352,31 @@ def _stats_seurat( "normalized dispersion was set to 1.\n " "Decreasing `n_bins` will likely avoid this effect." ) - disp_bin_stats["std"].loc[one_gene_per_bin] = disp_bin_stats["mean"].loc[ + disp_bin_stats["dev"].loc[one_gene_per_bin] = disp_bin_stats["avg"].loc[ one_gene_per_bin ] - disp_bin_stats["mean"].loc[one_gene_per_bin] = 0 + disp_bin_stats["avg"].loc[one_gene_per_bin] = 0 # (use values here as index differs) - disp_avg = disp_bin_stats["mean"].loc[mean_bins].reset_index(drop=True) - disp_dev = disp_bin_stats["std"].loc[mean_bins].reset_index(drop=True) - return disp_avg, disp_dev + return disp_bin_stats.loc[mean_bins].set_index(mean_bins.index) def _stats_cell_ranger( mean_bins: pd.Series | DaskSeries, disp_grouped: SeriesGroupBy | DaskSeriesGroupBy, -) -> tuple[pd.Series | DaskSeries, pd.Series | DaskSeries]: +) -> pd.DataFrame | DaskDataFrame: """Compute median and median absolute dev per bin.""" from statsmodels import robust + raise RuntimeError("TODO") # using .agg here doesn’t work: https://github.com/dask/dask/issues/10836 disp_median_bin = dask_compute(disp_grouped.median()) # the next line raises the warning: "Mean of empty slice" with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) disp_mad_bin = dask_compute(disp_grouped.apply(robust.mad)) - disp_avg = disp_median_bin.loc[mean_bins].reset_index(drop=True) - disp_dev = disp_mad_bin.loc[mean_bins].reset_index(drop=True) + # TODO: df + disp_avg = disp_median_bin.loc[mean_bins].reindex(mean_bins.index) + disp_dev = disp_mad_bin.loc[mean_bins].reindex(mean_bins.index) return disp_avg, disp_dev @@ -444,8 +450,12 @@ def _highly_variable_genes_batched( hvg = _highly_variable_genes_single_batch( adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) + assert hvg.index.name == "gene" + if isinstance(hvg, DaskDataFrame): + hvg = hvg.reset_index(drop=False) + else: + hvg.reset_index(drop=False, inplace=True) - hvg["gene"] = adata_subset.var_names if (n_removed := np.sum(~filt)) > 0: # Add 0 values for genes that were filtered out missing_hvg = pd.DataFrame( @@ -458,7 +468,7 @@ def _highly_variable_genes_batched( # Order as before filtering idxs = np.concatenate((np.flatnonzero(filt), np.flatnonzero(~filt))) - hvg = hvg.loc[np.argsort(idxs)] + hvg = hvg.iloc[np.argsort(idxs)] dfs.append(hvg) @@ -479,7 +489,9 @@ def _highly_variable_genes_batched( highly_variable="sum", ) ) - df.rename(columns=dict(highly_variable="highly_variable_nbatches"), inplace=True) + if isinstance(df, DaskDataFrame): + df = df.set_index("gene") # happens automatically for pandas df + df["highly_variable_nbatches"] = df["highly_variable"] df["highly_variable_intersection"] = df["highly_variable_nbatches"] == len(batches) if isinstance(cutoff, int): @@ -492,9 +504,7 @@ def _highly_variable_genes_batched( inplace=True, ) df["highly_variable"] = np.arange(df.shape[0]) < cutoff - df = df.loc[adata.var_names, :] else: - df = df.loc[adata.var_names] dispersion_norm = series_to_array(df["dispersions_norm"]) dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat df["highly_variable"] = cutoff.in_bounds(df["means"], df["dispersions_norm"]) @@ -691,18 +701,18 @@ def highly_variable_genes( " 'dispersions', float vector (adata.var)\n" " 'dispersions_norm', float vector (adata.var)" ) - adata.var["highly_variable"] = series_to_array(df["highly_variable"]) - adata.var["means"] = series_to_array(df["means"]) - adata.var["dispersions"] = series_to_array(df["dispersions"]) - adata.var["dispersions_norm"] = series_to_array( - df["dispersions_norm"], dtype=np.float32 + adata.var["highly_variable"] = dask_compute(df["highly_variable"]) + adata.var["means"] = dask_compute(df["means"]) + adata.var["dispersions"] = dask_compute(df["dispersions"]) + adata.var["dispersions_norm"] = dask_compute(df["dispersions_norm"]).astype( + np.float32, copy=False ) if batch_key is not None: - adata.var["highly_variable_nbatches"] = series_to_array( + adata.var["highly_variable_nbatches"] = dask_compute( df["highly_variable_nbatches"] ) - adata.var["highly_variable_intersection"] = series_to_array( + adata.var["highly_variable_intersection"] = dask_compute( df["highly_variable_intersection"] ) if subset: diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index f0ae529010..72390f99e8 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -24,7 +24,8 @@ @pytest.fixture(scope="session") def adata_sess() -> AnnData: adata = sc.datasets.blobs() - adata.var_names = list(ascii_letters[: adata.n_vars]) + rng = np.random.default_rng(0) + adata.var_names = rng.choice(list(ascii_letters), adata.n_vars, replace=False) return adata From 1bedd5c4a7fdfc9909bcc2e08f82dfa2e804b45f Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jan 2024 10:09:19 +0100 Subject: [PATCH 22/50] cell ranger fix --- scanpy/preprocessing/_distributed.py | 42 ++++++++++++++++++- .../preprocessing/_highly_variable_genes.py | 36 +++++++++++----- 2 files changed, 65 insertions(+), 13 deletions(-) diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index 45087c3f32..db6d98b2d3 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -1,7 +1,8 @@ from __future__ import annotations from contextlib import contextmanager -from typing import TYPE_CHECKING, overload +from itertools import chain +from typing import TYPE_CHECKING, Literal, overload import numpy as np @@ -15,9 +16,10 @@ ) if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Callable, Generator import pandas as pd + from dask.dataframe import Aggregation from numpy.typing import ArrayLike from pandas.core.groupby.generic import DataFrameGroupBy, SeriesGroupBy @@ -124,3 +126,39 @@ def suppress_pandas_warning() -> Generator[None, None, None]: "ignore", r"The default of observed=False", category=FutureWarning ) yield + + +try: + import dask.dataframe as dd +except ImportError: + + def get_mad(dask: Literal[False]) -> Callable[[np.ndarray], np.ndarray]: + from statsmodels.robust import mad + + return mad +else: + + def _mad1(chunks: DaskSeriesGroupBy): + return chunks.apply(list) + + def _mad2(grouped: DaskSeriesGroupBy): + def internal(c): + if (c != c).all(): + return [np.nan] + f = [_ for _ in c if _ == _] + f = [_ if isinstance(_, list) else [_] for _ in f] + return list(chain.from_iterable(f)) + + return grouped.apply(internal) + + def _mad3(grouped: DaskSeriesGroupBy): + from statsmodels.robust import mad + + return grouped.apply(lambda s: np.nan if len(s) == 0 else mad(s)) + + mad_dask = dd.Aggregation("mad", chunk=_mad1, agg=_mad2, finalize=_mad3) + + def get_mad(dask: bool) -> Callable[[np.ndarray], np.ndarray] | Aggregation: + from statsmodels.robust import mad + + return mad_dask if dask else mad diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 79f7cd9f73..fe63463e99 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -15,6 +15,7 @@ from .._compat import ( DaskArray, DaskDataFrame, + DaskDataFrameGroupBy, DaskSeries, DaskSeriesGroupBy, old_positionals, @@ -24,6 +25,7 @@ from ..get import _get_obs_rep from ._distributed import ( dask_compute, + get_mad, materialize_as_ndarray, series_to_array, suppress_pandas_warning, @@ -33,7 +35,7 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from pandas.core.groupby.generic import SeriesGroupBy + from pandas.core.groupby.generic import DataFrameGroupBy, SeriesGroupBy def _highly_variable_genes_seurat_v3( @@ -365,19 +367,31 @@ def _stats_cell_ranger( disp_grouped: SeriesGroupBy | DaskSeriesGroupBy, ) -> pd.DataFrame | DaskDataFrame: """Compute median and median absolute dev per bin.""" - from statsmodels import robust - raise RuntimeError("TODO") - # using .agg here doesn’t work: https://github.com/dask/dask/issues/10836 - disp_median_bin = dask_compute(disp_grouped.median()) - # the next line raises the warning: "Mean of empty slice" + is_dask = isinstance(disp_grouped, DaskSeriesGroupBy) with warnings.catch_warnings(): + # MAD calculation raises the warning: "Mean of empty slice" warnings.simplefilter("ignore", category=RuntimeWarning) - disp_mad_bin = dask_compute(disp_grouped.apply(robust.mad)) - # TODO: df - disp_avg = disp_median_bin.loc[mean_bins].reindex(mean_bins.index) - disp_dev = disp_mad_bin.loc[mean_bins].reindex(mean_bins.index) - return disp_avg, disp_dev + disp_bin_stats = _aggregate(disp_grouped, ["median", get_mad(dask=is_dask)]) + # Can’t use kwargs in `aggregate`: https://github.com/dask/dask/issues/10836 + disp_bin_stats = disp_bin_stats.rename(columns=dict(median="avg", mad="dev")) + return disp_bin_stats.loc[mean_bins].set_index(mean_bins.index) + + +def _aggregate( + grouped: ( + DataFrameGroupBy | DaskDataFrameGroupBy | SeriesGroupBy | DaskSeriesGroupBy + ), + arg=None, + **kw, +) -> pd.DataFrame | DaskDataFrame | pd.Series | DaskSeries: + # ValueError: In order to aggregate with 'median', + # you must use shuffling-based aggregation (e.g., shuffle='tasks') + if ((arg and "median" in arg) or "median" in kw) and isinstance( + grouped, (DaskSeriesGroupBy, DaskDataFrameGroupBy) + ): + kw["shuffle"] = True + return grouped.agg(arg, **kw) def _subset_genes( From 181a6c5ceffceffc49900fc614f6de67fe89d91c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jan 2024 10:37:21 +0100 Subject: [PATCH 23/50] almost --- scanpy/preprocessing/_highly_variable_genes.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index fe63463e99..3fe2a6a9b0 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -358,8 +358,7 @@ def _stats_seurat( one_gene_per_bin ] disp_bin_stats["avg"].loc[one_gene_per_bin] = 0 - # (use values here as index differs) - return disp_bin_stats.loc[mean_bins].set_index(mean_bins.index) + return _unbin(disp_bin_stats, mean_bins) def _stats_cell_ranger( @@ -375,7 +374,15 @@ def _stats_cell_ranger( disp_bin_stats = _aggregate(disp_grouped, ["median", get_mad(dask=is_dask)]) # Can’t use kwargs in `aggregate`: https://github.com/dask/dask/issues/10836 disp_bin_stats = disp_bin_stats.rename(columns=dict(median="avg", mad="dev")) - return disp_bin_stats.loc[mean_bins].set_index(mean_bins.index) + return _unbin(disp_bin_stats, mean_bins) + + +def _unbin( + df: pd.DataFrame | DaskDataFrame, mean_bins: pd.Series | DaskSeries +) -> pd.DataFrame | DaskDataFrame: + df = df.loc[mean_bins] + df["gene"] = mean_bins.index + return df.set_index("gene") def _aggregate( @@ -480,10 +487,6 @@ def _highly_variable_genes_batched( missing_hvg["gene"] = gene_list[~filt] hvg = pd.concat([hvg, missing_hvg], ignore_index=True) - # Order as before filtering - idxs = np.concatenate((np.flatnonzero(filt), np.flatnonzero(~filt))) - hvg = hvg.iloc[np.argsort(idxs)] - dfs.append(hvg) df: DaskDataFrame | pd.DataFrame From e28aefaff3f93565ab78770820d52ebce8e0975a Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jan 2024 11:22:17 +0100 Subject: [PATCH 24/50] add XFail --- scanpy/tests/test_highly_variable_genes.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 72390f99e8..198821bef6 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -577,13 +577,18 @@ def test_cellranger_n_top_genes_warning(): @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) @pytest.mark.parametrize("inplace", [True, False], ids=["inplace", "copy"]) def test_highly_variable_genes_subset_inplace_consistency( - flavor, array_type, subset, inplace + request: pytest.FixtureRequest, flavor, array_type, subset, inplace ): """Tests that, with `n_top_genes=n` - `inplace` and `subset` interact correctly - for both the `seurat` and `cell_ranger` flavors - for dask arrays and non-dask arrays """ + if flavor == "cell_ranger" and "dask" in array_type.__name__: + request.applymarker( + pytest.mark.xfail(reason="See https://github.com/dask/dask/issues/10853") + ) + adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0) adata.X = array_type(np.abs(adata.X).astype(int)) From e3beadd3eab837831e4ff7c2e4ef605a59618fe8 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jan 2024 12:52:34 +0100 Subject: [PATCH 25/50] Fix docs, add relnote --- docs/release-notes/1.10.0.md | 1 + pyproject.toml | 1 + scanpy/preprocessing/_distributed.py | 3 ++- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/release-notes/1.10.0.md b/docs/release-notes/1.10.0.md index dbf9e57e35..f98e945f04 100644 --- a/docs/release-notes/1.10.0.md +++ b/docs/release-notes/1.10.0.md @@ -14,6 +14,7 @@ * {func}`scanpy.pp.pca`, {func}`scanpy.pp.scale`, {func}`scanpy.pl.embedding`, and {func}`scanpy.experimental.pp.normalize_pearson_residuals_pca` now support a `mask` parameter {pr}`2272` {smaller}`C Bright, T Marcella, & P Angerer` * {func}`scanpy.tl.rank_genes_groups` no longer warns that it's default was changed from t-test_overestim_var to t-test {pr}`2798` {smaller}`L Heumos` +* {func}`scanpy.pp.highly_variable_genes` supports dask for the default `seurat` flavor and partially for the `cell_ranger` flavor {pr}`2809` {smaller}`P Angerer` ```{rubric} Docs ``` diff --git a/pyproject.toml b/pyproject.toml index 65d3dfb8bc..06651f87a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,6 +122,7 @@ doc = [ "ipython>=7.20", # for nbsphinx code highlighting "matplotlib!=3.6.1", # TODO: remove necessity for being able to import doc-linked classes + "dask", "scanpy[paga]", "sam-algorithm", ] diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index db6d98b2d3..011101b0f7 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -37,6 +37,7 @@ def series_to_array(s: DaskSeries, *, dtype: np.dtype | None = None) -> DaskArra def series_to_array( s: pd.Series | DaskSeries, *, dtype: np.dtype | None = None ) -> np.ndarray | DaskArray: + """Convert Series to Array, keeping them in-memory or distributed.""" if isinstance(s, DaskSeries): return ( s.to_dask_array(True) @@ -73,7 +74,7 @@ def materialize_as_ndarray( def materialize_as_ndarray( a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...], ) -> tuple[np.ndarray] | np.ndarray: - """Convert distributed arrays to ndarrays.""" + """Compute distributed arrays and convert them to numpy ndarrays.""" if not isinstance(a, tuple): return np.asarray(a) From ce115c32d5d199c79b65b5c7b5efaa9f22f16426 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 25 Jan 2024 12:38:31 +0100 Subject: [PATCH 26/50] reuse --- .../preprocessing/_highly_variable_genes.py | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 3fe2a6a9b0..9a6e5f7fc1 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -80,15 +80,15 @@ def _highly_variable_genes_seurat_v3( "Please install skmisc package via `pip install --user scikit-misc" ) df = pd.DataFrame(index=adata.var_names) - X = adata.layers[layer] if layer is not None else adata.X + data = _get_obs_rep(layer=layer) - if check_values and not check_nonnegative_integers(X): + if check_values and not check_nonnegative_integers(data): warnings.warn( "`flavor='seurat_v3'` expects raw count data, but non-integers were found.", UserWarning, ) - df["means"], df["variances"] = _get_mean_var(X) + df["means"], df["variances"] = _get_mean_var(data) if batch_key is None: batch_info = pd.Categorical(np.zeros(adata.shape[0], dtype=int)) @@ -97,11 +97,11 @@ def _highly_variable_genes_seurat_v3( norm_gene_vars = [] for b in np.unique(batch_info): - X_batch = X[batch_info == b] + data_batch = data[batch_info == b] - mean, var = _get_mean_var(X_batch) + mean, var = _get_mean_var(data_batch) not_const = var > 0 - estimat_var = np.zeros(X.shape[1], dtype=np.float64) + estimat_var = np.zeros(data.shape[1], dtype=np.float64) y = np.log10(var[not_const]) x = np.log10(mean[not_const]) @@ -110,9 +110,9 @@ def _highly_variable_genes_seurat_v3( estimat_var[not_const] = model.outputs.fitted_values reg_std = np.sqrt(10**estimat_var) - batch_counts = X_batch.astype(np.float64).copy() + batch_counts = data_batch.astype(np.float64).copy() # clip large values as in Seurat - N = X_batch.shape[0] + N = data_batch.shape[0] vmax = np.sqrt(N) clip_val = reg_std * vmax + mean if sp_sparse.issparse(batch_counts): @@ -260,18 +260,18 @@ def _highly_variable_genes_single_batch( A DataFrame that contains the columns `highly_variable`, `means`, `dispersions`, and `dispersions_norm`. """ - X = adata.layers[layer] if layer is not None else adata.X + data = _get_obs_rep(layer=layer) if flavor == "seurat": - X = X.copy() + data = data.copy() if "log1p" in adata.uns_keys() and adata.uns["log1p"].get("base") is not None: - X *= np.log(adata.uns["log1p"]["base"]) - # use out if possible. only possible since we copy X - if isinstance(X, np.ndarray): - np.expm1(X, out=X) + data *= np.log(adata.uns["log1p"]["base"]) + # use out if possible. only possible since we copy the data matrix + if isinstance(data, np.ndarray): + np.expm1(data, out=data) else: - X = np.expm1(X) + data = np.expm1(data) - mean, var = _get_mean_var(X) + mean, var = _get_mean_var(data) # now actually compute the dispersion mean[mean == 0] = 1e-12 # set entries equal to zero to small value dispersion = var / mean @@ -282,7 +282,7 @@ def _highly_variable_genes_single_batch( # all of the following quantities are "per-gene" here df: pd.DataFrame | DaskDataFrame - if isinstance(X, DaskArray): + if isinstance(data, DaskArray): import dask.array as da import dask.dataframe as dd From d6ea0ddb4cbe329064a53003cb044de3437803d6 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 25 Jan 2024 12:47:30 +0100 Subject: [PATCH 27/50] oops --- scanpy/preprocessing/_highly_variable_genes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 9a6e5f7fc1..b1adbe8848 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -80,7 +80,7 @@ def _highly_variable_genes_seurat_v3( "Please install skmisc package via `pip install --user scikit-misc" ) df = pd.DataFrame(index=adata.var_names) - data = _get_obs_rep(layer=layer) + data = _get_obs_rep(adata, layer=layer) if check_values and not check_nonnegative_integers(data): warnings.warn( @@ -260,7 +260,7 @@ def _highly_variable_genes_single_batch( A DataFrame that contains the columns `highly_variable`, `means`, `dispersions`, and `dispersions_norm`. """ - data = _get_obs_rep(layer=layer) + data = _get_obs_rep(adata, layer=layer) if flavor == "seurat": data = data.copy() if "log1p" in adata.uns_keys() and adata.uns["log1p"].get("base") is not None: From 414e6673534e7f764dd74753ccaceceeb12f9aa6 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 25 Jan 2024 14:45:40 +0100 Subject: [PATCH 28/50] Fix order --- .../preprocessing/_highly_variable_genes.py | 73 +++++++++---------- scanpy/tests/test_highly_variable_genes.py | 71 +++++++++++++----- 2 files changed, 88 insertions(+), 56 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index b1adbe8848..940ebdbfd1 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -287,16 +287,10 @@ def _highly_variable_genes_single_batch( import dask.dataframe as dd df = dd.from_dask_array( - da.vstack((mean, dispersion)).T, - columns=["means", "dispersions"], + da.vstack((mean, dispersion)).T, columns=["means", "dispersions"] ) - df["gene"] = adata.var_names.to_series(index=df.index, name="gene") - df = df.set_index("gene") else: - df = pd.DataFrame( - dict(means=mean, dispersions=dispersion), index=adata.var_names - ) - df.index.name = "gene" + df = pd.DataFrame(dict(means=mean, dispersions=dispersion)) df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins) disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] if flavor == "seurat": @@ -315,6 +309,12 @@ def _highly_variable_genes_single_batch( cutoff=cutoff, ) + if isinstance(df, DaskDataFrame): + df["gene"] = adata.var_names.to_series(index=df.index, name="gene") + df = df.set_index("gene", sort=False) + else: + df.set_index(adata.var_names, inplace=True) + df.index.name = "gene" return df @@ -382,7 +382,10 @@ def _unbin( ) -> pd.DataFrame | DaskDataFrame: df = df.loc[mean_bins] df["gene"] = mean_bins.index - return df.set_index("gene") + if isinstance(df, DaskDataFrame): + return df.set_index("gene", sort=False) + df.set_index("gene", inplace=True) + return df def _aggregate( @@ -507,7 +510,7 @@ def _highly_variable_genes_batched( ) ) if isinstance(df, DaskDataFrame): - df = df.set_index("gene") # happens automatically for pandas df + df = df.set_index("gene", sort=False) # happens automatically for pandas df df["highly_variable_nbatches"] = df["highly_variable"] df["highly_variable_intersection"] = df["highly_variable_nbatches"] == len(batches) @@ -709,34 +712,30 @@ def highly_variable_genes( logg.info(" finished", time=start) - if inplace: - adata.uns["hvg"] = {"flavor": flavor} - logg.hint( - "added\n" - " 'highly_variable', boolean vector (adata.var)\n" - " 'means', float vector (adata.var)\n" - " 'dispersions', float vector (adata.var)\n" - " 'dispersions_norm', float vector (adata.var)" - ) - adata.var["highly_variable"] = dask_compute(df["highly_variable"]) - adata.var["means"] = dask_compute(df["means"]) - adata.var["dispersions"] = dask_compute(df["dispersions"]) - adata.var["dispersions_norm"] = dask_compute(df["dispersions_norm"]).astype( - np.float32, copy=False - ) - - if batch_key is not None: - adata.var["highly_variable_nbatches"] = dask_compute( - df["highly_variable_nbatches"] - ) - adata.var["highly_variable_intersection"] = dask_compute( - df["highly_variable_intersection"] - ) - if subset: - adata._inplace_subset_var(materialize_as_ndarray(df["highly_variable"])) - - else: + if not inplace: if subset: df = df.loc[df["highly_variable"]] return df + + df = dask_compute(df) + adata.uns["hvg"] = {"flavor": flavor} + logg.hint( + "added\n" + " 'highly_variable', boolean vector (adata.var)\n" + " 'means', float vector (adata.var)\n" + " 'dispersions', float vector (adata.var)\n" + " 'dispersions_norm', float vector (adata.var)" + ) + adata.var["highly_variable"] = df["highly_variable"] + adata.var["means"] = df["means"] + adata.var["dispersions"] = df["dispersions"] + adata.var["dispersions_norm"] = df["dispersions_norm"].astype( + np.float32, copy=False + ) + + if batch_key is not None: + adata.var["highly_variable_nbatches"] = df["highly_variable_nbatches"] + adata.var["highly_variable_intersection"] = df["highly_variable_intersection"] + if subset: + adata._inplace_subset_var(df["highly_variable"]) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 198821bef6..a3ea199cb3 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -7,6 +7,7 @@ import pandas as pd import pytest from anndata import AnnData +from pandas.testing import assert_frame_equal, assert_index_equal from scipy import sparse import scanpy as sc @@ -34,11 +35,11 @@ def adata(adata_sess: AnnData) -> AnnData: return adata_sess.copy() -def test_highly_variable_genes_runs(adata): +def test_runs(adata): sc.pp.highly_variable_genes(adata) -def test_highly_variable_genes_supports_batch(adata): +def test_supports_batch(adata): gen = np.random.default_rng(0) adata.obs["batch"] = pd.array( gen.binomial(3, 0.5, size=adata.n_obs), dtype="category" @@ -48,7 +49,7 @@ def test_highly_variable_genes_supports_batch(adata): assert "highly_variable_intersection" in adata.var.columns -def test_highly_variable_genes_supports_layers(adata_sess): +def test_supports_layers(adata_sess): def execute(layer: str | None) -> AnnData: gen = np.random.default_rng(0) adata = adata_sess.copy() @@ -70,7 +71,7 @@ def execute(layer: str | None) -> AnnData: assert (adata1.var["highly_variable"] != adata2.var["highly_variable"]).any() -def test_highly_variable_genes_no_batch_matches_batch(adata): +def test_no_batch_matches_batch(adata): sc.pp.highly_variable_genes(adata) no_batch_hvg = adata.var["highly_variable"].copy() assert no_batch_hvg.any() @@ -84,7 +85,7 @@ def test_highly_variable_genes_no_batch_matches_batch(adata): @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) -def test_highly_variable_genes_no_inplace(adata, array_type, batch_key): +def test_no_inplace(adata, array_type, batch_key): """Tests that, with `n_top_genes=None` the returned dataframe has the expected columns.""" adata.X = array_type(adata.X) if batch_key: @@ -110,7 +111,7 @@ def test_highly_variable_genes_no_inplace(adata, array_type, batch_key): @pytest.mark.parametrize("base", [None, 10]) @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) -def test_highly_variable_genes_keep_layer(base, flavor): +def test_keep_layer(base, flavor): adata = pbmc3k() # cell_ranger flavor can raise error if many 0 genes sc.pp.filter_genes(adata, min_counts=1) @@ -138,7 +139,7 @@ def _check_pearson_hvg_columns(output_df: pd.DataFrame, n_top_genes: int): assert np.nanmax(output_df["highly_variable_rank"].to_numpy()) <= n_top_genes - 1 -def test_highly_variable_genes_pearson_residuals_inputchecks(pbmc3k_parametrized_small): +def test_pearson_residuals_inputchecks(pbmc3k_parametrized_small): adata = pbmc3k_parametrized_small() # depending on check_values, warnings should be raised for non-integer data @@ -178,7 +179,7 @@ def test_highly_variable_genes_pearson_residuals_inputchecks(pbmc3k_parametrized ) @pytest.mark.parametrize("theta", [100, np.Inf], ids=["100theta", "inftheta"]) @pytest.mark.parametrize("n_top_genes", [100, 200], ids=["100n", "200n"]) -def test_highly_variable_genes_pearson_residuals_general( +def test_pearson_residuals_general( pbmc3k_parametrized_small, subset, clip, theta, n_top_genes ): adata = pbmc3k_parametrized_small() @@ -262,9 +263,7 @@ def test_highly_variable_genes_pearson_residuals_general( @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) @pytest.mark.parametrize("n_top_genes", [100, 200], ids=["100n", "200n"]) -def test_highly_variable_genes_pearson_residuals_batch( - pbmc3k_parametrized_small, subset, n_top_genes -): +def test_pearson_residuals_batch(pbmc3k_parametrized_small, subset, n_top_genes): adata = pbmc3k_parametrized_small() # cleanup var del adata.var @@ -330,7 +329,7 @@ def test_highly_variable_genes_pearson_residuals_batch( assert len(output_df) == n_genes -def test_highly_variable_genes_compare_to_seurat(): +def test_compare_to_seurat(): seurat_hvg_info = pd.read_csv(FILE, sep=" ") pbmc = pbmc68k_reduced() @@ -370,7 +369,7 @@ def test_highly_variable_genes_compare_to_seurat(): @needs.skmisc -def test_highly_variable_genes_compare_to_seurat_v3(): +def test_compare_to_seurat_v3(): seurat_hvg_info = pd.read_csv( FILE_V3, sep=" ", dtype={"variances_norm": np.float64} ) @@ -430,7 +429,7 @@ def test_highly_variable_genes_compare_to_seurat_v3(): @needs.skmisc -def test_highly_variable_genes_seurat_v3_warning(): +def test_seurat_v3_warning(): pbmc = pbmc3k()[:200].copy() sc.pp.log1p(pbmc) with pytest.warns( @@ -484,7 +483,7 @@ def test_filter_genes_dispersion_compare_to_seurat(): ) -def test_highly_variable_genes_batches(): +def test_batches(): adata = pbmc68k_reduced() adata[:100, :100].X = np.zeros((100, 100)) @@ -572,11 +571,16 @@ def test_cellranger_n_top_genes_warning(): sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger") +mark_no_cell_ranger = pytest.mark.xfail( + reason="See https://github.com/dask/dask/issues/10853" +) + + @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) @pytest.mark.parametrize("inplace", [True, False], ids=["inplace", "copy"]) -def test_highly_variable_genes_subset_inplace_consistency( +def test_subset_inplace_consistency( request: pytest.FixtureRequest, flavor, array_type, subset, inplace ): """Tests that, with `n_top_genes=n` @@ -585,9 +589,7 @@ def test_highly_variable_genes_subset_inplace_consistency( - for dask arrays and non-dask arrays """ if flavor == "cell_ranger" and "dask" in array_type.__name__: - request.applymarker( - pytest.mark.xfail(reason="See https://github.com/dask/dask/issues/10853") - ) + request.applymarker(mark_no_cell_ranger) adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0) adata.X = array_type(np.abs(adata.X).astype(int)) @@ -619,3 +621,34 @@ def test_highly_variable_genes_subset_inplace_consistency( assert isinstance(output_df, DaskDataFrame) else: assert isinstance(output_df, pd.DataFrame) + + +@pytest.mark.parametrize( + "flavor", ["seurat", pytest.param("cell_ranger", marks=mark_no_cell_ranger)] +) +@pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) +@pytest.mark.parametrize( + "to_dask", [p for p in ARRAY_TYPES_SUPPORTED if "dask" in p.values[0].__name__] +) +def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): + adata.X = np.abs(adata.X).astype(int) + if batch_key is not None: + adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2) + sc.pp.normalize_total(adata, target_sum=1e4) + sc.pp.log1p(adata) + + adata_dask = adata.copy() + adata_dask.X = to_dask(adata_dask.X) + + output_mem, output_dask = ( + sc.pp.highly_variable_genes(ad, flavor=flavor, n_top_genes=15, inplace=False) + for ad in [adata, adata_dask] + ) + + assert isinstance(output_mem, pd.DataFrame) + assert isinstance(output_dask, DaskDataFrame) + + assert_index_equal(adata.var_names, output_mem.index, check_names=False) + assert_index_equal(adata.var_names, output_dask.index.compute(), check_names=False) + + assert_frame_equal(output_mem, output_dask.compute()) From 246f9b56a7fca8ef3ee0f2ee74113bf22dfaaabc Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 1 Feb 2024 11:39:46 +0100 Subject: [PATCH 29/50] remove helper --- scanpy/preprocessing/_distributed.py | 56 +------------------ .../preprocessing/_highly_variable_genes.py | 18 ++---- 2 files changed, 8 insertions(+), 66 deletions(-) diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index 011101b0f7..f9b5648122 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -1,27 +1,18 @@ from __future__ import annotations -from contextlib import contextmanager from itertools import chain from typing import TYPE_CHECKING, Literal, overload import numpy as np -from scanpy._compat import ( - DaskArray, - DaskDataFrame, - DaskDataFrameGroupBy, - DaskSeries, - DaskSeriesGroupBy, - ZappyArray, -) +from scanpy._compat import DaskArray, DaskSeries, DaskSeriesGroupBy, ZappyArray if TYPE_CHECKING: - from collections.abc import Callable, Generator + from collections.abc import Callable import pandas as pd from dask.dataframe import Aggregation from numpy.typing import ArrayLike - from pandas.core.groupby.generic import DataFrameGroupBy, SeriesGroupBy @overload @@ -86,49 +77,6 @@ def materialize_as_ndarray( return da.compute(*a, sync=True) -@overload -def dask_compute(value: DaskDataFrame) -> pd.DataFrame: - ... - - -@overload -def dask_compute(value: DaskSeries) -> pd.Series: - ... - - -@overload -def dask_compute(value: DaskDataFrameGroupBy) -> DataFrameGroupBy: - ... - - -@overload -def dask_compute(value: DaskSeriesGroupBy) -> SeriesGroupBy: - ... - - -def dask_compute( - value: DaskDataFrame | DaskSeries | DaskDataFrameGroupBy | DaskSeriesGroupBy, -) -> pd.DataFrame | pd.Series | DataFrameGroupBy | SeriesGroupBy: - """Compute a dask array or series.""" - if isinstance( - value, (DaskDataFrame, DaskSeries, DaskDataFrameGroupBy, DaskSeriesGroupBy) - ): - with suppress_pandas_warning(): - return value.compute(sync=True) - return value - - -@contextmanager -def suppress_pandas_warning() -> Generator[None, None, None]: - import warnings - - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", r"The default of observed=False", category=FutureWarning - ) - yield - - try: import dask.dataframe as dd except ImportError: diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 940ebdbfd1..0b781895e5 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -23,13 +23,7 @@ from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ..get import _get_obs_rep -from ._distributed import ( - dask_compute, - get_mad, - materialize_as_ndarray, - series_to_array, - suppress_pandas_warning, -) +from ._distributed import get_mad, materialize_as_ndarray, series_to_array from ._simple import filter_genes from ._utils import _get_mean_var @@ -339,10 +333,9 @@ def _stats_seurat( disp_grouped: SeriesGroupBy | DaskSeriesGroupBy, ) -> pd.DataFrame | DaskDataFrame: """Compute mean and std dev per bin.""" - with suppress_pandas_warning(): - disp_bin_stats: pd.DataFrame = dask_compute( - disp_grouped.agg(avg="mean", dev=partial(np.std, ddof=1)) - ) + disp_bin_stats = disp_grouped.agg(avg="mean", dev=partial(np.std, ddof=1)) + if isinstance(disp_bin_stats, DaskDataFrame): + disp_bin_stats = disp_bin_stats.compute(sync=True) # retrieve those genes that have nan std, these are the ones where # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 @@ -718,7 +711,8 @@ def highly_variable_genes( return df - df = dask_compute(df) + if isinstance(df, DaskDataFrame): + df = df.compute(sync=True) adata.uns["hvg"] = {"flavor": flavor} logg.hint( "added\n" From f53e4b7bd68b317fd89cf71bca3f39e3343f6bb7 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 1 Feb 2024 12:01:29 +0100 Subject: [PATCH 30/50] Use pandas --- scanpy/preprocessing/_distributed.py | 44 +------- .../preprocessing/_highly_variable_genes.py | 105 ++++-------------- scanpy/tests/test_highly_variable_genes.py | 29 ++--- 3 files changed, 31 insertions(+), 147 deletions(-) diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index f9b5648122..868db988f4 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -1,17 +1,13 @@ from __future__ import annotations -from itertools import chain -from typing import TYPE_CHECKING, Literal, overload +from typing import TYPE_CHECKING, overload import numpy as np -from scanpy._compat import DaskArray, DaskSeries, DaskSeriesGroupBy, ZappyArray +from scanpy._compat import DaskArray, DaskSeries, ZappyArray if TYPE_CHECKING: - from collections.abc import Callable - import pandas as pd - from dask.dataframe import Aggregation from numpy.typing import ArrayLike @@ -75,39 +71,3 @@ def materialize_as_ndarray( import dask.array as da return da.compute(*a, sync=True) - - -try: - import dask.dataframe as dd -except ImportError: - - def get_mad(dask: Literal[False]) -> Callable[[np.ndarray], np.ndarray]: - from statsmodels.robust import mad - - return mad -else: - - def _mad1(chunks: DaskSeriesGroupBy): - return chunks.apply(list) - - def _mad2(grouped: DaskSeriesGroupBy): - def internal(c): - if (c != c).all(): - return [np.nan] - f = [_ for _ in c if _ == _] - f = [_ if isinstance(_, list) else [_] for _ in f] - return list(chain.from_iterable(f)) - - return grouped.apply(internal) - - def _mad3(grouped: DaskSeriesGroupBy): - from statsmodels.robust import mad - - return grouped.apply(lambda s: np.nan if len(s) == 0 else mad(s)) - - mad_dask = dd.Aggregation("mad", chunk=_mad1, agg=_mad2, finalize=_mad3) - - def get_mad(dask: bool) -> Callable[[np.ndarray], np.ndarray] | Aggregation: - from statsmodels.robust import mad - - return mad_dask if dask else mad diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 0b781895e5..a7703bb3cd 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -12,24 +12,17 @@ from anndata import AnnData from .. import logging as logg -from .._compat import ( - DaskArray, - DaskDataFrame, - DaskDataFrameGroupBy, - DaskSeries, - DaskSeriesGroupBy, - old_positionals, -) +from .._compat import DaskArray, old_positionals from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ..get import _get_obs_rep -from ._distributed import get_mad, materialize_as_ndarray, series_to_array +from ._distributed import materialize_as_ndarray, series_to_array from ._simple import filter_genes from ._utils import _get_mean_var if TYPE_CHECKING: from numpy.typing import NDArray - from pandas.core.groupby.generic import DataFrameGroupBy, SeriesGroupBy + from pandas.core.groupby.generic import SeriesGroupBy def _highly_variable_genes_seurat_v3( @@ -245,7 +238,7 @@ def _highly_variable_genes_single_batch( cutoff: _Cutoffs | int, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger"] = "seurat", -) -> pd.DataFrame | DaskDataFrame: +) -> pd.DataFrame: """\ See `highly_variable_genes`. @@ -275,16 +268,9 @@ def _highly_variable_genes_single_batch( mean = np.log1p(mean) # all of the following quantities are "per-gene" here - df: pd.DataFrame | DaskDataFrame - if isinstance(data, DaskArray): - import dask.array as da - import dask.dataframe as dd - - df = dd.from_dask_array( - da.vstack((mean, dispersion)).T, columns=["means", "dispersions"] - ) - else: - df = pd.DataFrame(dict(means=mean, dispersions=dispersion)) + df = pd.DataFrame( + dict(zip(["means", "dispersions"], materialize_as_ndarray((mean, dispersion)))) + ) df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins) disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] if flavor == "seurat": @@ -303,18 +289,14 @@ def _highly_variable_genes_single_batch( cutoff=cutoff, ) - if isinstance(df, DaskDataFrame): - df["gene"] = adata.var_names.to_series(index=df.index, name="gene") - df = df.set_index("gene", sort=False) - else: - df.set_index(adata.var_names, inplace=True) - df.index.name = "gene" + df.set_index(adata.var_names, inplace=True) + df.index.name = "gene" return df def _get_mean_bins( - means: pd.Series | DaskSeries, flavor: Literal["seurat", "cell_ranger"], n_bins: int -) -> pd.Series | DaskSeries: + means: pd.Series, flavor: Literal["seurat", "cell_ranger"], n_bins: int +) -> pd.Series: if flavor == "seurat": bins = n_bins elif flavor == "cell_ranger": @@ -322,20 +304,12 @@ def _get_mean_bins( else: raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') - if isinstance(means, DaskSeries): - # TODO: does map_partitions make sense for bin? It would bin per chunk, not globally - return means.map_partitions(pd.cut, bins=bins) return pd.cut(means, bins=bins) -def _stats_seurat( - mean_bins: pd.Series | DaskSeries, - disp_grouped: SeriesGroupBy | DaskSeriesGroupBy, -) -> pd.DataFrame | DaskDataFrame: +def _stats_seurat(mean_bins: pd.Series, disp_grouped: SeriesGroupBy) -> pd.DataFrame: """Compute mean and std dev per bin.""" disp_bin_stats = disp_grouped.agg(avg="mean", dev=partial(np.std, ddof=1)) - if isinstance(disp_bin_stats, DaskDataFrame): - disp_bin_stats = disp_bin_stats.compute(sync=True) # retrieve those genes that have nan std, these are the ones where # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 @@ -355,48 +329,26 @@ def _stats_seurat( def _stats_cell_ranger( - mean_bins: pd.Series | DaskSeries, - disp_grouped: SeriesGroupBy | DaskSeriesGroupBy, -) -> pd.DataFrame | DaskDataFrame: + mean_bins: pd.Series, + disp_grouped: SeriesGroupBy, +) -> pd.DataFrame: """Compute median and median absolute dev per bin.""" + from statsmodels.robust import mad - is_dask = isinstance(disp_grouped, DaskSeriesGroupBy) with warnings.catch_warnings(): # MAD calculation raises the warning: "Mean of empty slice" warnings.simplefilter("ignore", category=RuntimeWarning) - disp_bin_stats = _aggregate(disp_grouped, ["median", get_mad(dask=is_dask)]) - # Can’t use kwargs in `aggregate`: https://github.com/dask/dask/issues/10836 - disp_bin_stats = disp_bin_stats.rename(columns=dict(median="avg", mad="dev")) + disp_bin_stats = disp_grouped.agg(avg="median", dev=mad) return _unbin(disp_bin_stats, mean_bins) -def _unbin( - df: pd.DataFrame | DaskDataFrame, mean_bins: pd.Series | DaskSeries -) -> pd.DataFrame | DaskDataFrame: +def _unbin(df: pd.DataFrame, mean_bins: pd.Series) -> pd.DataFrame: df = df.loc[mean_bins] df["gene"] = mean_bins.index - if isinstance(df, DaskDataFrame): - return df.set_index("gene", sort=False) df.set_index("gene", inplace=True) return df -def _aggregate( - grouped: ( - DataFrameGroupBy | DaskDataFrameGroupBy | SeriesGroupBy | DaskSeriesGroupBy - ), - arg=None, - **kw, -) -> pd.DataFrame | DaskDataFrame | pd.Series | DaskSeries: - # ValueError: In order to aggregate with 'median', - # you must use shuffling-based aggregation (e.g., shuffle='tasks') - if ((arg and "median" in arg) or "median" in kw) and isinstance( - grouped, (DaskSeriesGroupBy, DaskDataFrameGroupBy) - ): - kw["shuffle"] = True - return grouped.agg(arg, **kw) - - def _subset_genes( adata: AnnData, *, @@ -443,7 +395,7 @@ def _highly_variable_genes_batched( n_bins: int, flavor: Literal["seurat", "cell_ranger"], cutoff: _Cutoffs | int, -) -> pd.DataFrame | DaskDataFrame: +) -> pd.DataFrame: sanitize_anndata(adata) batches = adata.obs[batch_key].cat.categories dfs = [] @@ -468,10 +420,7 @@ def _highly_variable_genes_batched( adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) assert hvg.index.name == "gene" - if isinstance(hvg, DaskDataFrame): - hvg = hvg.reset_index(drop=False) - else: - hvg.reset_index(drop=False, inplace=True) + hvg.reset_index(drop=False, inplace=True) if (n_removed := np.sum(~filt)) > 0: # Add 0 values for genes that were filtered out @@ -485,13 +434,7 @@ def _highly_variable_genes_batched( dfs.append(hvg) - df: DaskDataFrame | pd.DataFrame - if isinstance(dfs[0], DaskDataFrame): - import dask.dataframe as dd - - df = dd.concat(dfs, axis=0) - else: - df = pd.concat(dfs, axis=0) + df = pd.concat(dfs, axis=0) df["highly_variable"] = df["highly_variable"].astype(int) df = df.groupby("gene", observed=True).agg( @@ -502,8 +445,6 @@ def _highly_variable_genes_batched( highly_variable="sum", ) ) - if isinstance(df, DaskDataFrame): - df = df.set_index("gene", sort=False) # happens automatically for pandas df df["highly_variable_nbatches"] = df["highly_variable"] df["highly_variable_intersection"] = df["highly_variable_nbatches"] == len(batches) @@ -556,7 +497,7 @@ def highly_variable_genes( inplace: bool = True, batch_key: str | None = None, check_values: bool = True, -) -> pd.DataFrame | DaskDataFrame | None: +) -> pd.DataFrame | None: """\ Annotate highly variable genes [Satija15]_ [Zheng17]_ [Stuart19]_. @@ -711,8 +652,6 @@ def highly_variable_genes( return df - if isinstance(df, DaskDataFrame): - df = df.compute(sync=True) adata.uns["hvg"] = {"flavor": flavor} logg.hint( "added\n" diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index a3ea199cb3..b49714d633 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -11,7 +11,6 @@ from scipy import sparse import scanpy as sc -from scanpy._compat import DaskDataFrame from scanpy.testing._helpers import _check_check_values_warnings from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced from scanpy.testing._pytest.marks import needs @@ -101,12 +100,8 @@ def test_no_inplace(adata, array_type, batch_key): hvg_df = sc.pp.highly_variable_genes( adata, batch_key=batch_key, n_bins=3, inplace=False ) - assert hvg_df is not None + assert isinstance(hvg_df, pd.DataFrame) assert colnames == set(hvg_df.columns) - if "dask" in array_type.__name__: - assert isinstance(hvg_df, DaskDataFrame) - else: - assert isinstance(hvg_df, pd.DataFrame) @pytest.mark.parametrize("base", [None, 10]) @@ -580,17 +575,12 @@ def test_cellranger_n_top_genes_warning(): @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) @pytest.mark.parametrize("inplace", [True, False], ids=["inplace", "copy"]) -def test_subset_inplace_consistency( - request: pytest.FixtureRequest, flavor, array_type, subset, inplace -): +def test_subset_inplace_consistency(flavor, array_type, subset, inplace): """Tests that, with `n_top_genes=n` - `inplace` and `subset` interact correctly - for both the `seurat` and `cell_ranger` flavors - for dask arrays and non-dask arrays """ - if flavor == "cell_ranger" and "dask" in array_type.__name__: - request.applymarker(mark_no_cell_ranger) - adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0) adata.X = array_type(np.abs(adata.X).astype(int)) @@ -617,15 +607,10 @@ def test_subset_inplace_consistency( assert (output_df is None) == inplace assert len(adata.var if inplace else output_df) == (15 if subset else n_genes) if output_df is not None: - if "dask" in array_type.__name__: - assert isinstance(output_df, DaskDataFrame) - else: - assert isinstance(output_df, pd.DataFrame) + assert isinstance(output_df, pd.DataFrame) -@pytest.mark.parametrize( - "flavor", ["seurat", pytest.param("cell_ranger", marks=mark_no_cell_ranger)] -) +@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) @pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) @pytest.mark.parametrize( "to_dask", [p for p in ARRAY_TYPES_SUPPORTED if "dask" in p.values[0].__name__] @@ -646,9 +631,9 @@ def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): ) assert isinstance(output_mem, pd.DataFrame) - assert isinstance(output_dask, DaskDataFrame) + assert isinstance(output_dask, pd.DataFrame) assert_index_equal(adata.var_names, output_mem.index, check_names=False) - assert_index_equal(adata.var_names, output_dask.index.compute(), check_names=False) + assert_index_equal(adata.var_names, output_dask.index, check_names=False) - assert_frame_equal(output_mem, output_dask.compute()) + assert_frame_equal(output_mem, output_dask) From 4398114c51ad17468cc6e04bbfa1e6dd88134614 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 1 Feb 2024 12:04:51 +0100 Subject: [PATCH 31/50] remove unneeded shims --- scanpy/_compat.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/scanpy/_compat.py b/scanpy/_compat.py index a7022dded2..a89e8d20cf 100644 --- a/scanpy/_compat.py +++ b/scanpy/_compat.py @@ -16,27 +16,11 @@ try: from dask.array import Array as DaskArray - from dask.dataframe import DataFrame as DaskDataFrame - from dask.dataframe import Series as DaskSeries - from dask.dataframe.groupby import DataFrameGroupBy as DaskDataFrameGroupBy - from dask.dataframe.groupby import SeriesGroupBy as DaskSeriesGroupBy except ImportError: class DaskArray: pass - class DaskDataFrame: - pass - - class DaskSeries: - pass - - class DaskDataFrameGroupBy: - pass - - class DaskSeriesGroupBy: - pass - try: from zappy.base import ZappyArray @@ -50,10 +34,6 @@ class ZappyArray: "cache", "DaskArray", "ZappyArray", - "DaskDataFrame", - "DaskSeries", - "DaskDataFrameGroupBy", - "DaskSeriesGroupBy", "fullname", "pkg_metadata", "pkg_version", From db5c191e10a1add3bc339c69d4d9627eb14cc061 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 1 Feb 2024 12:58:39 +0100 Subject: [PATCH 32/50] remove series_to_array --- scanpy/preprocessing/_distributed.py | 26 +------------------ .../preprocessing/_highly_variable_genes.py | 9 +++---- 2 files changed, 5 insertions(+), 30 deletions(-) diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index 868db988f4..91db7e7149 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -4,36 +4,12 @@ import numpy as np -from scanpy._compat import DaskArray, DaskSeries, ZappyArray +from scanpy._compat import DaskArray, ZappyArray if TYPE_CHECKING: - import pandas as pd from numpy.typing import ArrayLike -@overload -def series_to_array(s: pd.Series, *, dtype: np.dtype | None = None) -> np.ndarray: - ... - - -@overload -def series_to_array(s: DaskSeries, *, dtype: np.dtype | None = None) -> DaskArray: - ... - - -def series_to_array( - s: pd.Series | DaskSeries, *, dtype: np.dtype | None = None -) -> np.ndarray | DaskArray: - """Convert Series to Array, keeping them in-memory or distributed.""" - if isinstance(s, DaskSeries): - return ( - s.to_dask_array(True) - if dtype is None - else s.astype(dtype).to_dask_array(True) - ) - return s.to_numpy() if dtype is None else s.to_numpy().astype(dtype, copy=False) - - @overload def materialize_as_ndarray(a: ArrayLike) -> np.ndarray: ... diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index a7703bb3cd..4358cf07ad 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -16,7 +16,7 @@ from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata from ..get import _get_obs_rep -from ._distributed import materialize_as_ndarray, series_to_array +from ._distributed import materialize_as_ndarray from ._simple import filter_genes from ._utils import _get_mean_var @@ -285,7 +285,7 @@ def _highly_variable_genes_single_batch( df["highly_variable"] = _subset_genes( adata, mean=mean, - dispersion_norm=series_to_array(df["dispersions_norm"]), + dispersion_norm=df["dispersions_norm"].to_numpy(), cutoff=cutoff, ) @@ -329,8 +329,7 @@ def _stats_seurat(mean_bins: pd.Series, disp_grouped: SeriesGroupBy) -> pd.DataF def _stats_cell_ranger( - mean_bins: pd.Series, - disp_grouped: SeriesGroupBy, + mean_bins: pd.Series, disp_grouped: SeriesGroupBy ) -> pd.DataFrame: """Compute median and median absolute dev per bin.""" from statsmodels.robust import mad @@ -459,7 +458,7 @@ def _highly_variable_genes_batched( ) df["highly_variable"] = np.arange(df.shape[0]) < cutoff else: - dispersion_norm = series_to_array(df["dispersions_norm"]) + dispersion_norm = df["dispersions_norm"].to_numpy() dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat df["highly_variable"] = cutoff.in_bounds(df["means"], df["dispersions_norm"]) From ab21f003a8696976487fb96266bd599214b7c42d Mon Sep 17 00:00:00 2001 From: Philipp A Date: Tue, 6 Feb 2024 16:46:19 +0100 Subject: [PATCH 33/50] Update scanpy/preprocessing/_highly_variable_genes.py Co-authored-by: Isaac Virshup --- scanpy/preprocessing/_highly_variable_genes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 4358cf07ad..7e6735dad6 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -220,8 +220,8 @@ def validate( def in_bounds( self, - mean: NDArray[np.float64] | DaskArray, - dispersion_norm: NDArray[np.float64] | DaskArray, + mean: NDArray[np.floating] | DaskArray, + dispersion_norm: NDArray[np.floating] | DaskArray, ) -> NDArray[np.bool_] | DaskArray: return ( (mean > self.min_mean) From b1b53bdaf071182b22d18261dac4ba50d7cd540b Mon Sep 17 00:00:00 2001 From: Philipp A Date: Tue, 6 Feb 2024 17:01:18 +0100 Subject: [PATCH 34/50] remove unused mark Co-authored-by: Isaac Virshup --- scanpy/tests/test_highly_variable_genes.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index b49714d633..6ad01703d8 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -566,9 +566,6 @@ def test_cellranger_n_top_genes_warning(): sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger") -mark_no_cell_ranger = pytest.mark.xfail( - reason="See https://github.com/dask/dask/issues/10853" -) @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) From 4e099334d65cc390818fcf691bd8126eb5b40a27 Mon Sep 17 00:00:00 2001 From: Philipp A Date: Tue, 6 Feb 2024 17:03:16 +0100 Subject: [PATCH 35/50] use pandas ddof Co-authored-by: Isaac Virshup --- scanpy/preprocessing/_highly_variable_genes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 7e6735dad6..a78fd1af03 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -309,7 +309,7 @@ def _get_mean_bins( def _stats_seurat(mean_bins: pd.Series, disp_grouped: SeriesGroupBy) -> pd.DataFrame: """Compute mean and std dev per bin.""" - disp_bin_stats = disp_grouped.agg(avg="mean", dev=partial(np.std, ddof=1)) + disp_bin_stats = disp_grouped.agg(avg="mean", dev="std") # retrieve those genes that have nan std, these are the ones where # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 From 698af7d1873a267460505f083272111982518e46 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 6 Feb 2024 17:12:12 +0100 Subject: [PATCH 36/50] Use COW compatible indexing --- scanpy/preprocessing/_highly_variable_genes.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index a78fd1af03..4453d8f350 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -2,7 +2,6 @@ import warnings from dataclasses import dataclass -from functools import partial from inspect import signature from typing import TYPE_CHECKING, Literal, cast @@ -321,10 +320,10 @@ def _stats_seurat(mean_bins: pd.Series, disp_grouped: SeriesGroupBy) -> pd.DataF "normalized dispersion was set to 1.\n " "Decreasing `n_bins` will likely avoid this effect." ) - disp_bin_stats["dev"].loc[one_gene_per_bin] = disp_bin_stats["avg"].loc[ - one_gene_per_bin + disp_bin_stats.loc[one_gene_per_bin, "dev"] = disp_bin_stats.loc[ + one_gene_per_bin, "avg" ] - disp_bin_stats["avg"].loc[one_gene_per_bin] = 0 + disp_bin_stats.loc[one_gene_per_bin, "avg"] = 0 return _unbin(disp_bin_stats, mean_bins) From 850aae4f4e4f73d118f7870fc9a5e5e793a1a6af Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 6 Feb 2024 17:17:47 +0100 Subject: [PATCH 37/50] remove _unbin --- scanpy/preprocessing/_highly_variable_genes.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 4453d8f350..216d7c6117 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -324,7 +324,7 @@ def _stats_seurat(mean_bins: pd.Series, disp_grouped: SeriesGroupBy) -> pd.DataF one_gene_per_bin, "avg" ] disp_bin_stats.loc[one_gene_per_bin, "avg"] = 0 - return _unbin(disp_bin_stats, mean_bins) + return disp_bin_stats.loc[mean_bins].set_index(mean_bins.index) def _stats_cell_ranger( @@ -337,14 +337,7 @@ def _stats_cell_ranger( # MAD calculation raises the warning: "Mean of empty slice" warnings.simplefilter("ignore", category=RuntimeWarning) disp_bin_stats = disp_grouped.agg(avg="median", dev=mad) - return _unbin(disp_bin_stats, mean_bins) - - -def _unbin(df: pd.DataFrame, mean_bins: pd.Series) -> pd.DataFrame: - df = df.loc[mean_bins] - df["gene"] = mean_bins.index - df.set_index("gene", inplace=True) - return df + return disp_bin_stats.loc[mean_bins].set_index(mean_bins.index) def _subset_genes( From cdfd23148e97d7f89a53c24175142671e57fc5e7 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 6 Feb 2024 17:22:05 +0100 Subject: [PATCH 38/50] =?UTF-8?q?Don=E2=80=99t=20set=20index=20name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scanpy/preprocessing/_highly_variable_genes.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 216d7c6117..ebd6521a23 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -289,7 +289,6 @@ def _highly_variable_genes_single_batch( ) df.set_index(adata.var_names, inplace=True) - df.index.name = "gene" return df @@ -410,8 +409,7 @@ def _highly_variable_genes_batched( hvg = _highly_variable_genes_single_batch( adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) - assert hvg.index.name == "gene" - hvg.reset_index(drop=False, inplace=True) + hvg.reset_index(drop=False, inplace=True, names=["gene"]) if (n_removed := np.sum(~filt)) > 0: # Add 0 values for genes that were filtered out From a949869ba443eb912b8e8ed97a4b0e660a687c6b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 6 Feb 2024 17:50:52 +0100 Subject: [PATCH 39/50] restructure stats functions --- .../preprocessing/_highly_variable_genes.py | 59 ++++++++++--------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index ebd6521a23..84d57aeb66 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -21,7 +21,6 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from pandas.core.groupby.generic import SeriesGroupBy def _highly_variable_genes_seurat_v3( @@ -271,13 +270,7 @@ def _highly_variable_genes_single_batch( dict(zip(["means", "dispersions"], materialize_as_ndarray((mean, dispersion)))) ) df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins) - disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] - if flavor == "seurat": - disp_stats = _stats_seurat(df["mean_bin"], disp_grouped) - elif flavor == "cell_ranger": - disp_stats = _stats_cell_ranger(df["mean_bin"], disp_grouped) - else: - raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') + disp_stats = _get_disp_stats(df, flavor) # actually do the normalization df["dispersions_norm"] = (df["dispersions"] - disp_stats["avg"]) / disp_stats["dev"] @@ -305,38 +298,46 @@ def _get_mean_bins( return pd.cut(means, bins=bins) -def _stats_seurat(mean_bins: pd.Series, disp_grouped: SeriesGroupBy) -> pd.DataFrame: - """Compute mean and std dev per bin.""" - disp_bin_stats = disp_grouped.agg(avg="mean", dev="std") +def _get_disp_stats( + df: pd.DataFrame, flavor: Literal["seurat", "cell_ranger"] +) -> pd.DataFrame: + disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] + if flavor == "seurat": + disp_bin_stats = disp_grouped.agg(avg="mean", dev="std") + _postprocess_seurat(disp_bin_stats, df["mean_bin"]) + elif flavor == "cell_ranger": + disp_bin_stats = disp_grouped.agg(avg="median", dev=_mad) + else: + raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') + return disp_bin_stats.loc[df["mean_bin"]].set_index(df.index) + + +def _postprocess_seurat(disp_bin_stats: pd.DataFrame, mean_bin: pd.Series) -> None: # retrieve those genes that have nan std, these are the ones where # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 one_gene_per_bin = disp_bin_stats["dev"].isnull() - gen_indices = np.flatnonzero(one_gene_per_bin.loc[mean_bins]) - if len(gen_indices) > 0: - logg.debug( - f"Gene indices {gen_indices} fell into a single bin: their " - "normalized dispersion was set to 1.\n " - "Decreasing `n_bins` will likely avoid this effect." - ) - disp_bin_stats.loc[one_gene_per_bin, "dev"] = disp_bin_stats.loc[ - one_gene_per_bin, "avg" - ] - disp_bin_stats.loc[one_gene_per_bin, "avg"] = 0 - return disp_bin_stats.loc[mean_bins].set_index(mean_bins.index) + gen_indices = np.flatnonzero(one_gene_per_bin.loc[mean_bin]) + if len(gen_indices) == 0: + return + logg.debug( + f"Gene indices {gen_indices} fell into a single bin: their " + "normalized dispersion was set to 1.\n " + "Decreasing `n_bins` will likely avoid this effect." + ) + disp_bin_stats.loc[one_gene_per_bin, "dev"] = disp_bin_stats.loc[ + one_gene_per_bin, "avg" + ] + disp_bin_stats.loc[one_gene_per_bin, "avg"] = 0 -def _stats_cell_ranger( - mean_bins: pd.Series, disp_grouped: SeriesGroupBy -) -> pd.DataFrame: - """Compute median and median absolute dev per bin.""" +def _mad(a): from statsmodels.robust import mad with warnings.catch_warnings(): # MAD calculation raises the warning: "Mean of empty slice" warnings.simplefilter("ignore", category=RuntimeWarning) - disp_bin_stats = disp_grouped.agg(avg="median", dev=mad) - return disp_bin_stats.loc[mean_bins].set_index(mean_bins.index) + return mad(a) def _subset_genes( From f1832b65872f0b65c6c4ce31b19c13491228d2b9 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 6 Feb 2024 17:53:36 +0100 Subject: [PATCH 40/50] relnotes --- docs/release-notes/1.10.0.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes/1.10.0.md b/docs/release-notes/1.10.0.md index f98e945f04..a303bfa06a 100644 --- a/docs/release-notes/1.10.0.md +++ b/docs/release-notes/1.10.0.md @@ -14,7 +14,7 @@ * {func}`scanpy.pp.pca`, {func}`scanpy.pp.scale`, {func}`scanpy.pl.embedding`, and {func}`scanpy.experimental.pp.normalize_pearson_residuals_pca` now support a `mask` parameter {pr}`2272` {smaller}`C Bright, T Marcella, & P Angerer` * {func}`scanpy.tl.rank_genes_groups` no longer warns that it's default was changed from t-test_overestim_var to t-test {pr}`2798` {smaller}`L Heumos` -* {func}`scanpy.pp.highly_variable_genes` supports dask for the default `seurat` flavor and partially for the `cell_ranger` flavor {pr}`2809` {smaller}`P Angerer` +* {func}`scanpy.pp.highly_variable_genes` supports dask for the default `seurat` and `cell_ranger` flavors {pr}`2809` {smaller}`P Angerer` ```{rubric} Docs ``` From 222f340e3f498e4ee456219b24208160018ac637 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 6 Feb 2024 17:58:09 +0100 Subject: [PATCH 41/50] fmt --- scanpy/tests/test_highly_variable_genes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 6ad01703d8..6a840204db 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -566,8 +566,6 @@ def test_cellranger_n_top_genes_warning(): sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger") - - @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) From e1f3cca5f8ce42611374d51e802735b43ba8b64a Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 12 Feb 2024 11:03:43 +0000 Subject: [PATCH 42/50] try showing patch coverage --- .codecov.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.codecov.yml b/.codecov.yml index 9dd8f244af..68cc92f2dd 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -8,7 +8,6 @@ coverage: default: # Require 1% coverage, i.e., always succeed target: 1 - patch: false changes: false comment: From 5ecc09e20e561911cec9d10a2a38bc4aa6b1ee68 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 12 Feb 2024 11:07:00 +0000 Subject: [PATCH 43/50] Revert "try showing patch coverage" This reverts commit e1f3cca5f8ce42611374d51e802735b43ba8b64a. --- .codecov.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.codecov.yml b/.codecov.yml index 68cc92f2dd..9dd8f244af 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -8,6 +8,7 @@ coverage: default: # Require 1% coverage, i.e., always succeed target: 1 + patch: false changes: false comment: From 92da20bf6c38d9b8b9beb9d27cf3bdc195ab9b9b Mon Sep 17 00:00:00 2001 From: Philipp A Date: Mon, 12 Feb 2024 14:34:32 +0100 Subject: [PATCH 44/50] index = Co-authored-by: Isaac Virshup --- scanpy/preprocessing/_highly_variable_genes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 84d57aeb66..dab65282f8 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -281,7 +281,7 @@ def _highly_variable_genes_single_batch( cutoff=cutoff, ) - df.set_index(adata.var_names, inplace=True) + df.index = adata.var_names return df From 1cf81c4ae8429368479f955deec722238ea44db6 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 12 Feb 2024 15:12:01 +0100 Subject: [PATCH 45/50] _postprocess_dispersions_seurat --- scanpy/preprocessing/_highly_variable_genes.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index dab65282f8..c972f0ab50 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -304,7 +304,7 @@ def _get_disp_stats( disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] if flavor == "seurat": disp_bin_stats = disp_grouped.agg(avg="mean", dev="std") - _postprocess_seurat(disp_bin_stats, df["mean_bin"]) + _postprocess_dispersions_seurat(disp_bin_stats, df["mean_bin"]) elif flavor == "cell_ranger": disp_bin_stats = disp_grouped.agg(avg="median", dev=_mad) else: @@ -312,7 +312,9 @@ def _get_disp_stats( return disp_bin_stats.loc[df["mean_bin"]].set_index(df.index) -def _postprocess_seurat(disp_bin_stats: pd.DataFrame, mean_bin: pd.Series) -> None: +def _postprocess_dispersions_seurat( + disp_bin_stats: pd.DataFrame, mean_bin: pd.Series +) -> None: # retrieve those genes that have nan std, these are the ones where # only a single gene fell in the bin and implicitly set them to have # a normalized disperion of 1 From 6d213267f8c3dca14be9cc6498ad3ba32fb3bf5f Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 12 Feb 2024 15:54:32 +0100 Subject: [PATCH 46/50] Check log entries --- scanpy/logging.py | 4 ++-- scanpy/tests/conftest.py | 9 +++++++++ scanpy/tests/test_highly_variable_genes.py | 10 ++++++++++ 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/scanpy/logging.py b/scanpy/logging.py index f59d2c6831..8784c17275 100644 --- a/scanpy/logging.py +++ b/scanpy/logging.py @@ -86,8 +86,8 @@ def _set_log_file(settings: ScanpyConfig): def _set_log_level(settings: ScanpyConfig, level: int): root = settings._root_logger root.setLevel(level) - (h,) = root.handlers # may only be 1 - h.setLevel(level) + for h in root.handlers: + h.setLevel(level) class _LogFormatter(logging.Formatter): diff --git a/scanpy/tests/conftest.py b/scanpy/tests/conftest.py index 5e91634b6d..fed27da1f9 100644 --- a/scanpy/tests/conftest.py +++ b/scanpy/tests/conftest.py @@ -39,6 +39,15 @@ def close_logs_on_teardown(request): request.addfinalizer(clear_loggers) +@pytest.fixture(autouse=True) +def _caplog_adapter(caplog: pytest.LogCaptureFixture): + import scanpy as sc + + sc.settings._root_logger.addHandler(caplog.handler) + yield + sc.settings._root_logger.removeHandler(caplog.handler) + + @pytest.fixture def imported_modules(): return IMPORTED diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 6a840204db..166a467ad1 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -566,6 +566,16 @@ def test_cellranger_n_top_genes_warning(): sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger") +def test_cutoff_info(caplog: pytest.LogCaptureFixture): + adata = pbmc3k()[:200].copy() + sc.pp.normalize_total(adata) + sc.pp.log1p(adata) + caplog.clear() + with sc.settings.verbosity.override(sc.Verbosity.info): + sc.pp.highly_variable_genes(adata, n_top_genes=10, max_mean=3.1) + assert "If you pass `n_top_genes`, all cutoffs are ignored." in caplog.messages + + @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) @pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) From 5a7c74da8c708ae15ab3c895b315c3c5f5041df7 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 12 Feb 2024 15:59:56 +0100 Subject: [PATCH 47/50] back to X --- scanpy/preprocessing/_highly_variable_genes.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index c972f0ab50..5484505125 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -245,18 +245,18 @@ def _highly_variable_genes_single_batch( A DataFrame that contains the columns `highly_variable`, `means`, `dispersions`, and `dispersions_norm`. """ - data = _get_obs_rep(adata, layer=layer) + X = _get_obs_rep(adata, layer=layer) if flavor == "seurat": - data = data.copy() + X = X.copy() if "log1p" in adata.uns_keys() and adata.uns["log1p"].get("base") is not None: - data *= np.log(adata.uns["log1p"]["base"]) + X *= np.log(adata.uns["log1p"]["base"]) # use out if possible. only possible since we copy the data matrix - if isinstance(data, np.ndarray): - np.expm1(data, out=data) + if isinstance(X, np.ndarray): + np.expm1(X, out=X) else: - data = np.expm1(data) + X = np.expm1(X) - mean, var = _get_mean_var(data) + mean, var = _get_mean_var(X) # now actually compute the dispersion mean[mean == 0] = 1e-12 # set entries equal to zero to small value dispersion = var / mean From 2ae87bd170f9380140e580b5fa6f271559f60013 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 12 Feb 2024 16:30:31 +0100 Subject: [PATCH 48/50] warning instead of info --- scanpy/preprocessing/_deprecated/highly_variable_genes.py | 3 ++- scanpy/preprocessing/_highly_variable_genes.py | 3 ++- scanpy/tests/test_highly_variable_genes.py | 6 ++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/scanpy/preprocessing/_deprecated/highly_variable_genes.py b/scanpy/preprocessing/_deprecated/highly_variable_genes.py index da250c688e..eaa00c754a 100644 --- a/scanpy/preprocessing/_deprecated/highly_variable_genes.py +++ b/scanpy/preprocessing/_deprecated/highly_variable_genes.py @@ -111,7 +111,8 @@ def filter_genes_dispersion( # noqa: PLR0917 if n_top_genes is not None and not all( x is None for x in [min_disp, max_disp, min_mean, max_mean] ): - logg.info("If you pass `n_top_genes`, all cutoffs are ignored.") + msg = "If you pass `n_top_genes`, all cutoffs are ignored." + warnings.warn(msg, UserWarning) if min_disp is None: min_disp = 0.5 if min_mean is None: diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 5484505125..063d845d92 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -213,7 +213,8 @@ def validate( if p.name in cutoffs } if {k: v for k, v in locals().items() if k in cutoffs} != defaults: - logg.info("If you pass `n_top_genes`, all cutoffs are ignored.") + msg = "If you pass `n_top_genes`, all cutoffs are ignored." + warnings.warn(msg, UserWarning) return n_top_genes def in_bounds( diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 166a467ad1..1a94531a3c 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -566,14 +566,12 @@ def test_cellranger_n_top_genes_warning(): sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger") -def test_cutoff_info(caplog: pytest.LogCaptureFixture): +def test_cutoff_info(): adata = pbmc3k()[:200].copy() sc.pp.normalize_total(adata) sc.pp.log1p(adata) - caplog.clear() - with sc.settings.verbosity.override(sc.Verbosity.info): + with pytest.warns(UserWarning, match="pass `n_top_genes`, all cutoffs are ignored"): sc.pp.highly_variable_genes(adata, n_top_genes=10, max_mean=3.1) - assert "If you pass `n_top_genes`, all cutoffs are ignored." in caplog.messages @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) From 3d18c0741deec22d36cf95a1629e447e67a2d94c Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 12 Feb 2024 17:13:15 +0100 Subject: [PATCH 49/50] Fix handler stuf --- scanpy/logging.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scanpy/logging.py b/scanpy/logging.py index 88f501d34e..6111b06483 100644 --- a/scanpy/logging.py +++ b/scanpy/logging.py @@ -76,17 +76,15 @@ def _set_log_file(settings: ScanpyConfig): h = logging.StreamHandler(file) if name is None else logging.FileHandler(name) h.setFormatter(_LogFormatter()) h.setLevel(root.level) - if len(root.handlers) == 1: - root.removeHandler(root.handlers[0]) - elif len(root.handlers) > 1: - raise RuntimeError("Scanpy’s root logger somehow got more than one handler") + for handler in list(root.handlers): + root.removeHandler(handler) root.addHandler(h) def _set_log_level(settings: ScanpyConfig, level: int): root = settings._root_logger root.setLevel(level) - for h in root.handlers: + for h in list(root.handlers): h.setLevel(level) From c0c97569fd758de8c839d44ddcd2399371c3b52e Mon Sep 17 00:00:00 2001 From: Philipp A Date: Thu, 15 Feb 2024 09:57:21 +0100 Subject: [PATCH 50/50] Discard changes to scanpy/tests/conftest.py --- scanpy/tests/conftest.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/scanpy/tests/conftest.py b/scanpy/tests/conftest.py index 71c58d0981..4c92245f27 100644 --- a/scanpy/tests/conftest.py +++ b/scanpy/tests/conftest.py @@ -55,15 +55,6 @@ def _caplog_adapter(caplog: pytest.LogCaptureFixture) -> Generator[None, None, N sc.settings._root_logger.removeHandler(caplog.handler) -@pytest.fixture(autouse=True) -def _caplog_adapter(caplog: pytest.LogCaptureFixture): - import scanpy as sc - - sc.settings._root_logger.addHandler(caplog.handler) - yield - sc.settings._root_logger.removeHandler(caplog.handler) - - @pytest.fixture def imported_modules(): return IMPORTED