diff --git a/README.md b/README.md index 597c02cd..af4e6d12 100644 --- a/README.md +++ b/README.md @@ -160,13 +160,14 @@ The paper references and links are all listed at the bottom of this file. | Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | | Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | | Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | +| MF | TRMF🧑‍🔧[^44] | ✅ | | | | | `2016 - NeurIPS` | | Naive | Lerp[^40] | ✅ | | | | | | | Naive | LOCF/NOCB | ✅ | | | | | | | Naive | Mean | ✅ | | | | | | | Naive | Median | ✅ | | | | | | 💯 Contribute your model right now to increase your research impact! PyPOTS downloads are increasing rapidly -(**[300K+ in total and 1K+ daily on PyPI so far](https://www.pepy.tech/projects/pypots)**), +(**[600K+ in total and 1K+ daily on PyPI so far](https://www.pepy.tech/projects/pypots)**), and your work will be widely used and cited by the community. Refer to the [contribution guide](https://github.com/WenjieDu/PyPOTS#-contribution) to see how to include your model in PyPOTS. @@ -517,3 +518,6 @@ Time-Series.AI [^43]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). [SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting](https://arxiv.org/abs/2308.11200). *arXiv 2023*. +[^44]: Yu, H. F., Rao, N., & Dhillon, I. S. (2016). +[Temporal regularized matrix factorization for high-dimensional time series prediction](https://papers.nips.cc/paper_files/paper/2016/file/85422afb467e9456013a2a51d4dff702-Paper.pdf). +*NeurIPS 2016*. diff --git a/README_zh.md b/README_zh.md index dc64581c..2e58cd96 100644 --- a/README_zh.md +++ b/README_zh.md @@ -145,13 +145,14 @@ PyPOTS当前支持多变量POTS数据的插补, 预测, 分类, 聚类以及异 | Neural Net | GRU-D[^4] | ✅ | | ✅ | | | `2018 - Sci. Rep.` | | Neural Net | TCN🧑‍🔧[^35] | ✅ | | | | | `2018 - arXiv` | | Neural Net | Transformer🧑‍🔧[^2] | ✅ | | | | | `2017 - NeurIPS` | +| MF | TRMF🧑‍🔧[^44] | ✅ | | | | | `2016 - NeurIPS` | | Naive | Lerp[^40] | ✅ | | | | | | | Naive | LOCF/NOCB | ✅ | | | | | | | Naive | Mean | ✅ | | | | | | | Naive | Median | ✅ | | | | | | 💯 现在贡献你的模型来增加你的研究影响力!PyPOTS的下载量正在迅速增长 -(**[目前PyPI上总共超过30万次且每日超1000的下载](https://www.pepy.tech/projects/pypots)**), +(**[目前PyPI上总共超过60万次且每日超1000的下载](https://www.pepy.tech/projects/pypots)**), 你的工作将被社区广泛使用和引用. 请参阅[贡献指南](#-%E8%B4%A1%E7%8C%AE%E5%A3%B0%E6%98%8E) , 了解如何将模型包含在PyPOTS中. @@ -490,3 +491,6 @@ Time-Series.AI [^43]: Lin, S., Lin, W., Wu, W., Zhao, F., Mo, R., & Zhang, H. (2023). [SegRNN: Segment Recurrent Neural Network for Long-Term Time Series Forecasting](https://arxiv.org/abs/2308.11200). *arXiv 2023*. +[^44]: Yu, H. F., Rao, N., & Dhillon, I. S. (2016). +[Temporal regularized matrix factorization for high-dimensional time series prediction](https://papers.nips.cc/paper_files/paper/2016/file/85422afb467e9456013a2a51d4dff702-Paper.pdf). +*NeurIPS 2016*. diff --git a/docs/index.rst b/docs/index.rst index 76aee376..0b255b9b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -213,6 +213,8 @@ The paper references are all listed at the bottom of this readme file. +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Neural Net | Transformer🧑‍🔧 :cite:`vaswani2017Transformer` | ✅ | | | | | ``2017 - NeurIPS`` | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ +| MF | TRMF :cite:`yu2016trmf` | ✅ | | | | | ``2016 - NeurIPS`` | ++----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Naive | Lerp (Linear Interpolation) | ✅ | | | | | | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ | Naive | LOCF/NOCB | ✅ | | | | | | @@ -222,7 +224,7 @@ The paper references are all listed at the bottom of this readme file. | Naive | Mean | ✅ | | | | | | +----------------+-----------------------------------------------------------+------+------+------+------+------+-----------------------+ -💯 Contribute your model right now to increase your research impact! PyPOTS downloads are increasing rapidly (`300K+ in total and 1K+ daily on PyPI so far `_), +💯 Contribute your model right now to increase your research impact! PyPOTS downloads are increasing rapidly (`600K+ in total and 1K+ daily on PyPI so far `_), and your work will be widely used and cited by the community. Refer to the `contribution guide <#id44>`_ to see how to include your model in PyPOTS. diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index 73e55646..05d13392 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -324,6 +324,15 @@ pypots.imputation.tcn :show-inheritance: :inherited-members: +pypots.imputation.trmf +------------------------------ + +.. automodule:: pypots.imputation.trmf + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.lerp ----------------------------- diff --git a/docs/references.bib b/docs/references.bib index 10a9632c..75f880d1 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -784,4 +784,12 @@ @article{qian2023csai author={Qian, Linglong and Ibrahim, Zina and Ellis, Hugh Logan and Zhang, Ao and Zhang, Yuezhou and Wang, Tao and Dobson, Richard}, journal={arXiv preprint arXiv:2312.16713}, year={2023} +} + +@article{yu2016trmf, +title={Temporal regularized matrix factorization for high-dimensional time series prediction}, +author={Yu, Hsiang-Fu and Rao, Nikhil and Dhillon, Inderjit S}, +journal={Advances in neural information processing systems}, +volume={29}, +year={2016} } \ No newline at end of file diff --git a/pypots/data/dataset/base.py b/pypots/data/dataset/base.py index a9b309c3..0e2df052 100644 --- a/pypots/data/dataset/base.py +++ b/pypots/data/dataset/base.py @@ -16,6 +16,7 @@ from torch.utils.data import Dataset from .config import SUPPORTED_DATASET_FILE_FORMATS +from ..saving import load_dict_from_h5 from ..utils import turn_data_into_specified_dtype @@ -435,6 +436,21 @@ def _fetch_data_from_file(self, idx: int) -> Iterable: return sample + def fetch_entire_dataset(self) -> dict: + """Fetch the entire dataset from the given data source. + + Returns + ------- + data : + The entire dataset in a dictionary fetched from the given data source. + + """ + if isinstance(self.data, str): # data from file + data = load_dict_from_h5(self.data) + else: + data = self.data + return data + def __getitem__(self, idx: int) -> Iterable: """Fetch data according to index. diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 9ce1d867..2ca28a71 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -39,13 +39,14 @@ from .timemixer import TimeMixer from .moderntcn import ModernTCN from .segrnn import SegRNN +from .tefn import TEFN +from .trmf import TRMF # naive imputation methods from .locf import LOCF from .mean import Mean from .median import Median from .lerp import Lerp -from .tefn import TEFN __all__ = [ # neural network imputation methods @@ -81,12 +82,13 @@ "ImputeFormer", "TimeMixer", "ModernTCN", + "TEFN", + "CSAI", + "SegRNN", + "TRMF", # naive imputation methods "LOCF", "Mean", "Median", "Lerp", - "TEFN", - "CSAI", - "SegRNN", ] diff --git a/pypots/imputation/trmf/__init__.py b/pypots/imputation/trmf/__init__.py new file mode 100644 index 00000000..44bac00d --- /dev/null +++ b/pypots/imputation/trmf/__init__.py @@ -0,0 +1,19 @@ +""" +The package including the modules of TRMF. + +Refer to the paper +`Hsiang-Fu Yu, Nikhil Rao, and Inderjit S. Dhillon. +"Temporal regularized matrix factorization for high-dimensional time series prediction." +In NeurIPS 2016. +`_ + +""" + +# Created by Jun Wang and Wenjie Du +# License: BSD-3-Clause + +from .model import TRMF + +__all__ = [ + "TRMF", +] diff --git a/pypots/imputation/trmf/core.py b/pypots/imputation/trmf/core.py new file mode 100644 index 00000000..298ac686 --- /dev/null +++ b/pypots/imputation/trmf/core.py @@ -0,0 +1,48 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ...nn.modules.trmf import BackboneTRMF + + +class _TRMF(nn.Module): + def __init__( + self, + lags, + K, + lambda_f, + lambda_x, + lambda_w, + alpha, + eta, + max_iter, + F_step=0.0001, + X_step=0.0001, + W_step=0.0001, + ): + super().__init__() + + self.backbone = BackboneTRMF( + lags, + K, + lambda_f, + lambda_x, + lambda_w, + alpha, + eta, + max_iter, + F_step, + X_step, + W_step, + ) + + def forward(self, inputs: dict) -> dict: + X, missing_mask = inputs["X"], inputs["missing_mask"] + self.backbone.forward(X, missing_mask) + results = {"loss": 0, "imputed_data": self.backbone.impute_missingness()} + return results diff --git a/pypots/imputation/trmf/data.py b/pypots/imputation/trmf/data.py new file mode 100644 index 00000000..39cf2f81 --- /dev/null +++ b/pypots/imputation/trmf/data.py @@ -0,0 +1,23 @@ +""" +Dataset class for the imputation model TRMF. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ...data.dataset import BaseDataset + + +class DatasetForTRMF(BaseDataset): + def __init__( + self, + data: Union[dict, str], + ): + super().__init__( + data, + return_X_ori=False, + return_X_pred=False, + return_y=False, + ) diff --git a/pypots/imputation/trmf/model.py b/pypots/imputation/trmf/model.py new file mode 100644 index 00000000..be4df4e5 --- /dev/null +++ b/pypots/imputation/trmf/model.py @@ -0,0 +1,243 @@ +""" +The implementation of TRMF for the partially-observed time-series imputation task, +which is mainly based on the implementation of TRMF in https://github.com/SemenovAlex/trmf/blob/master/trmf.py + +""" + +# Created by Jun Wang and Wenjie Du +# License: BSD-3-Clause + + +from typing import Union, Optional + +import numpy as np +import torch +from pygrinder import fill_and_get_mask_numpy + +from .core import _TRMF +from .data import DatasetForTRMF +from ..base import BaseImputer +from ...data import inverse_sliding_window, sliding_window +from ...data.dataset import BaseDataset +from ...utils.logging import logger + + +class TRMF(BaseImputer): + """Temporal Regularized Matrix Factorization (TRMF) imputation method. + + Parameters + ---------- + + lags : array-like, shape (n_lags,) + Set of lag indices to use in model. + + K : int + Length of latent embedding dimension + + lambda_f : float + Regularization parameter used for matrix F. + + lambda_x : float + Regularization parameter used for matrix X. + + lambda_w : float + Regularization parameter used for matrix W. + + alpha : float + Regularization parameter used for make the sum of lag coefficient close to 1. + That helps to avoid big deviations when forecasting. + + eta : float + Regularization parameter used for X when undercovering autoregressive dependencies. + + max_iter : int + Number of iterations of updating matrices F, X and W. + + F_step : float + Step of gradient descent when updating matrix F. + + X_step : float + Step of gradient descent when updating matrix X. + + W_step : float + Step of gradient descent when updating matrix W. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + verbose : + Whether to print out the training logs during the training process. + + """ + + def __init__( + self, + lags, + K, + lambda_f, + lambda_x, + lambda_w, + alpha, + eta, + max_iter=1000, + F_step=0.0001, + X_step=0.0001, + W_step=0.0001, + saving_path: Optional[str] = None, + model_saving_strategy: Optional[str] = "best", + verbose: bool = True, + ): + super().__init__( + saving_path=saving_path, + model_saving_strategy=model_saving_strategy, + verbose=verbose, + ) + + self.model = _TRMF( + lags, + K, + lambda_f, + lambda_x, + lambda_w, + alpha, + eta, + max_iter, + F_step, + X_step, + W_step, + ) + + logger.warning( + "‼️Note that, as a traditional matrix factorization function, TRMF does not support validation set. " + "Also, it only accepts 2-dim (time dim, feature dim) time series data, hence PyPOTS auto runs " + "inverse_sliding_window func for your input in the unified format with 3-dim (sample dim, time dim, " + "feature dim) and it assumes your samples window_len == sliding_len. If you generate samples " + "using sliding_window func with window_len != sliding_len, it may produce non-ideal results." + ) + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "hdf5", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForTRMF(train_set) + if val_set is not None: + raise RuntimeError("TRMF does not support validation set.") + + # Step 2: train the model and freeze it + X = training_set.fetch_entire_dataset()["X"] + X = inverse_sliding_window(X, training_set.n_steps) + if isinstance(X, torch.Tensor): + X = X.numpy() + X, missing_mask = fill_and_get_mask_numpy(X) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + self.model.forward(inputs) + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=self.model_saving_strategy == "best") + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + diagonal_attention_mask: bool = True, + return_latent_vars: bool = False, + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (n_steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : + The type of the given file if test_set is a path string. + + diagonal_attention_mask : + Whether to apply a diagonal attention mask to the self-attention mechanism in the testing stage. + + return_latent_vars : + Whether to return the latent variables in SAITS, e.g. attention weights of two DMSA blocks and + the weight matrix from the combination block, etc. + + Returns + ------- + file_type : + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + # self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, + return_X_ori=False, + return_X_pred=False, + return_y=False, + file_type=file_type, + ) + + X = test_set.fetch_entire_dataset()["X"] + X = inverse_sliding_window(X, test_set.n_steps) + if isinstance(X, torch.Tensor): + X = X.numpy() + X, missing_mask = fill_and_get_mask_numpy(X) + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + # Step 3: output collection and return + results = self.model.forward(inputs) + imputation = sliding_window(results["imputed_data"], test_set.n_steps) + result_dict = { + "imputation": imputation, + } + + return result_dict + + def impute( + self, + test_set: Union[dict, str], + file_type: str = "hdf5", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Parameters + ---------- + test_set : + The data samples for testing, should be array-like of shape [n_samples, sequence length (n_steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (n_steps), n_features], + Imputed data. + """ + + result_dict = self.predict(test_set, file_type=file_type) + return result_dict["imputation"] diff --git a/pypots/nn/modules/trmf/__init__.py b/pypots/nn/modules/trmf/__init__.py new file mode 100644 index 00000000..838b3eb6 --- /dev/null +++ b/pypots/nn/modules/trmf/__init__.py @@ -0,0 +1,12 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from .backbone import BackboneTRMF + +__all__ = [ + "BackboneTRMF", +] diff --git a/pypots/nn/modules/trmf/backbone.py b/pypots/nn/modules/trmf/backbone.py new file mode 100644 index 00000000..c4cf86d5 --- /dev/null +++ b/pypots/nn/modules/trmf/backbone.py @@ -0,0 +1,286 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import numpy as np +import torch.nn as nn + + +class BackboneTRMF(nn.Module): + """The backbone of Temporal Regularized Matrix Factorization (TRMF). + + Parameters + ---------- + lags : array-like, shape (n_lags,) + Set of lag indices to use in model. + + K : int + Length of latent embedding dimension + + lambda_f : float + Regularization parameter used for matrix F. + + lambda_x : float + Regularization parameter used for matrix X. + + lambda_w : float + Regularization parameter used for matrix W. + + alpha : float + Regularization parameter used for make the sum of lag coefficient close to 1. + That helps to avoid big deviations when forecasting. + + eta : float + Regularization parameter used for X when undercovering autoregressive dependencies. + + max_iter : int + Number of iterations of updating matrices F, X and W. + + F_step : float + Step of gradient descent when updating matrix F. + + X_step : float + Step of gradient descent when updating matrix X. + + W_step : float + Step of gradient descent when updating matrix W. + + Attributes + ---------- + + F : ndarray, shape (n_timeseries, K) + Latent embedding of timeseries. + + X : ndarray, shape (K, n_timepoints) + Latent embedding of timepoints. + + W : ndarray, shape (K, n_lags) + Matrix of autoregressive coefficients. + + """ + + def __init__( + self, + lags, + K, + lambda_f, + lambda_x, + lambda_w, + alpha, + eta, + max_iter, + F_step=0.0001, + X_step=0.0001, + W_step=0.0001, + ): + super().__init__() + + self.lags = lags + self.L = len(lags) + self.K = K + self.lambda_f = lambda_f + self.lambda_x = lambda_x + self.lambda_w = lambda_w + self.alpha = alpha + self.eta = eta + self.max_iter = max_iter + self.F_step = F_step + self.X_step = X_step + self.W_step = W_step + + self.W = None + self.F = None + self.X = None + + self.Y = None + self.mask = None + self.N = None + self.T = None + + # Flag to check if the model is trained. + # It is useful when we want to continue training the model without initialization after it was trained. + self.trained = False + + def forward(self, X, missing_mask): + # Transpose the input data to have the same dimensionality as in the original TRMF implementation + X = X.T + missing_mask = missing_mask.T + + self.Y = X.copy() + self.mask = missing_mask.copy() + ( + self.N, + self.T, + ) = self.Y.shape + + # if not trained, initialize the matrices + if not self.trained: + self.W = np.random.randn(self.K, self.L) / self.L + self.F = np.random.randn(self.N, self.K) + self.X = np.random.randn(self.K, self.T) + self.trained = True + + for _ in range(self.max_iter): + self._update_F(step=self.F_step) + self._update_X(step=self.X_step) + self._update_W(step=self.W_step) + + def impute_missingness(self): + """Impute each missing element in timeseries. + + Model uses matrix X and F to get all missing elements. + + Parameters + ---------- + + Returns + ------- + data : ndarray, shape (n_timeseries, T) + Imputed data. + """ + data = self.Y + data[self.mask == 0] = np.dot(self.F, self.X)[self.mask == 0] + return data.T + + def _update_F(self, step, n_iter=1): + """Gradient descent of matrix F. + + n_iter steps of gradient descent of matrix F. + + Parameters + ---------- + step : float + Step of gradient descent when updating matrix. + + n_iter : int + Number of gradient steps to be made. + + Returns + ------- + self : objects + Returns self. + """ + + for _ in range(n_iter): + self.F -= step * self._grad_F() + + def _grad_F(self): + """Gradient of matrix F. + + Evaluating gradient of matrix F. + + Parameters + ---------- + + Returns + ------- + self : objects + Returns self. + """ + return -2 * np.dot((self.Y - np.dot(self.F, self.X)) * self.mask, self.X.T) + 2 * self.lambda_f * self.F + + def _update_X(self, step, n_iter=1): + """Gradient descent of matrix X. + + n_iter steps of gradient descent of matrix X. + + Parameters + ---------- + step : float + Step of gradient descent when updating matrix. + + n_iter : int + Number of gradient steps to be made. + + Returns + ------- + self : objects + Returns self. + """ + + for _ in range(n_iter): + self.X -= step * self._grad_X() + + def _grad_X(self): + """Gradient of matrix X. + + Evaluating gradient of matrix X. + + Parameters + ---------- + + Returns + ------- + self : objects + Returns self. + """ + + for i in range(self.L): + lag = self.lags[i] + W_i = self.W[:, i].repeat(self.T, axis=0).reshape(self.K, self.T) + X_i = self.X * W_i + z_1 = self.X - np.roll(X_i, lag, axis=1) + z_1[:, : max(self.lags)] = 0.0 + z_2 = -(np.roll(self.X, -lag, axis=1) - X_i) * W_i + z_2[:, -lag:] = 0.0 + + grad_T_x = z_1 + z_2 + return ( + -2 * np.dot(self.F.T, self.mask * (self.Y - np.dot(self.F, self.X))) + + self.lambda_x * grad_T_x + + self.eta * self.X + ) + + def _update_W(self, step, n_iter=1): + """Gradient descent of matrix W. + + n_iter steps of gradient descent of matrix W. + + Parameters + ---------- + step : float + Step of gradient descent when updating matrix. + + n_iter : int + Number of gradient steps to be made. + + Returns + ------- + self : objects + Returns self. + """ + + for _ in range(n_iter): + self.W -= step * self._grad_W() + + def _grad_W(self): + """Gradient of matrix W. + + Evaluating gradient of matrix W. + + Parameters + ---------- + + Returns + ------- + self : objects + Returns self. + """ + + grad = np.zeros((self.K, self.L)) + for i in range(self.L): + lag = self.lags[i] + W_i = self.W[:, i].repeat(self.T, axis=0).reshape(self.K, self.T) + X_i = self.X * W_i + z_1 = self.X - np.roll(X_i, lag, axis=1) + z_1[:, : max(self.lags)] = 0.0 + z_2 = -(z_1 * np.roll(self.X, lag, axis=1)).sum(axis=1) + grad[:, i] = z_2 + return ( + grad + + self.W * 2 * self.lambda_w / self.lambda_x + - self.alpha * 2 * (1 - self.W.sum(axis=1)).repeat(self.L).reshape(self.W.shape) + ) diff --git a/tests/imputation/trmf.py b/tests/imputation/trmf.py new file mode 100644 index 00000000..b5509da3 --- /dev/null +++ b/tests/imputation/trmf.py @@ -0,0 +1,107 @@ +""" +Test cases for TRMF imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import TRMF +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from pypots.utils.visual.data import plot_data, plot_missingness +from tests.global_test_config import ( + DATA, + TRAIN_SET, + TEST_SET, + GENERAL_H5_TRAIN_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestTRMF(unittest.TestCase): + logger.info("Running tests for an imputation model TRMF...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "TRMF") + model_save_name = "saved_trmf_model.pypots" + + # initialize a TRMF model + trmf = TRMF( + (1, 5), + K=8, + lambda_f=1, + lambda_x=1, + lambda_w=1, + alpha=1, + eta=1000, + max_iter=1000, + saving_path=saving_path, + ) + + @pytest.mark.xdist_group(name="imputation-trmf") + def test_0_fit(self): + self.trmf.fit(TRAIN_SET) + + @pytest.mark.xdist_group(name="imputation-trmf") + def test_1_impute(self): + imputation_results = self.trmf.predict(TRAIN_SET, return_latent_vars=True) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["train_X_ori"], + np.isnan(DATA["train_X"]) ^ np.isnan(DATA["train_X_ori"]), + ) + logger.info(f"TRMF test_MSE: {test_MSE}") + + # plot the missingness and imputed data + plot_missingness(~np.isnan(TEST_SET["X"]), 0, imputation_results["imputation"].shape[1]) + plot_data(TEST_SET["X"], TEST_SET["X_ori"], imputation_results["imputation"]) + + @pytest.mark.xdist_group(name="imputation-trmf") + def test_2_parameters(self): + assert hasattr(self.trmf, "model") and self.trmf.model is not None + + @pytest.mark.xdist_group(name="imputation-trmf") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists(self.saving_path), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.trmf) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.trmf.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.trmf.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-trmf") + def test_4_lazy_loading(self): + self.trmf.fit(GENERAL_H5_TRAIN_SET_PATH) + imputation_results = self.trmf.predict(GENERAL_H5_TRAIN_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["train_X_ori"], + np.isnan(DATA["train_X"]) ^ np.isnan(DATA["train_X_ori"]), + ) + logger.info(f"Lazy-loading TRMF test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main()