Skip to content

Commit

Permalink
refactor(whitener): add Whitener to whiten a 2D matrix (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie authored Aug 19, 2024
1 parent 5973c41 commit a6b61b3
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 4 deletions.
164 changes: 164 additions & 0 deletions tests/preprocessing/test_whitener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import math

import pytest
import xarray as xr

from xeofs.preprocessing import Whitener

from ..conftest import generate_synthetic_dataarray
from ..utilities import (
assert_expected_coords,
assert_expected_dims,
data_is_dask,
)

# =============================================================================
# GENERALLY VALID TEST CASES
# =============================================================================
N_SAMPLE_DIMS = [1]
N_FEATURE_DIMS = [1]
INDEX_POLICY = ["index"]
NAN_POLICY = ["no_nan"]
DASK_POLICY = ["no_dask", "dask"]
SEED = [0]

VALID_TEST_DATA = [
(ns, nf, index, nan, dask)
for ns in N_SAMPLE_DIMS
for nf in N_FEATURE_DIMS
for index in INDEX_POLICY
for nan in NAN_POLICY
for dask in DASK_POLICY
]


# TESTS
# =============================================================================
@pytest.mark.parametrize(
"synthetic_dataarray",
VALID_TEST_DATA,
indirect=["synthetic_dataarray"],
)
def test_fit(synthetic_dataarray):
data = synthetic_dataarray.rename({"sample0": "sample", "feature0": "feature"})

whitener = Whitener(n_modes=2)
whitener.fit(data)


@pytest.mark.parametrize(
"synthetic_dataarray",
VALID_TEST_DATA,
indirect=["synthetic_dataarray"],
)
def test_transform(synthetic_dataarray):
data = synthetic_dataarray.rename({"sample0": "sample", "feature0": "feature"})

whitener = Whitener(n_modes=2)
whitener.fit(data)

# Transform data
transformed_data = whitener.transform(data)
transformed_data2 = whitener.transform(data)
assert transformed_data.identical(transformed_data2)

assert isinstance(transformed_data, xr.DataArray)
assert transformed_data.ndim == 2
assert transformed_data.dims == ("sample", "feature")

# Consistent dask behaviour
is_dask_before = data_is_dask(data)
is_dask_after = data_is_dask(transformed_data)
assert is_dask_before == is_dask_after


@pytest.mark.parametrize(
"synthetic_dataarray",
VALID_TEST_DATA,
indirect=["synthetic_dataarray"],
)
def test_fit_transform(synthetic_dataarray):
data = synthetic_dataarray.rename({"sample0": "sample", "feature0": "feature"})

whitener = Whitener(n_modes=2)

# Transform data
transformed_data = whitener.fit_transform(data)
transformed_data2 = whitener.transform(data)
assert transformed_data.identical(transformed_data2)

assert isinstance(transformed_data, xr.DataArray)
assert transformed_data.ndim == 2
assert transformed_data.dims == ("sample", "feature")

# Consistent dask behaviour
is_dask_before = data_is_dask(data)
is_dask_after = data_is_dask(transformed_data)
assert is_dask_before == is_dask_after


@pytest.mark.parametrize(
"synthetic_dataarray",
VALID_TEST_DATA,
indirect=["synthetic_dataarray"],
)
def test_invserse_transform_data(synthetic_dataarray):
data = synthetic_dataarray.rename({"sample0": "sample", "feature0": "feature"})

whitener = Whitener(n_modes=2)
whitener.fit(data)

whitened_data = whitener.transform(data)
unwhitened_data = whitener.inverse_transform_data(whitened_data)

is_dask_before = data_is_dask(data)
is_dask_after = data_is_dask(unwhitened_data)

# Unstacked data has dimensions of original data
assert_expected_dims(data, unwhitened_data, policy="all")
# Unstacked data has coordinates of original data
assert_expected_coords(data, unwhitened_data, policy="all")
# inverse transform should not change dask-ness
assert is_dask_before == is_dask_after


@pytest.mark.parametrize(
"alpha",
[0.0, 0.5, 1.0, 1.5],
)
def test_transform_alpha(alpha):
data = generate_synthetic_dataarray(1, 1, "index", "no_nan", "no_dask")
data = data.rename({"sample0": "sample", "feature0": "feature"})

whitener = Whitener(n_modes=2, alpha=alpha)
data_whitened = whitener.fit_transform(data)

norm = (data_whitened**2).sum("sample")
ones = norm / norm
# Check that for alpha=0 full whitening is performed
if math.isclose(alpha, 0.0, abs_tol=1e-6):
xr.testing.assert_allclose(norm, ones, atol=1e-6)


@pytest.mark.parametrize(
"alpha",
[0.0, 0.5, 1.0, 1.5],
)
def test_invserse_transform_alpha(alpha):
data = generate_synthetic_dataarray(1, 1, "index", "no_nan", "no_dask")
data = data.rename({"sample0": "sample", "feature0": "feature"})

whitener = Whitener(n_modes=6, alpha=alpha)
data_whitened = whitener.fit_transform(data)
data_unwhitened = whitener.inverse_transform_data(data_whitened)

xr.testing.assert_allclose(data, data_unwhitened, atol=1e-6)


def test_invalid_alpha():
data = generate_synthetic_dataarray(1, 1, "index", "no_nan", "no_dask")
data = data.rename({"sample0": "sample", "feature0": "feature"})

err_msg = "`alpha` must be greater than or equal to 0"
with pytest.raises(ValueError, match=err_msg):
Whitener(n_modes=2, alpha=-1.0)
10 changes: 6 additions & 4 deletions xeofs/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .scaler import Scaler
from .sanitizer import Sanitizer
from .multi_index_converter import MultiIndexConverter
from .stacker import Stacker
from .concatenator import Concatenator
from .dimension_renamer import DimensionRenamer
from .multi_index_converter import MultiIndexConverter
from .sanitizer import Sanitizer
from .scaler import Scaler
from .stacker import Stacker
from .whitener import Whitener

__all__ = [
"Scaler",
Expand All @@ -12,4 +13,5 @@
"Stacker",
"Concatenator",
"DimensionRenamer",
"Whitener",
]
132 changes: 132 additions & 0 deletions xeofs/preprocessing/whitener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import Dict, Optional

import xarray as xr
from typing_extensions import Self

from ..models.decomposer import Decomposer
from ..utils.data_types import (
DataArray,
Dims,
DimsList,
)
from ..utils.sanity_checks import assert_single_dataarray
from .transformer import Transformer


class Whitener(Transformer):
"""Whiten a 2D DataArray matrix using PCA.
Parameters
----------
n_modes: int | float
If int, number of components to keep. If float, fraction of variance to keep.
init_rank_reduction: float, default=0.3
Used only when `n_modes` is given as a float. Specifiy the initial PCA rank reduction before truncating the solution to the desired fraction of explained variance. Must be in the half open interval ]0, 1]. Lower values will speed up the computation.
alpha: float, default=0.0
Power parameter to perform fractional whitening, where 0 corresponds to full PCA whitening and 1 to PCA without whitening.
sample_name: str, default="sample"
Name of the sample dimension.
feature_name: str, default="feature"
Name of the feature dimension.
solver_kwargs: Dict
Additional keyword arguments for the SVD solver.
"""

def __init__(
self,
n_modes: int | float,
init_rank_reduction: float = 0.3,
alpha: float = 0.0,
sample_name: str = "sample",
feature_name: str = "feature",
solver_kwargs: Dict = {},
):
super().__init__(sample_name, feature_name)

# Verify that alpha has a lower bound of 0
if alpha < 0:
raise ValueError("`alpha` must be greater than or equal to 0")

self.n_modes = n_modes
self.init_rank_reduction = init_rank_reduction
self.alpha = alpha
self.solver_kwargs = solver_kwargs

def _sanity_check_input(self, X) -> None:
assert_single_dataarray(X)

if len(X.dims) != 2:
raise ValueError("Input DataArray must have shape 2")

if X.dims != (self.sample_name, self.feature_name):
raise ValueError(
"Input DataArray must have dimensions ({:}, {:})".format(
self.sample_name, self.feature_name
)
)

def get_serialization_attrs(self) -> Dict:
return dict(n_modes=self.n_modes, alpha=self.alpha)

def fit(
self,
X: xr.DataArray,
sample_dims: Optional[Dims] = None,
feature_dims: Optional[DimsList] = None,
) -> Self:
self._sanity_check_input(X)

decomposer = Decomposer(
n_modes=self.n_modes,
init_rank_reduction=self.init_rank_reduction,
**self.solver_kwargs,
)
decomposer.fit(X, dims=(self.sample_name, self.feature_name))

self.U = decomposer.U_
self.s = decomposer.s_
self.V = decomposer.V_

return self

def transform(self, X: xr.DataArray) -> DataArray:
"""Transform new data into the fractional whitened PC space."""

self._sanity_check_input(X)

scores = xr.dot(X, self.V, dims=self.feature_name) * self.s ** (self.alpha - 1)
return scores.rename({"mode": self.feature_name})

def fit_transform(
self,
X: xr.DataArray,
sample_dims: Optional[Dims] = None,
feature_dims: Optional[DimsList] = None,
) -> DataArray:
return self.fit(X, sample_dims, feature_dims).transform(X)

def inverse_transform_data(self, X: DataArray) -> DataArray:
"""Transform 2D data (sample x feature) from whitened PC space back into original space."""

X = X.rename({self.feature_name: "mode"})
X_unwhitened = X * self.s ** (1 - self.alpha)
return xr.dot(X_unwhitened, self.V.conj().T, dims="mode")

def inverse_transform_components(self, X: DataArray) -> DataArray:
"""Transform 2D components (feature x mode) from whitened PC space back into original space."""

dummy_dim = "dummy_dim"
comps_pc_space = X.rename({self.feature_name: dummy_dim})
V = self.V.rename({"mode": dummy_dim})
return xr.dot(comps_pc_space, V.conj().T, dims=dummy_dim)

def inverse_transform_scores(self, X: DataArray) -> DataArray:
"""Transform 2D scores (sample x mode) from whitened PC space back into original space."""

return X

def inverse_transform_scores_unseen(self, X: DataArray) -> DataArray:
"""Transform unseen 2D scores (sample x mode) from whitened PC space back into original space."""

return X

0 comments on commit a6b61b3

Please sign in to comment.