Skip to content

Commit

Permalink
fix: test loaded model vs transform not scores
Browse files Browse the repository at this point in the history
  • Loading branch information
slevang committed Nov 20, 2023
1 parent e7d6d90 commit a9d6bdf
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 23 deletions.
4 changes: 1 addition & 3 deletions tests/models/test_eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,11 @@ def test_save_load(dim, mock_data_array, tmp_path):

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.scores(), loaded.transform(mock_data_array), rtol=1e-3, atol=1e-3
original.transform(mock_data_array), loaded.transform(mock_data_array)
)

# The loaded model should also be able to inverse_transform new data
assert np.allclose(
original.inverse_transform(original.scores()),
loaded.inverse_transform(loaded.scores()),
rtol=1e-3,
atol=1e-3,
)
3 changes: 1 addition & 2 deletions tests/models/test_eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,11 @@ def test_save_load(dim, mock_data_array, tmp_path):

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.scores(), loaded.transform(mock_data_array), rtol=1e-3, atol=1e-3
original.transform(mock_data_array), loaded.transform(mock_data_array)
)

# The loaded model should also be able to inverse_transform new data
assert np.allclose(
original.inverse_transform(original.scores()),
loaded.inverse_transform(loaded.scores()),
rtol=1e-2,
)
6 changes: 1 addition & 5 deletions tests/models/test_mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,16 +422,12 @@ def test_save_load(dim, mock_data_array, tmp_path):

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.scores(),
original.transform(mock_data_array, mock_data_array),
loaded.transform(mock_data_array, mock_data_array),
rtol=1e-3,
atol=1e-3,
)

# The loaded model should also be able to inverse_transform new data
assert np.allclose(
original.inverse_transform(*original.scores()),
loaded.inverse_transform(*loaded.scores()),
rtol=1e-3,
atol=1e-3,
)
8 changes: 2 additions & 6 deletions tests/models/test_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,12 @@ def test_save_load(dim, mock_data_array, tmp_path):

# Test that the recreated model can be used to transform new data
assert np.allclose(
original.scores(),
loaded.transform(data1=mock_data_array, data2=mock_data_array),
rtol=1e-3,
atol=1e-3,
original.transform(mock_data_array, mock_data_array),
loaded.transform(mock_data_array, mock_data_array),
)

# The loaded model should also be able to inverse_transform new data
assert np.allclose(
original.inverse_transform(*original.scores()),
loaded.inverse_transform(*loaded.scores()),
rtol=1e-3,
atol=1e-3,
)
14 changes: 7 additions & 7 deletions xeofs/models/mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .mca import MCA, ComplexMCA
from ..preprocessing.preprocessor import Preprocessor
from ..utils.rotation import promax
from ..utils.data_types import DataArray
from ..utils.data_types import DataArray, DataObject
from ..utils.xarray_utils import argsort_dask, get_deterministic_sign_multiplier
from ..data_container import DataContainer
from .._version import __version__
Expand Down Expand Up @@ -319,7 +319,9 @@ def _sort_by_variance(self):
)
self.sorted = True

def transform(self, **kwargs) -> DataArray | List[DataArray]:
def transform(
self, data1: DataObject | None = None, data2: DataObject | None = None
) -> DataArray | List[DataArray]:
"""Project new "unseen" data onto the rotated singular vectors.
Parameters
Expand All @@ -336,7 +338,7 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]:
"""
# raise error if no data is provided
if not kwargs:
if data1 is None and data2 is None:
raise ValueError("No data provided. Please provide data1 and/or data2.")

n_modes = self._params["n_modes"]
Expand All @@ -348,8 +350,7 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]:

results = []

if "data1" in kwargs.keys():
data1 = kwargs["data1"]
if data1 is not None:
# Select the (non-rotated) singular vectors of the first dataset
comps1 = self.model.data["components1"].sel(mode=slice(1, n_modes))
norm1 = self.model.data["norm1"].sel(mode=slice(1, n_modes))
Expand All @@ -375,8 +376,7 @@ def transform(self, **kwargs) -> DataArray | List[DataArray]:

results.append(projections1)

if "data2" in kwargs.keys():
data2 = kwargs["data2"]
if data2 is not None:
# Select the (non-rotated) singular vectors of the second dataset
comps2 = self.model.data["components2"].sel(mode=slice(1, n_modes))
norm2 = self.model.data["norm2"].sel(mode=slice(1, n_modes))
Expand Down

0 comments on commit a9d6bdf

Please sign in to comment.