Skip to content

Simplify scale #3351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
20b5f14
Avoid parallel numba within dask
flying-sheep Oct 24, 2024
4ba3b21
restore zappy compat
flying-sheep Oct 24, 2024
7fdeda1
only do it in tests
flying-sheep Oct 25, 2024
726b28f
Merge branch 'main' into fix-clip-dask-sparse
flying-sheep Nov 7, 2024
f5eaa12
relnote
flying-sheep Nov 7, 2024
ebada2f
Merge branch 'main' into fix-clip-dask-sparse
flying-sheep Nov 8, 2024
ae8bd60
Merge branch 'main' into fix-clip-dask-sparse
flying-sheep Nov 11, 2024
bbbf3f4
Simplify scale implementation
flying-sheep Nov 11, 2024
4008a28
Merge branch 'main' into simplify-scale
flying-sheep Nov 11, 2024
f57a9b0
Fix merge
flying-sheep Nov 11, 2024
9e9f63f
Merge branch 'main' into simplify-scale
flying-sheep Feb 17, 2025
cdb1c87
Rename 3317.bugfix.md to 3351.bugfix.md
flying-sheep Feb 17, 2025
3f59f8b
reintroduce numba helper
flying-sheep Feb 17, 2025
07e8af2
oops
flying-sheep Feb 17, 2025
18a58d7
use typevar
flying-sheep Feb 17, 2025
208bec6
Merge branch 'main' into simplify-scale
flying-sheep Mar 14, 2025
e253f8a
Merge branch 'main' into simplify-scale
flying-sheep Mar 17, 2025
1d71b1b
add mask for csr
Intron7 Apr 8, 2025
d7c21a3
make this into a suggestion
Intron7 Apr 8, 2025
bd122d1
Merge branch 'main' into simplify-scale
flying-sheep Apr 10, 2025
04815d2
Merge branch 'main' into simplify-scale
flying-sheep Apr 10, 2025
08112b9
add default mask for CSR
flying-sheep Apr 10, 2025
a8d13d7
do tests properly
flying-sheep Apr 10, 2025
b0716e2
can only use shortcut when not zero-centering
flying-sheep Apr 10, 2025
66ca664
msg
flying-sheep Apr 10, 2025
efd1130
fix test
flying-sheep Apr 10, 2025
04ff943
Merge branch 'main' into simplify-scale
flying-sheep Apr 10, 2025
a428274
Merge branch 'main' into simplify-scale
flying-sheep Apr 14, 2025
9ccabe8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2025
233d8a3
oops
flying-sheep Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3351.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix zappy compatibility for clip_array {smaller}`P Angerer`
240 changes: 118 additions & 122 deletions src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from anndata import AnnData

