Skip to content

normalize_total with numba #3571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3571.performance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up `pp.normalize_total` with a numba kernel for `csr-matrices` {smaller}`S Dicks`
227 changes: 137 additions & 90 deletions src/scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import TYPE_CHECKING
from warnings import warn

import numba
import numpy as np

from .. import logging as logg
from .._compat import CSBase, DaskArray, old_positionals
from .._compat import CSBase, CSCBase, DaskArray, njit, old_positionals
from .._utils import axis_mul_or_truediv, axis_sum, view_to_actual
from ..get import _get_obs_rep, _set_obs_rep

Expand All @@ -19,9 +20,6 @@
dask = None

if TYPE_CHECKING:
from collections.abc import Iterable
from typing import Literal

from anndata import AnnData


Expand All @@ -34,40 +32,129 @@ def _compute_nnz_median(counts: np.ndarray | DaskArray) -> np.floating:
return median


def _normalize_data(X, counts, after=None, *, copy: bool = False):
X = X.copy() if copy else X
if issubclass(X.dtype.type, int | np.integer):
X = X.astype(np.float32) # TODO: Check if float64 should be used
if after is None:
after = _compute_nnz_median(counts)
counts = counts / after
out = X if isinstance(X, np.ndarray | CSBase) else None
return axis_mul_or_truediv(
X, counts, op=truediv, out=out, allow_divide_by_zero=False, axis=0
@njit
def _normalize_csr(
indptr,
indices,
data,
*,
rows,
columns,
exclude_highly_expressed: bool = False,
max_fraction: float = 0.05,
n_threads: int = 10,
):
"""For sparse CSR matrix, compute the normalization factors."""
counts_per_cell = np.zeros(rows, dtype=data.dtype)
for i in numba.prange(rows):
count = 0.0
for j in range(indptr[i], indptr[i + 1]):
count += data[j]
counts_per_cell[i] = count
if exclude_highly_expressed:
counts_per_cols_t = np.zeros((n_threads, columns), dtype=np.int32)
counts_per_cols = np.zeros(columns, dtype=np.int32)

for i in numba.prange(n_threads):
for r in range(i, rows, n_threads):
for j in range(indptr[r], indptr[r + 1]):
if data[j] > max_fraction * counts_per_cell[r]:
minor_index = indices[j]
counts_per_cols_t[i, minor_index] += 1
for c in numba.prange(columns):
counts_per_cols[c] = counts_per_cols_t[:, c].sum()

for i in numba.prange(rows):
count = 0.0
for j in range(indptr[i], indptr[i + 1]):
if counts_per_cols[indices[j]] == 0:
count += data[j]
counts_per_cell[i] = count

return counts_per_cell, counts_per_cols


def _normalize_total_helper(
x: np.ndarray | CSBase | DaskArray,
*,
exclude_highly_expressed: bool,
max_fraction: float,
target_sum: float | None,
) -> tuple[np.ndarray | CSBase | DaskArray, np.ndarray, np.ndarray | None]:
"""Calculate the normalized data, counts per cell, and gene subset.

Parameters
----------
See `normalize_total` for details.

Returns
-------
X
The normalized data matrix.
counts_per_cell
The normalization factors used for each cell (counts / target_sum).
gene_subset
If `exclude_highly_expressed=True`, a boolean mask indicating which genes
were not considered highly expressed. Otherwise, `None`.
"""
gene_subset = None
counts_per_cell = None
if isinstance(x, CSBase):
n_threads = numba.get_num_threads()
counts_per_cell, counts_per_cols = _normalize_csr(
x.indptr,
x.indices,
x.data,
rows=x.shape[0],
columns=x.shape[1],
exclude_highly_expressed=exclude_highly_expressed,
max_fraction=max_fraction,
n_threads=n_threads,
)
if target_sum is None:
target_sum = np.median(counts_per_cell)
if exclude_highly_expressed:
gene_subset = ~np.where(counts_per_cols)[0]
else:
counts_per_cell = axis_sum(x, axis=1)
if exclude_highly_expressed:
counts_per_cell = np.ravel(counts_per_cell)
# at least one cell as more than max_fraction of counts per cell
gene_subset = axis_sum(
(x > counts_per_cell[:, None] * max_fraction), axis=0
)
gene_subset = np.asarray(np.ravel(gene_subset) == 0)
counts_per_cell = axis_sum(x[:, gene_subset], axis=1)
counts_per_cell = np.ravel(counts_per_cell)
if target_sum is None:
target_sum = _compute_nnz_median(counts_per_cell)

counts_per_cell = counts_per_cell / target_sum
out = x if isinstance(x, np.ndarray | CSBase) else None
X = axis_mul_or_truediv(
x, counts_per_cell, op=truediv, out=out, allow_divide_by_zero=False, axis=0
)

return X, counts_per_cell, gene_subset


@old_positionals(
"target_sum",
"exclude_highly_expressed",
"max_fraction",
"key_added",
"layer",
"layers",
"layer_norm",
"inplace",
"copy",
)
def normalize_total( # noqa: PLR0912, PLR0915
def normalize_total( # noqa: PLR0912
adata: AnnData,
*,
target_sum: float | None = None,
exclude_highly_expressed: bool = False,
max_fraction: float = 0.05,
key_added: str | None = None,
layer: str | None = None,
layers: Literal["all"] | Iterable[str] | None = None,
layer_norm: str | None = None,
inplace: bool = True,
copy: bool = False,
) -> AnnData | dict[str, np.ndarray] | None:
Expand All @@ -90,8 +177,8 @@ def normalize_total( # noqa: PLR0912, PLR0915
call functions that trigger `.compute()` on the :class:`~dask.array.Array` if `exclude_highly_expressed`
is `True`, `layer_norm` is not `None`, or if `key_added` is not `None`.

Params
------
Parameters
----------
adata
The annotated data matrix of shape `n_obs` × `n_vars`.
Rows correspond to cells and columns to genes.
Expand Down Expand Up @@ -163,7 +250,8 @@ def normalize_total( # noqa: PLR0912, PLR0915
... max_fraction=0.2,
... inplace=False,
... )["X"]
normalizing counts per cell. The following highly-expressed genes are not considered during normalization factor computation:
normalizing counts per cell
The following highly-expressed genes are not considered during normalization factor computation:
['1', '3', '4']
finished (0:00:00)
>>> X_norm
Expand All @@ -182,87 +270,46 @@ def normalize_total( # noqa: PLR0912, PLR0915
msg = "Choose max_fraction between 0 and 1."
raise ValueError(msg)

# Deprecated features
if layers is not None:
warn(
"The `layers` argument is deprecated. Instead, specify individual "
"layers to normalize with `layer`.",
FutureWarning,
stacklevel=2,
)
if layer_norm is not None:
warn(
"The `layer_norm` argument is deprecated. Specify the target size "
"factor directly with `target_sum`.",
FutureWarning,
stacklevel=2,
)

if layers == "all":
layers = adata.layers.keys()
elif isinstance(layers, str):
msg = f"`layers` needs to be a list of strings or 'all', not {layers!r}"
raise ValueError(msg)

view_to_actual(adata)

x = _get_obs_rep(adata, layer=layer)
if x is None:
msg = f"Layer {layer!r} not found in adata."
raise ValueError(msg)
if isinstance(x, CSCBase):
x = x.tocsr()
if not inplace:
x = x.copy()
if issubclass(x.dtype.type, int | np.integer):
x = x.astype(np.float32) # TODO: Check if float64 should be used

start = logg.info("normalizing counts per cell")

X, counts_per_cell, gene_subset = _normalize_total_helper(
x,
exclude_highly_expressed=exclude_highly_expressed,
max_fraction=max_fraction,
target_sum=target_sum,
)

gene_subset = None
msg = "normalizing counts per cell"

counts_per_cell = axis_sum(x, axis=1)
if exclude_highly_expressed:
counts_per_cell = np.ravel(counts_per_cell)

# at least one cell as more than max_fraction of counts per cell

gene_subset = axis_sum((x > counts_per_cell[:, None] * max_fraction), axis=0)
gene_subset = np.asarray(np.ravel(gene_subset) == 0)

msg += (
". The following highly-expressed genes are not considered during "
f"normalization factor computation:\n{adata.var_names[~gene_subset].tolist()}"
logg.info(
"The following highly-expressed genes are not considered during normalization factor computation:\n"
f"{adata.var_names[~gene_subset].tolist()}"
)
counts_per_cell = axis_sum(x[:, gene_subset], axis=1)

start = logg.info(msg)
counts_per_cell = np.ravel(counts_per_cell)

cell_subset = counts_per_cell > 0
if not isinstance(cell_subset, DaskArray) and not np.all(cell_subset):
warn("Some cells have zero counts", UserWarning, stacklevel=2)

dat = dict(
X=X,
norm_factor=counts_per_cell,
)
if inplace:
if key_added is not None:
adata.obs[key_added] = counts_per_cell
_set_obs_rep(
adata, _normalize_data(x, counts_per_cell, target_sum), layer=layer
)
else:
# not recarray because need to support sparse
dat = dict(
X=_normalize_data(x, counts_per_cell, target_sum, copy=True),
norm_factor=counts_per_cell,
)

# Deprecated features
if layer_norm == "after":
after = target_sum
elif layer_norm == "X":
after = np.median(counts_per_cell[cell_subset])
elif layer_norm is None:
after = None
else:
msg = 'layer_norm should be "after", "X" or None'
raise ValueError(msg)

for layer_to_norm in layers if layers is not None else ():
res = normalize_total(
adata, layer=layer_to_norm, target_sum=after, inplace=inplace
)
if not inplace:
dat[layer_to_norm] = res["X"]
adata.obs[key_added] = dat["norm_factor"]
_set_obs_rep(adata, dat["X"], layer=layer)

logg.info(
" finished ({time_passed})",
Expand Down
10 changes: 0 additions & 10 deletions tests/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,6 @@ def test_normalize_total_rep(array_type, dtype):
check_rep_results(sc.pp.normalize_total, X, fields=["layer"])


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("dtype", ["float32", "int64"])
def test_normalize_total_layers(array_type, dtype):
adata = AnnData(array_type(X_total).astype(dtype))
adata.layers["layer"] = adata.X.copy()
with pytest.warns(FutureWarning, match=r".*layers.*deprecated"):
sc.pp.normalize_total(adata, layers=["layer"])
assert np.allclose(axis_sum(adata.layers["layer"], axis=1), [3.0, 3.0, 3.0])


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("dtype", ["float32", "int64"])
def test_normalize_total_view(array_type, dtype):
Expand Down
Loading