diff --git a/docs/release-notes/1.10.0.md b/docs/release-notes/1.10.0.md index 323ad85d64..37a6adf97f 100644 --- a/docs/release-notes/1.10.0.md +++ b/docs/release-notes/1.10.0.md @@ -17,6 +17,7 @@ now support a `mask` parameter {pr}`2272` {smaller}`C Bright, T Marcella, & P Angerer` * {func}`scanpy.tl.rank_genes_groups` no longer warns that it's default was changed from t-test_overestim_var to t-test {pr}`2798` {smaller}`L Heumos` * {func}`scanpy.pp.highly_variable_genes` has new flavor `seurat_v3_paper` that is in its implementation consistent with the paper description in Stuart et al 2018. {pr}`2792` {smaller}`E Roellin` +* {func}`scanpy.pp.highly_variable_genes` supports dask for the default `seurat` and `cell_ranger` flavors {pr}`2809` {smaller}`P Angerer` ```{rubric} Docs ``` diff --git a/pyproject.toml b/pyproject.toml index a8afbac042..bd5e2b5215 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,7 @@ doc = [ "ipython>=7.20", # for nbsphinx code highlighting "matplotlib!=3.6.1", # TODO: remove necessity for being able to import doc-linked classes + "dask", "scanpy[paga]", "sam-algorithm", ] diff --git a/scanpy/_compat.py b/scanpy/_compat.py index 08fd53bd03..a89e8d20cf 100644 --- a/scanpy/_compat.py +++ b/scanpy/_compat.py @@ -30,7 +30,14 @@ class ZappyArray: pass -__all__ = ["cache", "DaskArray", "fullname", "pkg_metadata", "pkg_version"] +__all__ = [ + "cache", + "DaskArray", + "ZappyArray", + "fullname", + "pkg_metadata", + "pkg_version", +] def fullname(typ: type) -> str: diff --git a/scanpy/preprocessing/_deprecated/highly_variable_genes.py b/scanpy/preprocessing/_deprecated/highly_variable_genes.py index da250c688e..eaa00c754a 100644 --- a/scanpy/preprocessing/_deprecated/highly_variable_genes.py +++ b/scanpy/preprocessing/_deprecated/highly_variable_genes.py @@ -111,7 +111,8 @@ def filter_genes_dispersion( # noqa: PLR0917 if n_top_genes is not None and not all( x is None for x in [min_disp, max_disp, min_mean, max_mean] ): - logg.info("If you pass `n_top_genes`, all cutoffs are ignored.") + msg = "If you pass `n_top_genes`, all cutoffs are ignored." + warnings.warn(msg, UserWarning) if min_disp is None: min_disp = 0.5 if min_mean is None: diff --git a/scanpy/preprocessing/_distributed.py b/scanpy/preprocessing/_distributed.py index 748ec3d671..91db7e7149 100644 --- a/scanpy/preprocessing/_distributed.py +++ b/scanpy/preprocessing/_distributed.py @@ -37,7 +37,7 @@ def materialize_as_ndarray( def materialize_as_ndarray( a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...], ) -> tuple[np.ndarray] | np.ndarray: - """Convert distributed arrays to ndarrays.""" + """Compute distributed arrays and convert them to numpy ndarrays.""" if not isinstance(a, tuple): return np.asarray(a) diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index 3366b93970..e6f925a168 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -1,8 +1,9 @@ from __future__ import annotations import warnings +from dataclasses import dataclass from inspect import signature -from typing import Literal, cast +from typing import TYPE_CHECKING, Literal, cast import numpy as np import pandas as pd @@ -10,13 +11,17 @@ from anndata import AnnData from .. import logging as logg -from .._compat import old_positionals +from .._compat import DaskArray, old_positionals from .._settings import Verbosity, settings from .._utils import check_nonnegative_integers, sanitize_anndata +from ..get import _get_obs_rep from ._distributed import materialize_as_ndarray from ._simple import filter_genes from ._utils import _get_mean_var +if TYPE_CHECKING: + from numpy.typing import NDArray + def _highly_variable_genes_seurat_v3( adata: AnnData, @@ -61,15 +66,15 @@ def _highly_variable_genes_seurat_v3( "Please install skmisc package via `pip install --user scikit-misc" ) df = pd.DataFrame(index=adata.var_names) - X = adata.layers[layer] if layer is not None else adata.X + data = _get_obs_rep(adata, layer=layer) - if check_values and not check_nonnegative_integers(X): + if check_values and not check_nonnegative_integers(data): warnings.warn( f"`flavor='{flavor}'` expects raw count data, but non-integers were found.", UserWarning, ) - df["means"], df["variances"] = _get_mean_var(X) + df["means"], df["variances"] = _get_mean_var(data) if batch_key is None: batch_info = pd.Categorical(np.zeros(adata.shape[0], dtype=int)) @@ -78,11 +83,11 @@ def _highly_variable_genes_seurat_v3( norm_gene_vars = [] for b in np.unique(batch_info): - X_batch = X[batch_info == b] + data_batch = data[batch_info == b] - mean, var = _get_mean_var(X_batch) + mean, var = _get_mean_var(data_batch) not_const = var > 0 - estimat_var = np.zeros(X.shape[1], dtype=np.float64) + estimat_var = np.zeros(data.shape[1], dtype=np.float64) y = np.log10(var[not_const]) x = np.log10(mean[not_const]) @@ -91,9 +96,9 @@ def _highly_variable_genes_seurat_v3( estimat_var[not_const] = model.outputs.fitted_values reg_std = np.sqrt(10**estimat_var) - batch_counts = X_batch.astype(np.float64).copy() + batch_counts = data_batch.astype(np.float64).copy() # clip large values as in Seurat - N = X_batch.shape[0] + N = data_batch.shape[0] vmax = np.sqrt(N) clip_val = reg_std * vmax + mean if sp_sparse.issparse(batch_counts): @@ -186,15 +191,55 @@ def _highly_variable_genes_seurat_v3( return df +@dataclass +class _Cutoffs: + min_disp: float + max_disp: float + min_mean: float + max_mean: float + + @classmethod + def validate( + cls, + *, + n_top_genes: int | None, + min_disp: float, + max_disp: float, + min_mean: float, + max_mean: float, + ) -> _Cutoffs | int: + if n_top_genes is None: + return cls(min_disp, max_disp, min_mean, max_mean) + + cutoffs = {"min_disp", "max_disp", "min_mean", "max_mean"} + defaults = { + p.name: p.default + for p in signature(highly_variable_genes).parameters.values() + if p.name in cutoffs + } + if {k: v for k, v in locals().items() if k in cutoffs} != defaults: + msg = "If you pass `n_top_genes`, all cutoffs are ignored." + warnings.warn(msg, UserWarning) + return n_top_genes + + def in_bounds( + self, + mean: NDArray[np.floating] | DaskArray, + dispersion_norm: NDArray[np.floating] | DaskArray, + ) -> NDArray[np.bool_] | DaskArray: + return ( + (mean > self.min_mean) + & (mean < self.max_mean) + & (dispersion_norm > self.min_disp) + & (dispersion_norm < self.max_disp) + ) + + def _highly_variable_genes_single_batch( adata: AnnData, *, layer: str | None = None, - min_disp: float | None = 0.5, - max_disp: float | None = np.inf, - min_mean: float | None = 0.0125, - max_mean: float | None = 3, - n_top_genes: int | None = None, + cutoff: _Cutoffs | int, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger"] = "seurat", ) -> pd.DataFrame: @@ -206,18 +251,18 @@ def _highly_variable_genes_single_batch( A DataFrame that contains the columns `highly_variable`, `means`, `dispersions`, and `dispersions_norm`. """ - X = adata.layers[layer] if layer is not None else adata.X + X = _get_obs_rep(adata, layer=layer) if flavor == "seurat": X = X.copy() if "log1p" in adata.uns_keys() and adata.uns["log1p"].get("base") is not None: X *= np.log(adata.uns["log1p"]["base"]) - # use out if possible. only possible since we copy X + # use out if possible. only possible since we copy the data matrix if isinstance(X, np.ndarray): np.expm1(X, out=X) else: X = np.expm1(X) - mean, var = materialize_as_ndarray(_get_mean_var(X)) + mean, var = _get_mean_var(X) # now actually compute the dispersion mean[mean == 0] = 1e-12 # set entries equal to zero to small value dispersion = var / mean @@ -225,89 +270,197 @@ def _highly_variable_genes_single_batch( dispersion[dispersion == 0] = np.nan dispersion = np.log(dispersion) mean = np.log1p(mean) + # all of the following quantities are "per-gene" here - df = pd.DataFrame() - df["means"] = mean - df["dispersions"] = dispersion + df = pd.DataFrame( + dict(zip(["means", "dispersions"], materialize_as_ndarray((mean, dispersion)))) + ) + df["mean_bin"] = _get_mean_bins(df["means"], flavor, n_bins) + disp_stats = _get_disp_stats(df, flavor) + + # actually do the normalization + df["dispersions_norm"] = (df["dispersions"] - disp_stats["avg"]) / disp_stats["dev"] + df["highly_variable"] = _subset_genes( + adata, + mean=mean, + dispersion_norm=df["dispersions_norm"].to_numpy(), + cutoff=cutoff, + ) + + df.index = adata.var_names + return df + + +def _get_mean_bins( + means: pd.Series, flavor: Literal["seurat", "cell_ranger"], n_bins: int +) -> pd.Series: if flavor == "seurat": - df["mean_bin"] = pd.cut(df["means"], bins=n_bins) - disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] - disp_mean_bin = disp_grouped.mean() - disp_std_bin = disp_grouped.std(ddof=1) - # retrieve those genes that have nan std, these are the ones where - # only a single gene fell in the bin and implicitly set them to have - # a normalized disperion of 1 - one_gene_per_bin = disp_std_bin.isnull() - gen_indices = np.where(one_gene_per_bin[df["mean_bin"].to_numpy()])[0].tolist() - if len(gen_indices) > 0: - logg.debug( - f"Gene indices {gen_indices} fell into a single bin: their " - "normalized dispersion was set to 1.\n " - "Decreasing `n_bins` will likely avoid this effect." - ) - # Circumvent pandas 0.23 bug. Both sides of the assignment have dtype==float32, - # but there’s still a dtype error without “.value”. - disp_std_bin[one_gene_per_bin.to_numpy()] = disp_mean_bin[ - one_gene_per_bin.to_numpy() - ].to_numpy() - disp_mean_bin[one_gene_per_bin.to_numpy()] = 0 - # actually do the normalization - df["dispersions_norm"] = ( - df["dispersions"].to_numpy() # use values here as index differs - - disp_mean_bin[df["mean_bin"].to_numpy()].to_numpy() - ) / disp_std_bin[df["mean_bin"].to_numpy()].to_numpy() + bins = n_bins elif flavor == "cell_ranger": - from statsmodels import robust + bins = np.r_[-np.inf, np.percentile(means, np.arange(10, 105, 5)), np.inf] + else: + raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') - df["mean_bin"] = pd.cut( - df["means"], - np.r_[-np.inf, np.percentile(df["means"], np.arange(10, 105, 5)), np.inf], - ) - disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] - disp_median_bin = disp_grouped.median() - # the next line raises the warning: "Mean of empty slice" - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - disp_mad_bin = disp_grouped.apply(robust.mad) - df["dispersions_norm"] = ( - df["dispersions"].to_numpy() - - disp_median_bin[df["mean_bin"].to_numpy()].to_numpy() - ) / disp_mad_bin[df["mean_bin"].to_numpy()].to_numpy() + return pd.cut(means, bins=bins) + + +def _get_disp_stats( + df: pd.DataFrame, flavor: Literal["seurat", "cell_ranger"] +) -> pd.DataFrame: + disp_grouped = df.groupby("mean_bin", observed=True)["dispersions"] + if flavor == "seurat": + disp_bin_stats = disp_grouped.agg(avg="mean", dev="std") + _postprocess_dispersions_seurat(disp_bin_stats, df["mean_bin"]) + elif flavor == "cell_ranger": + disp_bin_stats = disp_grouped.agg(avg="median", dev=_mad) else: raise ValueError('`flavor` needs to be "seurat" or "cell_ranger"') - dispersion_norm = df["dispersions_norm"].to_numpy() - if n_top_genes is not None: - dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)] - dispersion_norm[ - ::-1 - ].sort() # interestingly, np.argpartition is slightly slower - if n_top_genes > adata.n_vars: - logg.info("`n_top_genes` > `adata.n_var`, returning all genes.") - n_top_genes = adata.n_vars - if n_top_genes > dispersion_norm.size: - warnings.warn( - "`n_top_genes` > number of normalized dispersions, returning all genes with normalized dispersions.", - UserWarning, + return disp_bin_stats.loc[df["mean_bin"]].set_index(df.index) + + +def _postprocess_dispersions_seurat( + disp_bin_stats: pd.DataFrame, mean_bin: pd.Series +) -> None: + # retrieve those genes that have nan std, these are the ones where + # only a single gene fell in the bin and implicitly set them to have + # a normalized disperion of 1 + one_gene_per_bin = disp_bin_stats["dev"].isnull() + gen_indices = np.flatnonzero(one_gene_per_bin.loc[mean_bin]) + if len(gen_indices) == 0: + return + logg.debug( + f"Gene indices {gen_indices} fell into a single bin: their " + "normalized dispersion was set to 1.\n " + "Decreasing `n_bins` will likely avoid this effect." + ) + disp_bin_stats.loc[one_gene_per_bin, "dev"] = disp_bin_stats.loc[ + one_gene_per_bin, "avg" + ] + disp_bin_stats.loc[one_gene_per_bin, "avg"] = 0 + + +def _mad(a): + from statsmodels.robust import mad + + with warnings.catch_warnings(): + # MAD calculation raises the warning: "Mean of empty slice" + warnings.simplefilter("ignore", category=RuntimeWarning) + return mad(a) + + +def _subset_genes( + adata: AnnData, + *, + mean: NDArray[np.float64] | DaskArray, + dispersion_norm: NDArray[np.float64] | DaskArray, + cutoff: _Cutoffs | int, +) -> NDArray[np.bool_] | DaskArray: + """Get boolean mask of genes with normalized dispersion in bounds.""" + if isinstance(cutoff, _Cutoffs): + dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat + return cutoff.in_bounds(mean, dispersion_norm) + n_top_genes = cutoff + del cutoff + + if n_top_genes > adata.n_vars: + logg.info("`n_top_genes` > `adata.n_var`, returning all genes.") + n_top_genes = adata.n_vars + disp_cut_off = _nth_highest(dispersion_norm, n_top_genes) + logg.debug( + f"the {n_top_genes} top genes correspond to a " + f"normalized dispersion cutoff of {disp_cut_off}" + ) + return np.nan_to_num(dispersion_norm) >= disp_cut_off + + +def _nth_highest(x: NDArray[np.float64] | DaskArray, n: int) -> float | DaskArray: + x = x[~np.isnan(x)] + if n > x.size: + msg = "`n_top_genes` > number of normalized dispersions, returning all genes with normalized dispersions." + warnings.warn(msg, UserWarning) + n = x.size + if isinstance(x, DaskArray): + return x.topk(n)[-1] + # interestingly, np.argpartition is slightly slower + x[::-1].sort() + return x[n - 1] + + +def _highly_variable_genes_batched( + adata: AnnData, + batch_key: str, + *, + layer: str | None, + n_bins: int, + flavor: Literal["seurat", "cell_ranger"], + cutoff: _Cutoffs | int, +) -> pd.DataFrame: + sanitize_anndata(adata) + batches = adata.obs[batch_key].cat.categories + dfs = [] + gene_list = adata.var_names + for batch in batches: + adata_subset = adata[adata.obs[batch_key] == batch] + + # Filter to genes that are in the dataset + with settings.verbosity.override(Verbosity.error): + # TODO use groupby or so instead of materialize_as_ndarray + filt, _ = materialize_as_ndarray( + filter_genes( + _get_obs_rep(adata_subset, layer=layer), + min_cells=1, + inplace=False, + ) ) - n_top_genes = dispersion_norm.size - disp_cut_off = dispersion_norm[n_top_genes - 1] - gene_subset = np.nan_to_num(df["dispersions_norm"].to_numpy()) >= disp_cut_off - logg.debug( - f"the {n_top_genes} top genes correspond to a " - f"normalized dispersion cutoff of {disp_cut_off}" + + adata_subset = adata_subset[:, filt] + + hvg = _highly_variable_genes_single_batch( + adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) - else: - dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat - gene_subset = np.logical_and.reduce( - ( - mean > min_mean, - mean < max_mean, - dispersion_norm > min_disp, - dispersion_norm < max_disp, + hvg.reset_index(drop=False, inplace=True, names=["gene"]) + + if (n_removed := np.sum(~filt)) > 0: + # Add 0 values for genes that were filtered out + missing_hvg = pd.DataFrame( + np.zeros((n_removed, len(hvg.columns))), + columns=hvg.columns, ) + missing_hvg["highly_variable"] = missing_hvg["highly_variable"].astype(bool) + missing_hvg["gene"] = gene_list[~filt] + hvg = pd.concat([hvg, missing_hvg], ignore_index=True) + + dfs.append(hvg) + + df = pd.concat(dfs, axis=0) + + df["highly_variable"] = df["highly_variable"].astype(int) + df = df.groupby("gene", observed=True).agg( + dict( + means="mean", + dispersions="mean", + dispersions_norm="mean", + highly_variable="sum", ) + ) + df["highly_variable_nbatches"] = df["highly_variable"] + df["highly_variable_intersection"] = df["highly_variable_nbatches"] == len(batches) + + if isinstance(cutoff, int): + # sort genes by how often they selected as hvg within each batch and + # break ties with normalized dispersion across batches + df.sort_values( + ["highly_variable_nbatches", "dispersions_norm"], + ascending=False, + na_position="last", + inplace=True, + ) + df["highly_variable"] = np.arange(df.shape[0]) < cutoff + else: + dispersion_norm = df["dispersions_norm"].to_numpy() + dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat + df["highly_variable"] = cutoff.in_bounds(df["means"], df["dispersions_norm"]) - df["highly_variable"] = gene_subset return df @@ -331,10 +484,10 @@ def highly_variable_genes( *, layer: str | None = None, n_top_genes: int | None = None, - min_disp: float | None = 0.5, - max_disp: float | None = np.inf, - min_mean: float | None = 0.0125, - max_mean: float | None = 3, + min_disp: float = 0.5, + max_disp: float = np.inf, + min_mean: float = 0.0125, + max_mean: float = 3, span: float = 0.3, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"] = "seurat", @@ -447,20 +600,15 @@ def highly_variable_genes( For `flavor='seurat_v3'`/`'seurat_v3_paper'`, rank of the gene according to normalized variance, in case of multiple batches description above `adata.var['highly_variable_nbatches']` : :class:`pandas.Series` (dtype `int`) - If batch_key is given, this denotes in how many batches genes are detected as HVG + If `batch_key` is given, this denotes in how many batches genes are detected as HVG `adata.var['highly_variable_intersection']` : :class:`pandas.Series` (dtype `bool`) - If batch_key is given, this denotes the genes that are highly variable in all batches + If `batch_key` is given, this denotes the genes that are highly variable in all batches Notes ----- This function replaces :func:`~scanpy.pp.filter_genes_dispersion`. """ - if n_top_genes is not None and not all( - m is None for m in [min_disp, max_disp, min_mean, max_mean] - ): - logg.info("If you pass `n_top_genes`, all cutoffs are ignored.") - start = logg.info("extracting highly variable genes") if not isinstance(adata, AnnData): @@ -485,137 +633,49 @@ def highly_variable_genes( inplace=inplace, ) + cutoff = _Cutoffs.validate( + n_top_genes=n_top_genes, + min_disp=min_disp, + max_disp=max_disp, + min_mean=min_mean, + max_mean=max_mean, + ) + del min_disp, max_disp, min_mean, max_mean, n_top_genes + if batch_key is None: df = _highly_variable_genes_single_batch( - adata, - layer=layer, - min_disp=min_disp, - max_disp=max_disp, - min_mean=min_mean, - max_mean=max_mean, - n_top_genes=n_top_genes, - n_bins=n_bins, - flavor=flavor, + adata, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) else: - sanitize_anndata(adata) - batches = adata.obs[batch_key].cat.categories - df = [] - gene_list = adata.var_names - for batch in batches: - adata_subset = adata[adata.obs[batch_key] == batch] - - # Filter to genes that are in the dataset - with settings.verbosity.override(Verbosity.error): - filt = filter_genes(adata_subset, min_cells=1, inplace=False)[0] - - adata_subset = adata_subset[:, filt] - - hvg = _highly_variable_genes_single_batch( - adata_subset, - layer=layer, - min_disp=min_disp, - max_disp=max_disp, - min_mean=min_mean, - max_mean=max_mean, - n_top_genes=n_top_genes, - n_bins=n_bins, - flavor=flavor, - ) - - hvg["gene"] = adata_subset.var_names.to_numpy() - if (n_removed := np.sum(~filt)) > 0: - # Add 0 values for genes that were filtered out - missing_hvg = pd.DataFrame( - np.zeros((n_removed, len(hvg.columns))), - columns=hvg.columns, - ) - missing_hvg["highly_variable"] = missing_hvg["highly_variable"].astype( - bool - ) - missing_hvg["gene"] = gene_list[~filt] - hvg = pd.concat([hvg, missing_hvg], ignore_index=True) - - # Order as before filtering - idxs = np.concatenate((np.where(filt)[0], np.where(~filt)[0])) - hvg = hvg.loc[np.argsort(idxs)] - - df.append(hvg) - - df = pd.concat(df, axis=0) - df["highly_variable"] = df["highly_variable"].astype(int) - df = df.groupby("gene", observed=True).agg( - dict( - means="mean", - dispersions="mean", - dispersions_norm="mean", - highly_variable="sum", - ) - ) - df.rename( - columns=dict(highly_variable="highly_variable_nbatches"), inplace=True - ) - df["highly_variable_intersection"] = df["highly_variable_nbatches"] == len( - batches + df = _highly_variable_genes_batched( + adata, batch_key, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor ) - if n_top_genes is not None: - # sort genes by how often they selected as hvg within each batch and - # break ties with normalized dispersion across batches - df.sort_values( - ["highly_variable_nbatches", "dispersions_norm"], - ascending=False, - na_position="last", - inplace=True, - ) - high_var = np.zeros(df.shape[0]) - high_var[:n_top_genes] = True - df["highly_variable"] = high_var.astype(bool) - df = df.loc[adata.var_names, :] - else: - df = df.loc[adata.var_names] - dispersion_norm = df["dispersions_norm"].to_numpy() - dispersion_norm[np.isnan(dispersion_norm)] = 0 # similar to Seurat - gene_subset = np.logical_and.reduce( - ( - df["means"] > min_mean, - df["means"] < max_mean, - df["dispersions_norm"] > min_disp, - df["dispersions_norm"] < max_disp, - ) - ) - df["highly_variable"] = gene_subset - logg.info(" finished", time=start) - if inplace: - adata.uns["hvg"] = {"flavor": flavor} - logg.hint( - "added\n" - " 'highly_variable', boolean vector (adata.var)\n" - " 'means', float vector (adata.var)\n" - " 'dispersions', float vector (adata.var)\n" - " 'dispersions_norm', float vector (adata.var)" - ) - adata.var["highly_variable"] = df["highly_variable"].to_numpy() - adata.var["means"] = df["means"].to_numpy() - adata.var["dispersions"] = df["dispersions"].to_numpy() - adata.var["dispersions_norm"] = ( - df["dispersions_norm"].to_numpy().astype("float32", copy=False) - ) - - if batch_key is not None: - adata.var["highly_variable_nbatches"] = df[ - "highly_variable_nbatches" - ].to_numpy() - adata.var["highly_variable_intersection"] = df[ - "highly_variable_intersection" - ].to_numpy() - if subset: - adata._inplace_subset_var(df["highly_variable"].to_numpy()) - - else: + if not inplace: if subset: df = df.loc[df["highly_variable"]] return df + + adata.uns["hvg"] = {"flavor": flavor} + logg.hint( + "added\n" + " 'highly_variable', boolean vector (adata.var)\n" + " 'means', float vector (adata.var)\n" + " 'dispersions', float vector (adata.var)\n" + " 'dispersions_norm', float vector (adata.var)" + ) + adata.var["highly_variable"] = df["highly_variable"] + adata.var["means"] = df["means"] + adata.var["dispersions"] = df["dispersions"] + adata.var["dispersions_norm"] = df["dispersions_norm"].astype( + np.float32, copy=False + ) + + if batch_key is not None: + adata.var["highly_variable_nbatches"] = df["highly_variable_nbatches"] + adata.var["highly_variable_intersection"] = df["highly_variable_intersection"] + if subset: + adata._inplace_subset_var(df["highly_variable"]) diff --git a/scanpy/tests/test_highly_variable_genes.py b/scanpy/tests/test_highly_variable_genes.py index 598bc5617f..27b237ad5b 100644 --- a/scanpy/tests/test_highly_variable_genes.py +++ b/scanpy/tests/test_highly_variable_genes.py @@ -1,17 +1,21 @@ from __future__ import annotations from pathlib import Path +from string import ascii_letters from typing import Literal import numpy as np import pandas as pd import pytest +from anndata import AnnData +from pandas.testing import assert_frame_equal, assert_index_equal from scipy import sparse import scanpy as sc from scanpy.testing._helpers import _check_check_values_warnings from scanpy.testing._helpers.data import pbmc3k, pbmc68k_reduced from scanpy.testing._pytest.marks import needs +from scanpy.testing._pytest.params import ARRAY_TYPES_SUPPORTED FILE = Path(__file__).parent / Path("_scripts/seurat_hvg.csv") FILE_V3 = Path(__file__).parent / Path("_scripts/seurat_hvg_v3.csv.gz") @@ -19,13 +23,24 @@ FILE_CELL_RANGER = Path(__file__).parent / "_scripts/cell_ranger_hvg.csv" -def test_highly_variable_genes_runs(): +@pytest.fixture(scope="session") +def adata_sess() -> AnnData: adata = sc.datasets.blobs() + rng = np.random.default_rng(0) + adata.var_names = rng.choice(list(ascii_letters), adata.n_vars, replace=False) + return adata + + +@pytest.fixture +def adata(adata_sess: AnnData) -> AnnData: + return adata_sess.copy() + + +def test_runs(adata): sc.pp.highly_variable_genes(adata) -def test_highly_variable_genes_supports_batch(): - adata = sc.datasets.blobs() +def test_supports_batch(adata): gen = np.random.default_rng(0) adata.obs["batch"] = pd.array( gen.binomial(3, 0.5, size=adata.n_obs), dtype="category" @@ -35,39 +50,33 @@ def test_highly_variable_genes_supports_batch(): assert "highly_variable_intersection" in adata.var.columns -def test_highly_variable_genes_supports_layers(): - adata = sc.datasets.blobs() - gen = np.random.default_rng(0) - adata.obs["batch"] = pd.array( - gen.binomial(4, 0.5, size=adata.n_obs), dtype="category" - ) - sc.pp.highly_variable_genes(adata, batch_key="batch", n_top_genes=3) - assert "highly_variable_nbatches" in adata.var.columns - assert adata.var["highly_variable"].sum() == 3 - highly_var_first_layer = adata.var["highly_variable"] +def test_supports_layers(adata_sess): + def execute(layer: str | None) -> AnnData: + gen = np.random.default_rng(0) + adata = adata_sess.copy() + assert isinstance(adata.X, np.ndarray) + if layer: + adata.X, adata.layers[layer] = None, adata.X.copy() + gen.shuffle(adata.layers[layer]) + adata.obs["batch"] = pd.array( + gen.binomial(4, 0.5, size=adata.n_obs), dtype="category" + ) + sc.pp.highly_variable_genes( + adata, batch_key="batch", n_top_genes=3, layer=layer + ) + assert "highly_variable_nbatches" in adata.var.columns + assert adata.var["highly_variable"].sum() == 3 + return adata - adata = sc.datasets.blobs() - assert isinstance(adata.X, np.ndarray) - new_layer = adata.X.copy() - gen.shuffle(new_layer) - adata.layers["test_layer"] = new_layer - adata.obs["batch"] = gen.binomial(4, 0.5, size=(adata.n_obs)) - adata.obs["batch"] = adata.obs["batch"].astype("category") - sc.pp.highly_variable_genes( - adata, batch_key="batch", n_top_genes=3, layer="test_layer" - ) - assert "highly_variable_nbatches" in adata.var.columns - assert adata.var["highly_variable"].sum() == 3 - assert (highly_var_first_layer != adata.var["highly_variable"]).any() + adata1, adata2 = map(execute, [None, "test_layer"]) + assert (adata1.var["highly_variable"] != adata2.var["highly_variable"]).any() -def test_highly_variable_genes_no_batch_matches_batch(): - adata = sc.datasets.blobs() +def test_no_batch_matches_batch(adata): sc.pp.highly_variable_genes(adata) no_batch_hvg = adata.var["highly_variable"].copy() assert no_batch_hvg.any() - adata.obs["batch"] = "batch" - adata.obs["batch"] = adata.obs["batch"].astype("category") + adata.obs["batch"] = pd.array(["batch"], dtype="category").repeat(len(adata)) sc.pp.highly_variable_genes(adata, batch_key="batch") assert np.all(no_batch_hvg == adata.var["highly_variable"]) assert np.all( @@ -75,28 +84,31 @@ def test_highly_variable_genes_no_batch_matches_batch(): ) -def test_highly_variable_genes_(): - adata = sc.datasets.blobs() - adata.obs["batch"] = np.tile(["a", "b"], adata.shape[0] // 2) - sc.pp.highly_variable_genes(adata, batch_key="batch") +@pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) +def test_no_inplace(adata, array_type, batch_key): + """Tests that, with `n_top_genes=None` the returned dataframe has the expected columns.""" + adata.X = array_type(adata.X) + if batch_key: + adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2) + sc.pp.highly_variable_genes(adata, batch_key=batch_key, n_bins=3) assert adata.var["highly_variable"].any() - colnames = [ - "means", - "dispersions", - "dispersions_norm", - "highly_variable_nbatches", - "highly_variable_intersection", - "highly_variable", - ] - hvg_df = sc.pp.highly_variable_genes(adata, batch_key="batch", inplace=False) - assert hvg_df is not None - assert np.all(np.isin(colnames, hvg_df.columns)) + colnames = {"means", "dispersions", "dispersions_norm", "highly_variable"} | ( + {"mean_bin"} + if batch_key is None + else {"highly_variable_nbatches", "highly_variable_intersection"} + ) + hvg_df = sc.pp.highly_variable_genes( + adata, batch_key=batch_key, n_bins=3, inplace=False + ) + assert isinstance(hvg_df, pd.DataFrame) + assert colnames == set(hvg_df.columns) @pytest.mark.parametrize("base", [None, 10]) @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) -def test_highly_variable_genes_keep_layer(base, flavor): +def test_keep_layer(base, flavor): adata = pbmc3k() # cell_ranger flavor can raise error if many 0 genes sc.pp.filter_genes(adata, min_counts=1) @@ -124,7 +136,7 @@ def _check_pearson_hvg_columns(output_df: pd.DataFrame, n_top_genes: int): assert np.nanmax(output_df["highly_variable_rank"].to_numpy()) <= n_top_genes - 1 -def test_highly_variable_genes_pearson_residuals_inputchecks(pbmc3k_parametrized_small): +def test_pearson_residuals_inputchecks(pbmc3k_parametrized_small): adata = pbmc3k_parametrized_small() # depending on check_values, warnings should be raised for non-integer data @@ -164,7 +176,7 @@ def test_highly_variable_genes_pearson_residuals_inputchecks(pbmc3k_parametrized ) @pytest.mark.parametrize("theta", [100, np.Inf], ids=["100theta", "inftheta"]) @pytest.mark.parametrize("n_top_genes", [100, 200], ids=["100n", "200n"]) -def test_highly_variable_genes_pearson_residuals_general( +def test_pearson_residuals_general( pbmc3k_parametrized_small, subset, clip, theta, n_top_genes ): adata = pbmc3k_parametrized_small() @@ -248,9 +260,7 @@ def test_highly_variable_genes_pearson_residuals_general( @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) @pytest.mark.parametrize("n_top_genes", [100, 200], ids=["100n", "200n"]) -def test_highly_variable_genes_pearson_residuals_batch( - pbmc3k_parametrized_small, subset, n_top_genes -): +def test_pearson_residuals_batch(pbmc3k_parametrized_small, subset, n_top_genes): adata = pbmc3k_parametrized_small() # cleanup var del adata.var @@ -382,7 +392,7 @@ def test_compare_to_upstream( @needs.skmisc -def test_highly_variable_genes_compare_to_seurat_v3(): +def test_compare_to_seurat_v3(): ### test without batch seurat_hvg_info = pd.read_csv(FILE_V3) @@ -451,7 +461,7 @@ def test_highly_variable_genes_compare_to_seurat_v3(): @needs.skmisc -def test_highly_variable_genes_seurat_v3_warning(): +def test_seurat_v3_warning(): pbmc = pbmc3k()[:200].copy() sc.pp.log1p(pbmc) with pytest.warns( @@ -461,13 +471,13 @@ def test_highly_variable_genes_seurat_v3_warning(): sc.pp.highly_variable_genes(pbmc, flavor="seurat_v3") -def test_highly_variable_genes_batches(): +def test_batches(): adata = pbmc68k_reduced() adata[:100, :100].X = np.zeros((100, 100)) adata.obs["batch"] = ["0" if i < 100 else "1" for i in range(adata.n_obs)] - adata_1 = adata[adata.obs.batch.isin(["0"]), :] - adata_2 = adata[adata.obs.batch.isin(["1"]), :] + adata_1 = adata[adata.obs["batch"] == "0"].copy() + adata_2 = adata[adata.obs["batch"] == "1"].copy() sc.pp.highly_variable_genes( adata, @@ -538,7 +548,7 @@ def test_seurat_v3_mean_var_output_with_batchkey(): def test_cellranger_n_top_genes_warning(): X = np.random.poisson(2, (100, 30)) - adata = sc.AnnData(X) + adata = AnnData(X) sc.pp.normalize_total(adata) sc.pp.log1p(adata) @@ -549,12 +559,26 @@ def test_cellranger_n_top_genes_warning(): sc.pp.highly_variable_genes(adata, n_top_genes=1000, flavor="cell_ranger") +def test_cutoff_info(): + adata = pbmc3k()[:200].copy() + sc.pp.normalize_total(adata) + sc.pp.log1p(adata) + with pytest.warns(UserWarning, match="pass `n_top_genes`, all cutoffs are ignored"): + sc.pp.highly_variable_genes(adata, n_top_genes=10, max_mean=3.1) + + @pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) +@pytest.mark.parametrize("array_type", ARRAY_TYPES_SUPPORTED) @pytest.mark.parametrize("subset", [True, False], ids=["subset", "full"]) @pytest.mark.parametrize("inplace", [True, False], ids=["inplace", "copy"]) -def test_highly_variable_genes_subset_inplace_consistency(flavor, subset, inplace): +def test_subset_inplace_consistency(flavor, array_type, subset, inplace): + """Tests that, with `n_top_genes=n` + - `inplace` and `subset` interact correctly + - for both the `seurat` and `cell_ranger` flavors + - for dask arrays and non-dask arrays + """ adata = sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0) - adata.X = np.abs(adata.X).astype(int) + adata.X = array_type(np.abs(adata.X).astype(int)) if flavor == "seurat" or flavor == "cell_ranger": sc.pp.normalize_total(adata, target_sum=1e4) @@ -578,3 +602,34 @@ def test_highly_variable_genes_subset_inplace_consistency(flavor, subset, inplac assert (output_df is None) == inplace assert len(adata.var if inplace else output_df) == (15 if subset else n_genes) + if output_df is not None: + assert isinstance(output_df, pd.DataFrame) + + +@pytest.mark.parametrize("flavor", ["seurat", "cell_ranger"]) +@pytest.mark.parametrize("batch_key", [None, "batch"], ids=["single", "batched"]) +@pytest.mark.parametrize( + "to_dask", [p for p in ARRAY_TYPES_SUPPORTED if "dask" in p.values[0].__name__] +) +def test_dask_consistency(adata: AnnData, flavor, batch_key, to_dask): + adata.X = np.abs(adata.X).astype(int) + if batch_key is not None: + adata.obs[batch_key] = np.tile(["a", "b"], adata.shape[0] // 2) + sc.pp.normalize_total(adata, target_sum=1e4) + sc.pp.log1p(adata) + + adata_dask = adata.copy() + adata_dask.X = to_dask(adata_dask.X) + + output_mem, output_dask = ( + sc.pp.highly_variable_genes(ad, flavor=flavor, n_top_genes=15, inplace=False) + for ad in [adata, adata_dask] + ) + + assert isinstance(output_mem, pd.DataFrame) + assert isinstance(output_dask, pd.DataFrame) + + assert_index_equal(adata.var_names, output_mem.index, check_names=False) + assert_index_equal(adata.var_names, output_dask.index, check_names=False) + + assert_frame_equal(output_mem, output_dask)