from .. import logging as logg
from .._compat import CSBase, CSCBase, DaskArray, njit, old_positionals
from .._compat import CSBase, CSCBase, CSRBase, DaskArray, njit, old_positionals
from .._utils import (
_check_array_function_arguments,
axis_mul_or_truediv,
Expand All @@ -28,18 +28,29 @@
da = None

if TYPE_CHECKING:
from numpy.typing import NDArray
from typing import TypeVar

from numpy.typing import ArrayLike, NDArray

@njit
def _scale_sparse_numba(indptr, indices, data, *, std, mask_obs, clip):
for i in numba.prange(len(indptr) - 1):
if mask_obs[i]:
for j in range(indptr[i], indptr[i + 1]):
if clip:
data[j] = min(clip, data[j] / std[indices[j]])
else:
data[j] /= std[indices[j]]
_A = TypeVar("_A", bound=CSBase | np.ndarray | DaskArray)


@singledispatch
def clip(x: ArrayLike | _A, *, max_value: float, zero_center: bool = True) -> _A:
return clip_array(x, max_value=max_value, zero_center=zero_center)


@clip.register(CSBase)
def _(x: CSBase, *, max_value: float, zero_center: bool = True) -> CSBase:
x.data = clip(x.data, max_value=max_value, zero_center=zero_center)
return x


@clip.register(DaskArray)
def _(x: DaskArray, *, max_value: float, zero_center: bool = True) -> DaskArray:
return x.map_blocks(
clip, max_value=max_value, zero_center=zero_center, dtype=x.dtype, meta=x._meta
)


@njit
Expand All @@ -62,27 +73,19 @@ def clip_array(
return X


def clip_set(x: CSBase, *, max_value: float, zero_center: bool = True) -> CSBase:
x = x.copy()
x[x > max_value] = max_value
if zero_center:
x[x < -max_value] = -max_value
return x


@renamed_arg("X", "data", pos_0=True)
@old_positionals("zero_center", "max_value", "copy", "layer", "obsm")
@singledispatch
def scale(
data: AnnData | CSBase | np.ndarray | DaskArray,
data: AnnData | _A,
*,
zero_center: bool = True,
max_value: float | None = None,
copy: bool = False,
layer: str | None = None,
obsm: str | None = None,
mask_obs: NDArray[np.bool_] | str | None = None,
) -> AnnData | CSBase | np.ndarray | DaskArray | None:
) -> AnnData | _A | None:
"""Scale data to unit variance and zero mean.

.. note::
Expand Down Expand Up @@ -117,7 +120,7 @@ def scale(
-------
Returns `None` if `copy=False`, else returns an updated `AnnData` object. Sets the following fields:

`adata.X` | `adata.layers[layer]` : :class:`numpy.ndarray` | :class:`scipy.sparse._csr.csr_matrix` (dtype `float`)
`adata.X` | `adata.layers[layer]` : :class:`numpy.ndarray` | :class:`scipy.sparse.csr_matrix` (dtype `float`)
Scaled count data matrix.
`adata.var['mean']` : :class:`pandas.Series` (dtype `float`)
Means per gene before scaling.
Expand All @@ -141,150 +144,143 @@ def scale(

@scale.register(np.ndarray)
@scale.register(DaskArray)
def scale_array( # noqa: PLR0912
X: np.ndarray | DaskArray,
@scale.register(CSBase)
def scale_array(
x: _A,
*,
zero_center: bool = True,
max_value: float | None = None,
copy: bool = False,
return_mean_std: bool = False,
mask_obs: NDArray[np.bool_] | None = None,
) -> (
np.ndarray
| DaskArray
_A
| tuple[
np.ndarray | DaskArray, NDArray[np.float64] | DaskArray, NDArray[np.float64]
_A,
NDArray[np.float64] | DaskArray,
NDArray[np.float64],
]
):
if copy:
X = X.copy()
mask_obs = _check_mask(X, mask_obs, "obs")
if mask_obs is not None:
scale_rv = scale_array(
X[mask_obs, :],
zero_center=zero_center,
max_value=max_value,
copy=False,
return_mean_std=return_mean_std,
mask_obs=None,
)

if return_mean_std:
X[mask_obs, :], mean, std = scale_rv
return X, mean, std
else:
X[mask_obs, :] = scale_rv
return X
x = x.copy()

if not zero_center and max_value is not None:
logg.info( # Be careful of what? This should be more specific
"... be careful when using `max_value` without `zero_center`."
)

if np.issubdtype(X.dtype, np.integer):
if np.issubdtype(x.dtype, np.integer):
logg.info(
"... as scaling leads to float results, integer "
"input is cast to float, returning copy."
)
X = X.astype(float)
x = x.astype(np.float64)

mask_obs = (
# For CSR matrices, default to a set mask to take the `scale_array_masked` path.
# This is faster than the maskless `axis_mul_or_truediv` path.
np.ones(x.shape[0], dtype=np.bool_)
if isinstance(x, CSRBase) and mask_obs is None and not zero_center
else _check_mask(x, mask_obs, "obs")
)
if mask_obs is not None:
return scale_array_masked(
x,
mask_obs,
zero_center=zero_center,
max_value=max_value,
return_mean_std=return_mean_std,
)

mean, var = _get_mean_var(X)
mean, var = _get_mean_var(x)
std = np.sqrt(var)
std[std == 0] = 1
if zero_center:
if isinstance(X, DaskArray) and isinstance(X._meta, CSBase):
warnings.warn(
"zero-center being used with `DaskArray` sparse chunks. "
"This can be bad if you have large chunks or intend to eventually read the whole data into memory.",
UserWarning,
stacklevel=2,
)
X -= mean

out = X if isinstance(X, np.ndarray | CSBase) else None
X = axis_mul_or_truediv(X, std, op=truediv, out=out, axis=1)
if isinstance(x, CSBase) or (
isinstance(x, DaskArray) and isinstance(x._meta, CSBase)
):
msg = "zero-centering a sparse array/matrix densifies it."
warnings.warn(msg, UserWarning, stacklevel=2)
x -= mean

x = axis_mul_or_truediv(
x,
std,
op=truediv,
out=x if isinstance(x, np.ndarray | CSBase) else None,
axis=1,
)

# do the clipping
if max_value is not None:
logg.debug(f"... clipping at max_value {max_value}")
if isinstance(X, DaskArray):
clip = clip_set if isinstance(X._meta, CSBase) else clip_array
X = X.map_blocks(clip, max_value=max_value, zero_center=zero_center)
elif isinstance(X, CSBase):
X.data = clip_array(X.data, max_value=max_value, zero_center=False)
else:
X = clip_array(X, max_value=max_value, zero_center=zero_center)
x = clip(x, max_value=max_value, zero_center=zero_center)
if return_mean_std:
return X, mean, std
return x, mean, std
else:
return X
return x


@scale.register(CSBase)
def scale_sparse(
X: CSBase,
def scale_array_masked(
x: _A,
mask_obs: NDArray[np.bool_],
*,
zero_center: bool = True,
max_value: float | None = None,
copy: bool = False,
return_mean_std: bool = False,
mask_obs: NDArray[np.bool_] | None = None,
) -> np.ndarray | tuple[np.ndarray, NDArray[np.float64], NDArray[np.float64]]:
# need to add the following here to make inplace logic work
if zero_center:
logg.info(
"... as `zero_center=True`, sparse input is "
"densified and may lead to large memory consumption"
)
X = X.toarray()
copy = False # Since the data has been copied
return scale_array(
X,
zero_center=zero_center,
copy=copy,
max_value=max_value,
return_mean_std=return_mean_std,
) -> (
_A
| tuple[
_A,
NDArray[np.float64] | DaskArray,
NDArray[np.float64],
]
):
if isinstance(x, CSBase) and not zero_center:
if isinstance(x, CSCBase):
x = x.tocsr()
mean, var = _get_mean_var(x[mask_obs, :])
std = np.sqrt(var)
std[std == 0] = 1

scale_and_clip_csr(
x.indptr,
x.indices,
x.data,
std=std,
mask_obs=mask_obs,
max_value=max_value,
)
elif mask_obs is None:
return scale_array(
X,
else:
x[mask_obs, :], mean, std = scale_array(
x[mask_obs, :],
zero_center=zero_center,
copy=copy,
max_value=max_value,
return_mean_std=return_mean_std,
mask_obs=mask_obs,
return_mean_std=True,
)
else:
if isinstance(X, CSCBase):
X = X.tocsr()
elif copy:
X = X.copy()

if mask_obs is not None:
mask_obs = _check_mask(X, mask_obs, "obs")

mean, var = _get_mean_var(X[mask_obs, :])

std = np.sqrt(var)
std[std == 0] = 1

if max_value is None:
max_value = 0

_scale_sparse_numba(
X.indptr,
X.indices,
X.data,
std=std.astype(X.dtype),
mask_obs=mask_obs,
clip=max_value,
)

if return_mean_std:
return X, mean, std
return x, mean, std
else:
return X
return x


@njit
def scale_and_clip_csr(
indptr: NDArray[np.integer],
indices: NDArray[np.integer],
data: NDArray[np.floating],
*,
std: NDArray[np.floating],
mask_obs: NDArray[np.bool_],
max_value: float | None,
) -> None:
for i in numba.prange(len(indptr) - 1):
if mask_obs[i]:
for j in range(indptr[i], indptr[i + 1]):
if max_value is not None:
data[j] = min(max_value, data[j] / std[indices[j]])
else:
data[j] /= std[indices[j]]


@scale.register(AnnData)
Expand Down
11 changes: 3 additions & 8 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
maybe_dask_process_context,
)
from testing.scanpy._helpers.data import pbmc3k, pbmc68k_reduced
from testing.scanpy._pytest.params import ARRAY_TYPES
from testing.scanpy._pytest.params import ARRAY_TYPES, ARRAY_TYPES_SPARSE

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -318,18 +318,13 @@ def test_scale_matrix_types(array_type, zero_center, max_value):
assert_allclose(X, adata.X, rtol=1e-5, atol=1e-5)


ARRAY_TYPES_DASK_SPARSE = [
a for a in ARRAY_TYPES if "sparse" in a.id and "dask" in a.id
]


@pytest.mark.parametrize("array_type", ARRAY_TYPES_DASK_SPARSE)
@pytest.mark.parametrize("array_type", ARRAY_TYPES_SPARSE)
def test_scale_zero_center_warns_dask_sparse(array_type):
adata = pbmc68k_reduced()
adata.X = adata.raw.X
adata_casted = adata.copy()
adata_casted.X = array_type(adata_casted.raw.X)
with pytest.warns(UserWarning, match="zero-center being used with `DaskArray`*"):
with pytest.warns(UserWarning, match="zero-center.*sparse"):
sc.pp.scale(adata_casted)
sc.pp.scale(adata)
assert_allclose(adata_casted.X, adata.X, rtol=1e-5, atol=1e-5)
Expand Down
Loading
Loading