Skip to content

Commit

Permalink
refactor(decomposer): allow truncated SVD based on variance (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie authored Aug 19, 2024
1 parent 33e5bbb commit 5973c41
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 32 deletions.
9 changes: 5 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings

import numpy as np
import pandas as pd
import pytest
import warnings
import xarray as xr
import pandas as pd

from xeofs.utils.data_types import DataArray, DataSet, DataList
from xeofs.utils.data_types import DataArray, DataList, DataSet

warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
Expand Down Expand Up @@ -49,7 +50,7 @@ def generate_synthetic_dataarray(
# Create dimensions
sample_dims = [f"sample{i}" for i in range(n_sample)]
feature_dims = [f"feature{i}" for i in range(n_feature)]
all_dims = feature_dims + sample_dims
all_dims = sample_dims + feature_dims

# Create coordinates/indices
coords = {}
Expand Down
91 changes: 90 additions & 1 deletion tests/models/test_decomposer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
import numpy as np
import pytest
from dask.array import Array as DaskArray # type: ignore

from xeofs.models.decomposer import Decomposer

from ..utilities import data_is_dask


def compute_max_exp_var(singular_values, data):
"""Compute the maximal cumulative explained variance by all components."""

total_variance = data.var("sample", ddof=1).sum("feature")
explained_variance = singular_values**2 / (data.sample.size - 1)
explained_variance_ratio = explained_variance / total_variance
explained_variance_ratio_cumsum = explained_variance_ratio.cumsum("mode")
return explained_variance_ratio_cumsum.isel(mode=-1).item()


@pytest.fixture
def decomposer():
return Decomposer(n_modes=2, random_state=42)


@pytest.fixture
def mock_data_array(mock_data_array):
return mock_data_array.stack(sample=("time",), feature=("lat", "lon")).dropna(
data2d = mock_data_array.stack(sample=("time",), feature=("lat", "lon")).dropna(
"feature"
)
return data2d - data2d.mean("sample")


@pytest.fixture
Expand Down Expand Up @@ -177,3 +190,79 @@ def test_random_state(

# Check that the results are the same
assert np.all(U1 == U2)


@pytest.mark.parametrize(
"target_variance, solver",
[
(0.1, "randomized"),
(0.5, "randomized"),
(0.9, "randomized"),
(0.99, "randomized"),
(0.1, "full"),
(0.5, "full"),
(0.9, "full"),
(0.99, "full"),
],
)
def test_decompose_via_variance_threshold(mock_data_array, target_variance, solver):
"""Test that the decomposer returns the correct number of modes to explain the target variance."""
decomposer = Decomposer(
n_modes=target_variance, solver=solver, init_rank_reduction=0.9
)
decomposer.fit(mock_data_array)
s = decomposer.s_

# Compute total variance and test whether variance threshold is reached
max_explained_variance_ratio = compute_max_exp_var(s, mock_data_array)
assert (
max_explained_variance_ratio >= target_variance
), f"Expected >= {target_variance:.2f}, got {max_explained_variance_ratio:2f}"

# We still get a truncated version of the SVD
assert s.mode.size < min(mock_data_array.shape)


def test_raise_warning_for_low_init_rank_reduction(mock_data_array):
target_variance = 0.5
init_rank_reduction = 0.1
decomposer = Decomposer(
n_modes=target_variance, init_rank_reduction=init_rank_reduction
)
warn_msg = ".*components were computed which explain.*of the variance but.*of explained variance was requested. Consider increasing the `init_rank_reduction`"
with pytest.warns(UserWarning, match=warn_msg):
decomposer.fit(mock_data_array)


def test_compute_at_least_one_component(mock_data_array):
""""""
target_variance = 0.5
init_rank_reduction = 0.01
decomposer = Decomposer(
n_modes=target_variance, init_rank_reduction=init_rank_reduction
)

# Warning is raised to indicate that the value of init_rank_reduction is too low
warn_msg = "`init_rank_reduction=.*` is too low and results in zero components. One component will be computed instead."
with pytest.warns(UserWarning, match=warn_msg):
decomposer.fit(mock_data_array)

# At least one mode is computed
s = decomposer.s_
assert s.mode.size >= 1


@pytest.mark.parametrize(
"solver",
["full", "randomized"],
)
def test_dask_array_based_on_target_variance(mock_dask_data_array, solver):
target_variance = 0.5
decomposer = Decomposer(
n_modes=target_variance, init_rank_reduction=0.9, solver=solver, compute=False
)

err_msg = "Estimating the number of modes to keep based on variance is not supported with dask arrays.*"
with pytest.raises(ValueError, match=err_msg):
assert data_is_dask(mock_dask_data_array)
decomposer.fit(mock_dask_data_array)
110 changes: 87 additions & 23 deletions xeofs/models/decomposer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import warnings
from typing import Optional

import dask
import numpy as np
import xarray as xr
import dask
from dask.array import Array as DaskArray # type: ignore
from dask.array.linalg import svd_compressed as dask_svd
from dask.diagnostics.progress import ProgressBar
from sklearn.utils.extmath import randomized_svd
from scipy.sparse.linalg import svds as complex_svd # type: ignore
from dask.array.linalg import svd_compressed as dask_svd
from typing import Optional
from sklearn.utils.extmath import randomized_svd

from ..utils.sanity_checks import sanity_check_n_modes
from ..utils.xarray_utils import get_deterministic_sign_multiplier


Expand All @@ -20,8 +23,10 @@ class Decomposer:
Parameters
----------
n_modes : int
Number of components to be computed.
n_modes : int | float
Number of components to be computed. If float, it represents the fraction of variance to keep.
init_rank_reduction : float, default=0.0
Used only when `n_modes` is given as a float. Specifiy the initial target rank to be computed by randomized SVD before truncating the solution to the desired fraction of explained variance. Must be in the half open interval (0, 1]. Lower values will speed up the computation. If the rank is too low and the fraction of explained variance is not reached, a warning will be raised.
flip_signs : bool, default=True
Whether to flip the sign of the components to ensure deterministic output.
compute : bool, default=True
Expand All @@ -36,26 +41,42 @@ class Decomposer:
Seed for the random number generator.
verbose: bool, default=False
Whether to show a progress bar when computing the decomposition.
component_dim_name : str, default='mode'
Name of the component dimension in the output DataArrays.
solver_kwargs : dict, default={}
Additional keyword arguments passed to the SVD solver.
"""

def __init__(
self,
n_modes: int,
n_modes: int | float,
init_rank_reduction: float = 0.3,
flip_signs: bool = True,
compute: bool = True,
solver: str = "auto",
random_state: Optional[int] = None,
verbose: bool = False,
component_dim_name: str = "mode",
solver_kwargs: dict = {},
):
sanity_check_n_modes(n_modes)
self.is_based_on_variance = False if isinstance(n_modes, int) else True

if self.is_based_on_variance:
if not (0 < init_rank_reduction <= 1.0):
raise ValueError(
"init_rank_reduction must be in the half open interval (0, 1]."
)

self.n_modes = n_modes
self.n_modes_precompute = n_modes
self.init_rank_reduction = init_rank_reduction
self.flip_signs = flip_signs
self.compute = compute
self.verbose = verbose
self.solver = solver
self.random_state = random_state
self.component_dim_name = component_dim_name
self.solver_kwargs = solver_kwargs

def fit(self, X, dims=("sample", "feature")):
Expand All @@ -68,11 +89,18 @@ def fit(self, X, dims=("sample", "feature")):
dims : tuple of str
Dimensions of the data object.
"""
n_coords1 = len(X.coords[dims[0]])
n_coords2 = len(X.coords[dims[1]])
rank = min(n_coords1, n_coords2)
rank = min(X.shape)

if self.n_modes > rank:
if self.is_based_on_variance:
self.n_modes_precompute = int(rank * self.init_rank_reduction)
if self.n_modes_precompute < 1:
warnings.warn(
f"`init_rank_reduction={self.init_rank_reduction}` is too low and results in zero components. One component will be computed instead."
)
self.n_modes_precompute = 1

# TODO(nicrie): perhaps we can just set n_modes to rank if it is larger than rank (possible solution for #158)
if self.n_modes_precompute > rank:
raise ValueError(
f"n_modes must be smaller or equal to the rank of the data object (rank={rank})"
)
Expand All @@ -87,12 +115,16 @@ def fit(self, X, dims=("sample", "feature")):
use_dask = True if isinstance(X.data, DaskArray) else False
use_complex = True if np.iscomplexobj(X.data) else False

is_small_data = max(n_coords1, n_coords2) < 500
is_small_data = max(X.shape) < 500

match self.solver:
# TODO(nicrie): implement more performant case for tall and skinny problems which are best handled by precomputing the covariance matrix.
# if X.shape[1] <= 1_000 and X.shape[0] >= 10 * X.shape[1]: -> covariance_eigh" (see sklearn PCA implementation: https://github.com/scikit-learn/scikit-learn/blob/e87b32a81c70abed8f2e97483758eb64df8255e9/sklearn/decomposition/_pca.py#L526)
case "auto":
use_exact = (
True if is_small_data and self.n_modes > int(0.8 * rank) else False
True
if is_small_data and self.n_modes_precompute > int(0.8 * rank)
else False
)
case "full":
use_exact = True
Expand All @@ -107,14 +139,14 @@ def fit(self, X, dims=("sample", "feature")):
# Use exact SVD for small data sets
if use_exact:
U, s, VT = self._svd(X, dims, np.linalg.svd, self.solver_kwargs)
U = U[:, : self.n_modes]
s = s[: self.n_modes]
VT = VT[: self.n_modes, :]
U = U[:, : self.n_modes_precompute]
s = s[: self.n_modes_precompute]
VT = VT[: self.n_modes_precompute, :]

# Use randomized SVD for large, real-valued data sets
elif (not use_complex) and (not use_dask):
solver_kwargs = self.solver_kwargs | {
"n_components": self.n_modes,
"n_components": self.n_modes_precompute,
"random_state": self.random_state,
}
U, s, VT = self._svd(X, dims, randomized_svd, solver_kwargs)
Expand All @@ -123,7 +155,7 @@ def fit(self, X, dims=("sample", "feature")):
elif use_complex and (not use_dask):
# Scipy sparse version
solver_kwargs = self.solver_kwargs | {
"k": self.n_modes,
"k": self.n_modes_precompute,
"solver": "lobpcg",
"random_state": self.random_state,
}
Expand All @@ -136,7 +168,7 @@ def fit(self, X, dims=("sample", "feature")):
# Use dask SVD for large, real-valued, delayed data sets
elif (not use_complex) and use_dask:
solver_kwargs = self.solver_kwargs | {
"k": self.n_modes,
"k": self.n_modes_precompute,
"seed": self.random_state,
}
solver_kwargs.setdefault("compute", self.compute)
Expand All @@ -158,14 +190,46 @@ def fit(self, X, dims=("sample", "feature")):
s.name = "s"
VT.name = "V"

# Truncate the decomposition to the desired number of modes
if self.is_based_on_variance:
# Truncating based on variance requires computation of dask array
# which we prefer to avoid
if use_dask:
err_msg = "Estimating the number of modes to keep based on variance is not supported with dask arrays. Please explicitly specifiy the number of modes to keep by using an integer for the number of modes."
raise ValueError(err_msg)

# Compute the fraction of explained variance per mode
N = X.shape[0] - 1
total_variance = X.var(X.dims[0], ddof=1).sum(X.dims[1])
explained_variance = s**2 / N / total_variance
cum_expvar = explained_variance.cumsum(self.component_dim_name)
total_explained_variance = cum_expvar[-1].item()

n_modes_required = (
self.n_modes_precompute
- (cum_expvar >= self.n_modes).sum(self.component_dim_name)
+ 1
)
if n_modes_required > self.n_modes_precompute:
warnings.warn(
f"{self.n_modes_precompute} components were computed which explain {total_explained_variance:.2%} of the variance but {self.n_modes:.2%} of explained variance was requested. Consider increasing the `init_rank_reduction`."
)
n_modes_required = self.n_modes_precompute

# Truncate solution to the desired fraction of explained variance
U = U.sel(mode=slice(1, n_modes_required))
s = s.sel(mode=slice(1, n_modes_required))
VT = VT.sel(mode=slice(1, n_modes_required))

# Flip signs of components to ensure deterministic output
if self.flip_signs:
sign_multiplier = get_deterministic_sign_multiplier(VT, dims[1])
VT *= sign_multiplier
U *= sign_multiplier

self.U_ = U
self.s_ = s
self.V_ = VT.conj().transpose(dims[1], "mode")
self.V_ = VT.conj().transpose(dims[1], self.component_dim_name)

def _svd(self, X, dims, func, kwargs):
"""Performs SVD on the data
Expand Down Expand Up @@ -197,9 +261,9 @@ def _svd(self, X, dims, func, kwargs):
kwargs=kwargs,
input_core_dims=[dims],
output_core_dims=[
[dims[0], "mode"],
["mode"],
["mode", dims[1]],
[dims[0], self.component_dim_name],
[self.component_dim_name],
[self.component_dim_name, dims[1]],
],
dask="allowed",
)
Expand Down
Loading

0 comments on commit 5973c41

Please sign in to comment.