diff --git a/docs/release-notes/3353.performance.md b/docs/release-notes/3353.performance.md new file mode 100644 index 0000000000..0c17717cec --- /dev/null +++ b/docs/release-notes/3353.performance.md @@ -0,0 +1 @@ +Speed up for a categorical regressor in {func}`~scanpy.pp.regress_out` {smaller}`S Dicks` diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 83ed9a544f..86e3f3896e 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -633,6 +633,22 @@ def normalize_per_cell( DT = TypeVar("DT") +@njit +def _create_regressor_categorical( + X: np.ndarray, number_categories: int, cat_array: np.ndarray +) -> np.ndarray: + # create regressor matrix for categorical variables + regressors = np.zeros(X.shape, dtype=X.dtype) + # iterate over categories + for category in range(number_categories): + # iterate over genes and calculate mean expression + # for each gene per category + mask = category == cat_array + for ix in numba.prange(X.T.shape[0]): + regressors[mask, ix] = X.T[ix, mask].mean() + return regressors + + @njit def get_resid( data: np.ndarray, @@ -728,13 +744,15 @@ def regress_out( ) raise ValueError(msg) logg.debug("... regressing on per-gene means within categories") - regressors = np.zeros(X.shape, dtype="float32") + # set number of categories to the same dtype as the categories + cat_array = adata.obs[keys[0]].cat.codes.to_numpy() + number_categories = cat_array.dtype.type(len(adata.obs[keys[0]].cat.categories)) + X = _to_dense(X, order="F") if isinstance(X, CSBase) else X - # TODO figure out if we should use a numba kernel for this - for category in adata.obs[keys[0]].cat.categories: - mask = (category == adata.obs[keys[0]]).values - for ix, x in enumerate(X.T): - regressors[mask, ix] = x[mask].mean() + if np.issubdtype(X.dtype, np.integer): + target_dtype = np.float32 if X.dtype.itemsize < 4 else np.float64 + X = X.astype(target_dtype) + regressors = _create_regressor_categorical(X, number_categories, cat_array) variable_is_categorical = True # regress on one or several ordinal variables else: diff --git a/tests/_data/regress_test_small_cat.npy b/tests/_data/regress_test_small_cat.npy new file mode 100644 index 0000000000..c84f88ef23 Binary files /dev/null and b/tests/_data/regress_test_small_cat.npy differ diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 75312ee01b..bb8fc7f880 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -456,14 +456,21 @@ def test_regress_out_constants(): assert_equal(adata, adata_copy) -def test_regress_out_reproducible(): - adata = pbmc68k_reduced() +@pytest.mark.parametrize( + ("keys", "test_file", "atol"), + [ + (["n_counts", "percent_mito"], "regress_test_small.npy", 0.0), + (["bulk_labels"], "regress_test_small_cat.npy", 1e-6), + ], +) +def test_regress_out_reproducible(keys, test_file, atol): + adata = sc.datasets.pbmc68k_reduced() adata = adata.raw.to_adata()[:200, :200].copy() - sc.pp.regress_out(adata, keys=["n_counts", "percent_mito"]) + sc.pp.regress_out(adata, keys=keys) # This file was generated from the original implementation in version 1.10.3 # Now we compare new implementation with the old one - tester = np.load(DATA_PATH / "regress_test_small.npy") - np.testing.assert_allclose(adata.X, tester) + tester = np.load(DATA_PATH / test_file) + np.testing.assert_allclose(adata.X, tester, atol=atol) def test_regress_out_constants_equivalent():