diff --git a/README.md b/README.md
index 13669944..69fdb951 100644
--- a/README.md
+++ b/README.md
@@ -146,25 +146,26 @@ mae = cal_mae(imputation, X_intact, indicating_mask) # calculate mean absolute
## β Available Algorithms
PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details.
-| ***`Imputation`*** | π₯ | π₯ | π₯ |
-|:----------------------:|:------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|
-| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
-| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 |
-| Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 |
-| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
-| Naive | LOCF | Last Observation Carried Forward | - |
-| ***`Classification`*** | π₯ | π₯ | π₯ |
-| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
-| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
-| Neural Net | GRU-D | Recurrent Neural Networks for Multivariate Time Series with Missing Values [^4] | 2018 |
-| Neural Net | Raindrop | Graph-Guided Network for Irregularly Sampled Multivariate Time Series [^5] | 2022 |
-| ***`Clustering`*** | π₯ | π₯ | π₯ |
-| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
-| Neural Net | CRLI | Clustering Representation Learning on Incomplete time-series data [^6] | 2021 |
-| Neural Net | VaDER | Variational Deep Embedding with Recurrence [^7] | 2019 |
-| ***`Forecasting`*** | π₯ | π₯ | π₯ |
-| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
-| Probabilistic | BTTF | Bayesian Temporal Tensor Factorization [^8] | 2021 |
+| ***`Imputation`*** | π₯ | π₯ | π₯ |
+|:----------------------:|:-----------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|
+| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
+| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 |
+| Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 |
+| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
+| Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 |
+| Naive | LOCF | Last Observation Carried Forward | - |
+| ***`Classification`*** | π₯ | π₯ | π₯ |
+| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
+| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 |
+| Neural Net | GRU-D | Recurrent Neural Networks for Multivariate Time Series with Missing Values [^4] | 2018 |
+| Neural Net | Raindrop | Graph-Guided Network for Irregularly Sampled Multivariate Time Series [^5] | 2022 |
+| ***`Clustering`*** | π₯ | π₯ | π₯ |
+| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
+| Neural Net | CRLI | Clustering Representation Learning on Incomplete time-series data [^6] | 2021 |
+| Neural Net | VaDER | Variational Deep Embedding with Recurrence [^7] | 2019 |
+| ***`Forecasting`*** | π₯ | π₯ | π₯ |
+| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** |
+| Probabilistic | BTTF | Bayesian Temporal Tensor Factorization [^8] | 2021 |
## β Citing PyPOTS
@@ -254,6 +255,8 @@ Thank you all for your attention! π
[^6]: Ma, Q., Chen, C., Li, S., & Cottrell, G. W. (2021). [Learning Representations for Incomplete Time Series Clustering](https://ojs.aaai.org/index.php/AAAI/article/view/17070). *AAAI 2021*.
[^7]: Jong, J.D., Emon, M.A., Wu, P., Karki, R., Sood, M., Godard, P., Ahmad, A., Vrooman, H.A., Hofmann-Apitius, M., & FrΓΆhlich, H. (2019). [Deep learning for clustering of multivariate clinical patient trajectories with missing values](https://academic.oup.com/gigascience/article/8/11/giz134/5626377). *GigaScience*.
[^8]: Chen, X., & Sun, L. (2021). [Bayesian Temporal Factorization for Multidimensional Time Series Prediction](https://arxiv.org/abs/1910.06366). *IEEE transactions on pattern analysis and machine intelligence*.
+[^9]: Yoon, J., Zame, W. R., & van der Schaar, M. (2019). [Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks](https://ieeexplore.ieee.org/document/8485748). *IEEE Transactions on Biomedical Engineering*.
+
π Visits
diff --git a/docs/index.rst b/docs/index.rst
index 6c264ff9..29411c15 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -128,6 +128,7 @@ Task Type Algorithm
Imputation Neural Network SAITS (Self-Attention-based Imputation for Time Series) 2022 :cite:`du2023SAITS`
Imputation Neural Network Transformer 2017 :cite:`vaswani2017Transformer`, :cite:`du2023SAITS`
Imputation, Classification Neural Network BRITS (Bidirectional Recurrent Imputation for Time Series) 2018 :cite:`cao2018BRITS`
+Imputation Neural Network M-RNN (Multi-directional Recurrent Neural Network) 2019 :cite:`yoon2019MRNN`
Imputation Naive LOCF (Last Observation Carried Forward) / /
Classification Neural Network GRU-D 2018 :cite:`che2018GRUD`
Classification Neural Network Raindrop 2022 :cite:`zhang2022Raindrop`
@@ -136,6 +137,7 @@ Clustering Neural Network VaDER (Variational Deep Embeddin
Forecasting Probabilistic BTTF (Bayesian Temporal Tensor Factorization) 2021 :cite:`chen2021BTMF`
============================== ================ ========================================================================= ====== =========
+[^9]: Yoon, J., Zame, W. R., & van der Schaar, M. (2019). [Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks](https://ieeexplore.ieee.org/document/8485748). *IEEE Transactions on Biomedical Engineering*.
β Citing PyPOTS
^^^^^^^^^^^^^^^^
diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst
index 50ff9d11..0e31f8c8 100644
--- a/docs/pypots.imputation.rst
+++ b/docs/pypots.imputation.rst
@@ -28,6 +28,15 @@ pypots.imputation.brits module
:show-inheritance:
:inherited-members:
+pypots.imputation.mrnn module
+------------------------------
+
+.. automodule:: pypots.imputation.mrnn
+ :members:
+ :undoc-members:
+ :show-inheritance:
+ :inherited-members:
+
pypots.imputation.locf module
-----------------------------
diff --git a/docs/references.bib b/docs/references.bib
index 050110c1..214b69b8 100644
--- a/docs/references.bib
+++ b/docs/references.bib
@@ -12,6 +12,18 @@ @article{cao2018BRITS
keywords = {Computer Science - Machine Learning,Statistics - Machine Learning}
}
+@ARTICLE{yoon2019MRNN,
+author={Yoon, Jinsung and Zame, William R. and van der Schaar, Mihaela},
+journal={IEEE Transactions on Biomedical Engineering},
+title={Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks},
+year={2019},
+volume={66},
+number={5},
+pages={1477-1490},
+doi={10.1109/TBME.2018.2874712}
+}
+
+
@article{che2018GRUD,
title = {Recurrent {{Neural Networks}} for {{Multivariate Time Series}} with {{Missing Values}}},
author = {Che, Zhengping and Purushotham, Sanjay and Cho, Kyunghyun and Sontag, David and Liu, Yan},
diff --git a/pypots/__init__.py b/pypots/__init__.py
index 8cedd5b4..b7a737b6 100644
--- a/pypots/__init__.py
+++ b/pypots/__init__.py
@@ -22,7 +22,7 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
-__version__ = "0.1.0"
+__version__ = "0.1.1"
__all__ = [
diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py
index 7e6878d1..9de8d0bc 100644
--- a/pypots/imputation/__init__.py
+++ b/pypots/imputation/__init__.py
@@ -9,10 +9,12 @@
from .locf import LOCF
from .saits import SAITS
from .transformer import Transformer
+from .mrnn import MRNN
__all__ = [
- "BRITS",
- "Transformer",
"SAITS",
+ "Transformer",
+ "BRITS",
+ "MRNN",
"LOCF",
]
diff --git a/pypots/imputation/mrnn/__init__.py b/pypots/imputation/mrnn/__init__.py
new file mode 100644
index 00000000..f5934cca
--- /dev/null
+++ b/pypots/imputation/mrnn/__init__.py
@@ -0,0 +1,13 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+from .model import MRNN
+
+
+__all__ = [
+ "MRNN",
+]
diff --git a/pypots/imputation/mrnn/data.py b/pypots/imputation/mrnn/data.py
new file mode 100644
index 00000000..1c8b3f1c
--- /dev/null
+++ b/pypots/imputation/mrnn/data.py
@@ -0,0 +1,46 @@
+"""
+Dataset class for model MRNN.
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+from typing import Union
+
+from ..brits.data import DatasetForBRITS
+
+
+class DatasetForMRNN(DatasetForBRITS):
+ """Dataset class for BRITS.
+
+ Parameters
+ ----------
+ data : dict or str,
+ The dataset for model input, should be a dictionary including keys as 'X' and 'y',
+ or a path string locating a data file.
+ If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
+ which is time-series data for input, can contain missing values, and y should be array-like of shape
+ [n_samples], which is classification labels of X.
+ If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
+ key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
+
+ return_labels : bool, default = True,
+ Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example,
+ during training of classification models, the Dataset class will return labels in __getitem__() for model input.
+ Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we
+ need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5
+ files, they already have both X and y saved. But we don't read labels from the file for validating and testing
+ with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for
+ distinction.
+
+ file_type : str, default = "h5py"
+ The type of the given file if train_set and val_set are path strings.
+ """
+
+ def __init__(
+ self,
+ data: Union[dict, str],
+ return_labels: bool = True,
+ file_type: str = "h5py",
+ ):
+ super().__init__(data, return_labels, file_type)
diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py
new file mode 100644
index 00000000..459e5ab1
--- /dev/null
+++ b/pypots/imputation/mrnn/model.py
@@ -0,0 +1,310 @@
+"""
+PyTorch MRNN model for the time-series imputation task.
+Some part of the code is from https://github.com/WenjieDu/SAITS.
+
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+
+from typing import Union, Optional
+
+import h5py
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+
+from .data import DatasetForMRNN
+from .module import FCN_Regression
+from ..base import BaseNNImputer
+from ...optim.adam import Adam
+from ...optim.base import Optimizer
+from ...utils.metrics import cal_rmse
+
+
+class _MRNN(nn.Module):
+ def __init__(self, seq_len, feature_num, rnn_hidden_size, device):
+ super().__init__()
+ # data settings
+ self.seq_len = seq_len
+ self.feature_num = feature_num
+ self.rnn_hidden_size = rnn_hidden_size
+ self.device = device
+
+ self.f_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size)
+ self.b_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size)
+ self.rnn_cells = {"forward": self.f_rnn, "backward": self.b_rnn}
+ self.concated_hidden_project = nn.Linear(
+ self.rnn_hidden_size * 2, self.feature_num
+ )
+ self.fcn_regression = FCN_Regression(feature_num, rnn_hidden_size)
+
+ def gene_hidden_states(self, data, direction):
+ values = data[direction]["X"]
+ masks = data[direction]["missing_mask"]
+ deltas = data[direction]["deltas"]
+
+ hidden_states_collector = []
+ hidden_state = torch.zeros(
+ (values.size()[0], self.rnn_hidden_size), device=self.device
+ )
+
+ for t in range(self.seq_len):
+ x = values[:, t, :]
+ m = masks[:, t, :]
+ d = deltas[:, t, :]
+ inputs = torch.cat([x, m, d], dim=1)
+ hidden_state = self.rnn_cells[direction](inputs, hidden_state)
+ hidden_states_collector.append(hidden_state)
+ return hidden_states_collector
+
+ def forward(self, inputs, training=True):
+ hidden_states_f = self.gene_hidden_states(inputs, "forward")
+ hidden_states_b = self.gene_hidden_states(inputs, "backward")[::-1]
+
+ X = inputs["forward"]["X"]
+ masks = inputs["forward"]["missing_mask"]
+
+ reconstruction_loss = 0
+ estimations = []
+ for i in range(
+ self.seq_len
+ ): # calculating estimation loss for times can obtain better results than once
+ x = X[:, i, :]
+ m = masks[:, i, :]
+ h_f = hidden_states_f[i]
+ h_b = hidden_states_b[i]
+ h = torch.cat([h_f, h_b], dim=1)
+ RNN_estimation = self.concated_hidden_project(h) # xΜ_t
+ RNN_imputed_data = m * x + (1 - m) * RNN_estimation
+ FCN_estimation = self.fcn_regression(
+ x, m, RNN_imputed_data
+ ) # FCN estimation is output estimation
+ reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse(
+ RNN_estimation, x, m
+ )
+ estimations.append(FCN_estimation.unsqueeze(dim=1))
+
+ estimations = torch.cat(estimations, dim=1)
+ imputed_data = masks * X + (1 - masks) * estimations
+
+ if not training:
+ # if not in training mode, return the classification result only
+ return {
+ "imputed_data": imputed_data,
+ }
+
+ reconstruction_loss /= self.seq_len
+
+ ret_dict = {
+ "loss": reconstruction_loss,
+ "imputed_data": imputed_data,
+ }
+ return ret_dict
+
+
+class MRNN(BaseNNImputer):
+ """The PyTorch implementation of the MRNN model :cite:`yoon2019MRNN`.
+
+ Parameters
+ ----------
+ rnn_hidden_size :
+ The size of the RNN hidden state, also the number of hidden units in the RNN cell.
+
+ batch_size :
+ The batch size for training and evaluating the model.
+
+ epochs :
+ The number of epochs for training the model.
+
+ patience :
+ The patience for the early-stopping mechanism. Given a positive integer, the training process will be
+ stopped when the model does not perform better after that number of epochs.
+ Leaving it default as None will disable the early-stopping.
+
+ optimizer :
+ The optimizer for model training.
+ If not given, will use a default Adam optimizer.
+
+ num_workers :
+ The number of subprocesses to use for data loading.
+ `0` means data loading will be in the main process, i.e. there won't be subprocesses.
+
+ device :
+ The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
+ If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
+ then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
+ If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
+ model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
+ Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
+
+ saving_path :
+ The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
+ training into a tensorboard file). Will not save if not given.
+
+ model_saving_strategy :
+ The strategy to save model checkpoints. It has to be one of [None, "best", "better"].
+ No model will be saved when it is set as None.
+ The "best" strategy will only automatically save the best model after the training finished.
+ The "better" strategy will automatically save the model during training whenever the model performs
+ better than in previous epochs.
+
+ Attributes
+ ----------
+ model : :class:`torch.nn.Module`
+ The underlying BRITS model.
+
+ optimizer : :class:`pypots.optim.Optimizer`
+ The optimizer for model training.
+
+ """
+
+ def __init__(
+ self,
+ n_steps: int,
+ n_features: int,
+ rnn_hidden_size: int,
+ batch_size: int = 32,
+ epochs: int = 100,
+ patience: int = None,
+ optimizer: Optional[Optimizer] = Adam(),
+ num_workers: int = 0,
+ device: Optional[Union[str, torch.device, list]] = None,
+ saving_path: str = None,
+ model_saving_strategy: Optional[str] = "best",
+ ):
+ super().__init__(
+ batch_size,
+ epochs,
+ patience,
+ num_workers,
+ device,
+ saving_path,
+ model_saving_strategy,
+ )
+
+ self.n_steps = n_steps
+ self.n_features = n_features
+ self.rnn_hidden_size = rnn_hidden_size
+
+ self.model = _MRNN(
+ self.n_steps,
+ self.n_features,
+ self.rnn_hidden_size,
+ self.device,
+ )
+ self._send_model_to_given_device()
+ self._print_model_size()
+
+ # set up the optimizer
+ self.optimizer = optimizer
+ self.optimizer.init_optimizer(self.model.parameters())
+
+ def _assemble_input_for_training(self, data: list) -> dict:
+ # fetch data
+ (
+ indices,
+ X,
+ missing_mask,
+ deltas,
+ back_X,
+ back_missing_mask,
+ back_deltas,
+ ) = self._send_data_to_given_device(data)
+
+ # assemble input data
+ inputs = {
+ "indices": indices,
+ "forward": {
+ "X": X,
+ "missing_mask": missing_mask,
+ "deltas": deltas,
+ },
+ "backward": {
+ "X": back_X,
+ "missing_mask": back_missing_mask,
+ "deltas": back_deltas,
+ },
+ }
+
+ return inputs
+
+ def _assemble_input_for_validating(self, data: list) -> dict:
+ return self._assemble_input_for_training(data)
+
+ def _assemble_input_for_testing(self, data: list) -> dict:
+ return self._assemble_input_for_validating(data)
+
+ def fit(
+ self,
+ train_set: Union[dict, str],
+ val_set: Optional[Union[dict, str]] = None,
+ file_type: str = "h5py",
+ ) -> None:
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ training_set = DatasetForMRNN(
+ train_set, return_labels=False, file_type=file_type
+ )
+ training_loader = DataLoader(
+ training_set,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ )
+ val_loader = None
+ if val_set is not None:
+ if isinstance(val_set, str):
+ with h5py.File(val_set, "r") as hf:
+ # Here we read the whole validation set from the file to mask a portion for validation.
+ # In PyPOTS, using a file usually because the data is too big. However, the validation set is
+ # generally shouldn't be too large. For example, we have 1 billion samples for model training.
+ # We won't take 20% of them as the validation set because we want as much as possible data for the
+ # training stage to enhance the model's generalization ability. Therefore, 100,000 representative
+ # samples will be enough to validate the model.
+ val_set = {
+ "X": hf["X"][:],
+ "X_intact": hf["X_intact"][:],
+ "indicating_mask": hf["indicating_mask"][:],
+ }
+ val_set = DatasetForMRNN(val_set, return_labels=False, file_type=file_type)
+ val_loader = DataLoader(
+ val_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+
+ # Step 2: train the model and freeze it
+ self._train_model(training_loader, val_loader)
+ self.model.load_state_dict(self.best_model_dict)
+ self.model.eval() # set the model as eval status to freeze it.
+
+ # Step 3: save the model if necessary
+ self._auto_save_model_if_necessary(training_finished=True)
+
+ def impute(
+ self,
+ X: Union[dict, str],
+ file_type="h5py",
+ ) -> np.ndarray:
+ self.model.eval() # set the model as eval status to freeze it.
+ test_set = DatasetForMRNN(X, return_labels=False, file_type=file_type)
+ test_loader = DataLoader(
+ test_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+ imputation_collector = []
+
+ with torch.no_grad():
+ for idx, data in enumerate(test_loader):
+ inputs = self._assemble_input_for_testing(data)
+ results = self.model.forward(inputs, training=False)
+ imputed_data = results["imputed_data"]
+ imputation_collector.append(imputed_data)
+
+ imputation_collector = torch.cat(imputation_collector)
+ return imputation_collector.cpu().detach().numpy()
diff --git a/pypots/imputation/mrnn/module.py b/pypots/imputation/mrnn/module.py
new file mode 100644
index 00000000..873d2d73
--- /dev/null
+++ b/pypots/imputation/mrnn/module.py
@@ -0,0 +1,48 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: GLP-v3
+
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+
+from ...imputation.brits.modules import FeatureRegression
+
+
+class FCN_Regression(nn.Module):
+ def __init__(self, feature_num, rnn_hid_size):
+ super(FCN_Regression, self).__init__()
+ self.feat_reg = FeatureRegression(rnn_hid_size * 2)
+ self.U = Parameter(torch.Tensor(feature_num, feature_num))
+ self.V1 = Parameter(torch.Tensor(feature_num, feature_num))
+ self.V2 = Parameter(torch.Tensor(feature_num, feature_num))
+ self.beta = Parameter(torch.Tensor(feature_num)) # bias beta
+ self.final_linear = nn.Linear(feature_num, feature_num)
+
+ m = torch.ones(feature_num, feature_num) - torch.eye(feature_num, feature_num)
+ self.register_buffer("m", m)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ stdv = 1.0 / math.sqrt(self.U.size(0))
+ self.U.data.uniform_(-stdv, stdv)
+ self.V1.data.uniform_(-stdv, stdv)
+ self.V2.data.uniform_(-stdv, stdv)
+ self.beta.data.uniform_(-stdv, stdv)
+
+ def forward(self, x_t, m_t, target):
+ h_t = F.tanh(
+ F.linear(x_t, self.U * self.m)
+ + F.linear(target, self.V1 * self.m)
+ + F.linear(m_t, self.V2)
+ + self.beta
+ )
+ x_hat_t = self.final_linear(h_t)
+ return x_hat_t
diff --git a/tests/test_imputation.py b/tests/test_imputation.py
index 24929b61..6094ce62 100644
--- a/tests/test_imputation.py
+++ b/tests/test_imputation.py
@@ -16,6 +16,7 @@
SAITS,
Transformer,
BRITS,
+ MRNN,
LOCF,
)
from pypots.optim import Adam
@@ -262,6 +263,75 @@ def test_3_saving_path(self):
self.brits.load_model(saved_model_path)
+class TestMRNN(unittest.TestCase):
+ logger.info("Running tests for an imputation model MRNN...")
+
+ # set the log and model saving path
+ saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "MRNN")
+ model_save_name = "saved_MRNN_model.pypots"
+
+ # initialize an Adam optimizer
+ optimizer = Adam(lr=0.001, weight_decay=1e-5)
+
+ # initialize a MRNN model
+ mrnn = MRNN(
+ DATA["n_steps"],
+ DATA["n_features"],
+ 256,
+ epochs=EPOCH,
+ saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/MRNN",
+ optimizer=optimizer,
+ )
+
+ @pytest.mark.xdist_group(name="imputation-mrnn")
+ def test_0_fit(self):
+ self.mrnn.fit(TRAIN_SET, VAL_SET)
+
+ @pytest.mark.xdist_group(name="imputation-mrnn")
+ def test_1_impute(self):
+ imputed_X = self.mrnn.impute(TEST_SET)
+ assert not np.isnan(
+ imputed_X
+ ).any(), "Output still has missing values after running impute()."
+ test_MAE = cal_mae(
+ imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"]
+ )
+ logger.info(f"MRNN test_MAE: {test_MAE}")
+
+ @pytest.mark.xdist_group(name="imputation-mrnn")
+ def test_2_parameters(self):
+ assert hasattr(self.mrnn, "model") and self.mrnn.model is not None
+
+ assert hasattr(self.mrnn, "optimizer") and self.mrnn.optimizer is not None
+
+ assert hasattr(self.mrnn, "best_loss")
+ self.assertNotEqual(self.mrnn.best_loss, float("inf"))
+
+ assert (
+ hasattr(self.mrnn, "best_model_dict")
+ and self.mrnn.best_model_dict is not None
+ )
+
+ @pytest.mark.xdist_group(name="imputation-mrnn")
+ def test_3_saving_path(self):
+ # whether the root saving dir exists, which should be created by save_log_into_tb_file
+ assert os.path.exists(
+ self.saving_path
+ ), f"file {self.saving_path} does not exist"
+
+ # check if the tensorboard file and model checkpoints exist
+ check_tb_and_model_checkpoints_existence(self.mrnn)
+
+ # save the trained model into file, and check if the path exists
+ self.mrnn.save_model(
+ saving_dir=self.saving_path, file_name=self.model_save_name
+ )
+
+ # test loading the saved model, not necessary, but need to test
+ saved_model_path = os.path.join(self.saving_path, self.model_save_name)
+ self.mrnn.load_model(saved_model_path)
+
+
class TestLOCF(unittest.TestCase):
logger.info("Running tests for an imputation model LOCF...")
locf = LOCF(nan=0)