diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index b200e89ce8..2a94343955 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -34,7 +34,10 @@ def _(X: np.ndarray, *, axis: Literal[0, 1], dtype: DTypeLike) -> np.ndarray: def _get_mean_var( X: _SupportedArray, *, axis: Literal[0, 1] = 0 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - if isinstance(X, sparse.spmatrix): + if isinstance(X, np.ndarray): + n_threads = numba.get_num_threads() + mean, var = _compute_mean_var_dense(X, axis=axis, n_threads=n_threads) + elif isinstance(X, sparse.spmatrix): mean, var = sparse_mean_variance_axis(X, axis=axis) else: mean = axis_mean(X, axis=axis, dtype=np.float64) @@ -46,6 +49,42 @@ def _get_mean_var( return mean, var +@numba.njit(cache=True, parallel=True) +def _compute_mean_var_dense( + X: _SupportedArray, axis: Literal[0, 1] = 0, n_threads=1 +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + if axis == 0: + axis_i = 1 + sums = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) + sums_squared = np.zeros((n_threads, X.shape[axis_i]), dtype=np.float64) + mean = np.zeros(X.shape[axis_i], dtype=np.float64) + var = np.zeros(X.shape[axis_i], dtype=np.float64) + n = X.shape[axis] + for i in numba.prange(n_threads): + for r in range(i, n, n_threads): + for c in range(X.shape[axis_i]): + value = X[r, c] + sums[i, c] += value + sums_squared[i, c] += value * value + for c in numba.prange(X.shape[axis_i]): + sum_ = sums[:, c].sum() + mean[c] = sum_ / n + var[c] = (sums_squared[:, c].sum() - sum_ * sum_ / n) / (n - 1) + else: + axis_i = 0 + mean = np.zeros(X.shape[axis_i], dtype=np.float64) + var = np.zeros(X.shape[axis_i], dtype=np.float64) + for r in numba.prange(X.shape[0]): + for c in range(X.shape[1]): + value = X[r, c] + mean[r] += value + var[r] += value * value + for c in numba.prange(X.shape[0]): + mean[c] = mean[c] / X.shape[1] + var[c] = (var[c] - mean[c] ** 2) / (X.shape[1] - 1) + + return mean, var + def sparse_mean_variance_axis(mtx: sparse.spmatrix, axis: int): """ This code and internal functions are based on sklearns