diff --git a/tests/conftest.py b/tests/conftest.py index 7aa1192..585d7c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ +import warnings + import numpy as np +import pandas as pd import pytest -import warnings import xarray as xr -import pandas as pd -from xeofs.utils.data_types import DataArray, DataSet, DataList +from xeofs.utils.data_types import DataArray, DataList, DataSet warnings.filterwarnings("ignore", message="numpy.dtype size changed") warnings.filterwarnings("ignore", message="numpy.ufunc size changed") @@ -49,7 +50,7 @@ def generate_synthetic_dataarray( # Create dimensions sample_dims = [f"sample{i}" for i in range(n_sample)] feature_dims = [f"feature{i}" for i in range(n_feature)] - all_dims = feature_dims + sample_dims + all_dims = sample_dims + feature_dims # Create coordinates/indices coords = {} diff --git a/tests/models/test_decomposer.py b/tests/models/test_decomposer.py index d6f07f5..88812fb 100644 --- a/tests/models/test_decomposer.py +++ b/tests/models/test_decomposer.py @@ -1,10 +1,22 @@ import numpy as np import pytest from dask.array import Array as DaskArray # type: ignore + from xeofs.models.decomposer import Decomposer + from ..utilities import data_is_dask +def compute_max_exp_var(singular_values, data): + """Compute the maximal cumulative explained variance by all components.""" + + total_variance = data.var("sample", ddof=1).sum("feature") + explained_variance = singular_values**2 / (data.sample.size - 1) + explained_variance_ratio = explained_variance / total_variance + explained_variance_ratio_cumsum = explained_variance_ratio.cumsum("mode") + return explained_variance_ratio_cumsum.isel(mode=-1).item() + + @pytest.fixture def decomposer(): return Decomposer(n_modes=2, random_state=42) @@ -12,9 +24,10 @@ def decomposer(): @pytest.fixture def mock_data_array(mock_data_array): - return mock_data_array.stack(sample=("time",), feature=("lat", "lon")).dropna( + data2d = mock_data_array.stack(sample=("time",), feature=("lat", "lon")).dropna( "feature" ) + return data2d - data2d.mean("sample") @pytest.fixture @@ -177,3 +190,79 @@ def test_random_state( # Check that the results are the same assert np.all(U1 == U2) + + +@pytest.mark.parametrize( + "target_variance, solver", + [ + (0.1, "randomized"), + (0.5, "randomized"), + (0.9, "randomized"), + (0.99, "randomized"), + (0.1, "full"), + (0.5, "full"), + (0.9, "full"), + (0.99, "full"), + ], +) +def test_decompose_via_variance_threshold(mock_data_array, target_variance, solver): + """Test that the decomposer returns the correct number of modes to explain the target variance.""" + decomposer = Decomposer( + n_modes=target_variance, solver=solver, init_rank_reduction=0.9 + ) + decomposer.fit(mock_data_array) + s = decomposer.s_ + + # Compute total variance and test whether variance threshold is reached + max_explained_variance_ratio = compute_max_exp_var(s, mock_data_array) + assert ( + max_explained_variance_ratio >= target_variance + ), f"Expected >= {target_variance:.2f}, got {max_explained_variance_ratio:2f}" + + # We still get a truncated version of the SVD + assert s.mode.size < min(mock_data_array.shape) + + +def test_raise_warning_for_low_init_rank_reduction(mock_data_array): + target_variance = 0.5 + init_rank_reduction = 0.1 + decomposer = Decomposer( + n_modes=target_variance, init_rank_reduction=init_rank_reduction + ) + warn_msg = ".*components were computed which explain.*of the variance but.*of explained variance was requested. Consider increasing the `init_rank_reduction`" + with pytest.warns(UserWarning, match=warn_msg): + decomposer.fit(mock_data_array) + + +def test_compute_at_least_one_component(mock_data_array): + """""" + target_variance = 0.5 + init_rank_reduction = 0.01 + decomposer = Decomposer( + n_modes=target_variance, init_rank_reduction=init_rank_reduction + ) + + # Warning is raised to indicate that the value of init_rank_reduction is too low + warn_msg = "`init_rank_reduction=.*` is too low and results in zero components. One component will be computed instead." + with pytest.warns(UserWarning, match=warn_msg): + decomposer.fit(mock_data_array) + + # At least one mode is computed + s = decomposer.s_ + assert s.mode.size >= 1 + + +@pytest.mark.parametrize( + "solver", + ["full", "randomized"], +) +def test_dask_array_based_on_target_variance(mock_dask_data_array, solver): + target_variance = 0.5 + decomposer = Decomposer( + n_modes=target_variance, init_rank_reduction=0.9, solver=solver, compute=False + ) + + err_msg = "Estimating the number of modes to keep based on variance is not supported with dask arrays.*" + with pytest.raises(ValueError, match=err_msg): + assert data_is_dask(mock_dask_data_array) + decomposer.fit(mock_dask_data_array) diff --git a/xeofs/models/decomposer.py b/xeofs/models/decomposer.py index 42da537..05a6302 100644 --- a/xeofs/models/decomposer.py +++ b/xeofs/models/decomposer.py @@ -1,13 +1,16 @@ +import warnings +from typing import Optional + +import dask import numpy as np import xarray as xr -import dask from dask.array import Array as DaskArray # type: ignore +from dask.array.linalg import svd_compressed as dask_svd from dask.diagnostics.progress import ProgressBar -from sklearn.utils.extmath import randomized_svd from scipy.sparse.linalg import svds as complex_svd # type: ignore -from dask.array.linalg import svd_compressed as dask_svd -from typing import Optional +from sklearn.utils.extmath import randomized_svd +from ..utils.sanity_checks import sanity_check_n_modes from ..utils.xarray_utils import get_deterministic_sign_multiplier @@ -20,8 +23,10 @@ class Decomposer: Parameters ---------- - n_modes : int - Number of components to be computed. + n_modes : int | float + Number of components to be computed. If float, it represents the fraction of variance to keep. + init_rank_reduction : float, default=0.0 + Used only when `n_modes` is given as a float. Specifiy the initial target rank to be computed by randomized SVD before truncating the solution to the desired fraction of explained variance. Must be in the half open interval (0, 1]. Lower values will speed up the computation. If the rank is too low and the fraction of explained variance is not reached, a warning will be raised. flip_signs : bool, default=True Whether to flip the sign of the components to ensure deterministic output. compute : bool, default=True @@ -36,26 +41,42 @@ class Decomposer: Seed for the random number generator. verbose: bool, default=False Whether to show a progress bar when computing the decomposition. + component_dim_name : str, default='mode' + Name of the component dimension in the output DataArrays. solver_kwargs : dict, default={} Additional keyword arguments passed to the SVD solver. """ def __init__( self, - n_modes: int, + n_modes: int | float, + init_rank_reduction: float = 0.3, flip_signs: bool = True, compute: bool = True, solver: str = "auto", random_state: Optional[int] = None, verbose: bool = False, + component_dim_name: str = "mode", solver_kwargs: dict = {}, ): + sanity_check_n_modes(n_modes) + self.is_based_on_variance = False if isinstance(n_modes, int) else True + + if self.is_based_on_variance: + if not (0 < init_rank_reduction <= 1.0): + raise ValueError( + "init_rank_reduction must be in the half open interval (0, 1]." + ) + self.n_modes = n_modes + self.n_modes_precompute = n_modes + self.init_rank_reduction = init_rank_reduction self.flip_signs = flip_signs self.compute = compute self.verbose = verbose self.solver = solver self.random_state = random_state + self.component_dim_name = component_dim_name self.solver_kwargs = solver_kwargs def fit(self, X, dims=("sample", "feature")): @@ -68,11 +89,18 @@ def fit(self, X, dims=("sample", "feature")): dims : tuple of str Dimensions of the data object. """ - n_coords1 = len(X.coords[dims[0]]) - n_coords2 = len(X.coords[dims[1]]) - rank = min(n_coords1, n_coords2) + rank = min(X.shape) - if self.n_modes > rank: + if self.is_based_on_variance: + self.n_modes_precompute = int(rank * self.init_rank_reduction) + if self.n_modes_precompute < 1: + warnings.warn( + f"`init_rank_reduction={self.init_rank_reduction}` is too low and results in zero components. One component will be computed instead." + ) + self.n_modes_precompute = 1 + + # TODO(nicrie): perhaps we can just set n_modes to rank if it is larger than rank (possible solution for #158) + if self.n_modes_precompute > rank: raise ValueError( f"n_modes must be smaller or equal to the rank of the data object (rank={rank})" ) @@ -87,12 +115,16 @@ def fit(self, X, dims=("sample", "feature")): use_dask = True if isinstance(X.data, DaskArray) else False use_complex = True if np.iscomplexobj(X.data) else False - is_small_data = max(n_coords1, n_coords2) < 500 + is_small_data = max(X.shape) < 500 match self.solver: + # TODO(nicrie): implement more performant case for tall and skinny problems which are best handled by precomputing the covariance matrix. + # if X.shape[1] <= 1_000 and X.shape[0] >= 10 * X.shape[1]: -> covariance_eigh" (see sklearn PCA implementation: https://github.com/scikit-learn/scikit-learn/blob/e87b32a81c70abed8f2e97483758eb64df8255e9/sklearn/decomposition/_pca.py#L526) case "auto": use_exact = ( - True if is_small_data and self.n_modes > int(0.8 * rank) else False + True + if is_small_data and self.n_modes_precompute > int(0.8 * rank) + else False ) case "full": use_exact = True @@ -107,14 +139,14 @@ def fit(self, X, dims=("sample", "feature")): # Use exact SVD for small data sets if use_exact: U, s, VT = self._svd(X, dims, np.linalg.svd, self.solver_kwargs) - U = U[:, : self.n_modes] - s = s[: self.n_modes] - VT = VT[: self.n_modes, :] + U = U[:, : self.n_modes_precompute] + s = s[: self.n_modes_precompute] + VT = VT[: self.n_modes_precompute, :] # Use randomized SVD for large, real-valued data sets elif (not use_complex) and (not use_dask): solver_kwargs = self.solver_kwargs | { - "n_components": self.n_modes, + "n_components": self.n_modes_precompute, "random_state": self.random_state, } U, s, VT = self._svd(X, dims, randomized_svd, solver_kwargs) @@ -123,7 +155,7 @@ def fit(self, X, dims=("sample", "feature")): elif use_complex and (not use_dask): # Scipy sparse version solver_kwargs = self.solver_kwargs | { - "k": self.n_modes, + "k": self.n_modes_precompute, "solver": "lobpcg", "random_state": self.random_state, } @@ -136,7 +168,7 @@ def fit(self, X, dims=("sample", "feature")): # Use dask SVD for large, real-valued, delayed data sets elif (not use_complex) and use_dask: solver_kwargs = self.solver_kwargs | { - "k": self.n_modes, + "k": self.n_modes_precompute, "seed": self.random_state, } solver_kwargs.setdefault("compute", self.compute) @@ -158,6 +190,38 @@ def fit(self, X, dims=("sample", "feature")): s.name = "s" VT.name = "V" + # Truncate the decomposition to the desired number of modes + if self.is_based_on_variance: + # Truncating based on variance requires computation of dask array + # which we prefer to avoid + if use_dask: + err_msg = "Estimating the number of modes to keep based on variance is not supported with dask arrays. Please explicitly specifiy the number of modes to keep by using an integer for the number of modes." + raise ValueError(err_msg) + + # Compute the fraction of explained variance per mode + N = X.shape[0] - 1 + total_variance = X.var(X.dims[0], ddof=1).sum(X.dims[1]) + explained_variance = s**2 / N / total_variance + cum_expvar = explained_variance.cumsum(self.component_dim_name) + total_explained_variance = cum_expvar[-1].item() + + n_modes_required = ( + self.n_modes_precompute + - (cum_expvar >= self.n_modes).sum(self.component_dim_name) + + 1 + ) + if n_modes_required > self.n_modes_precompute: + warnings.warn( + f"{self.n_modes_precompute} components were computed which explain {total_explained_variance:.2%} of the variance but {self.n_modes:.2%} of explained variance was requested. Consider increasing the `init_rank_reduction`." + ) + n_modes_required = self.n_modes_precompute + + # Truncate solution to the desired fraction of explained variance + U = U.sel(mode=slice(1, n_modes_required)) + s = s.sel(mode=slice(1, n_modes_required)) + VT = VT.sel(mode=slice(1, n_modes_required)) + + # Flip signs of components to ensure deterministic output if self.flip_signs: sign_multiplier = get_deterministic_sign_multiplier(VT, dims[1]) VT *= sign_multiplier @@ -165,7 +229,7 @@ def fit(self, X, dims=("sample", "feature")): self.U_ = U self.s_ = s - self.V_ = VT.conj().transpose(dims[1], "mode") + self.V_ = VT.conj().transpose(dims[1], self.component_dim_name) def _svd(self, X, dims, func, kwargs): """Performs SVD on the data @@ -197,9 +261,9 @@ def _svd(self, X, dims, func, kwargs): kwargs=kwargs, input_core_dims=[dims], output_core_dims=[ - [dims[0], "mode"], - ["mode"], - ["mode", dims[1]], + [dims[0], self.component_dim_name], + [self.component_dim_name], + [self.component_dim_name, dims[1]], ], dask="allowed", ) diff --git a/xeofs/utils/sanity_checks.py b/xeofs/utils/sanity_checks.py index 0cdd2bf..49ea2f2 100644 --- a/xeofs/utils/sanity_checks.py +++ b/xeofs/utils/sanity_checks.py @@ -5,7 +5,7 @@ from xeofs.utils.data_types import Dims -def assert_single_dataarray(da, name): +def assert_single_dataarray(da, name="object"): """Check if the given object is a DataArray. Args: @@ -19,7 +19,7 @@ def assert_single_dataarray(da, name): raise TypeError(f"{name} must be a DataArray") -def assert_list_dataarrays(da_list, name): +def assert_list_dataarrays(da_list, name="object"): """Check if the given object is a list of DataArrays. Args: @@ -35,7 +35,7 @@ def assert_list_dataarrays(da_list, name): assert_single_dataarray(da, name) -def assert_single_dataset(ds, name): +def assert_single_dataset(ds, name="object"): """Check if the given object is a Dataset. Args: @@ -49,7 +49,7 @@ def assert_single_dataset(ds, name): raise TypeError(f"{name} must be a Dataset") -def assert_dataarray_or_dataset(da, name): +def assert_dataarray_or_dataset(da, name="object"): """Check if the given object is a DataArray or Dataset. Args: @@ -100,3 +100,17 @@ def assert_not_complex(da: xr.DataArray) -> None: raise TypeError( "Invalid input type. This method does not support complex data types." ) + + +def sanity_check_n_modes(n_modes: int | float) -> None: + """Check if the number of modes is valid.""" + + match n_modes: + case int(): + if n_modes < 1: + raise ValueError("n_modes must be greater than 0") + case float(): + if not (0 < n_modes <= 1.0): + raise ValueError("n_modes must be in the range (0, 1]") + case _: + raise TypeError("n_modes must be an integer or a float")