From f1c78806d4f5b1681f6fe91bb9d9a8530be93136 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 21 Jun 2022 10:28:55 +0800 Subject: [PATCH 01/10] feat: add MRNN for the imputation task; --- pypots/imputation/mrnn.py | 146 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 pypots/imputation/mrnn.py diff --git a/pypots/imputation/mrnn.py b/pypots/imputation/mrnn.py new file mode 100644 index 00000000..3b350bd6 --- /dev/null +++ b/pypots/imputation/mrnn.py @@ -0,0 +1,146 @@ +""" +PyTorch MRNN model for the time-series imputation task. +Some part of the code is from https://github.com/WenjieDu/SAITS. + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from pypots.imputation.brits import FeatureRegression +from pypots.utils.metrics import cal_mae, cal_rmse + + +class FCN_Regression(nn.Module): + def __init__(self, feature_num, rnn_hid_size): + super(FCN_Regression, self).__init__() + self.feat_reg = FeatureRegression(rnn_hid_size * 2) + self.U = Parameter(torch.Tensor(feature_num, feature_num)) + self.V1 = Parameter(torch.Tensor(feature_num, feature_num)) + self.V2 = Parameter(torch.Tensor(feature_num, feature_num)) + self.beta = Parameter(torch.Tensor(feature_num)) # bias beta + self.final_linear = nn.Linear(feature_num, feature_num) + + m = torch.ones(feature_num, feature_num) - torch.eye(feature_num, feature_num) + self.register_buffer("m", m) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.U.size(0)) + self.U.data.uniform_(-stdv, stdv) + self.V1.data.uniform_(-stdv, stdv) + self.V2.data.uniform_(-stdv, stdv) + self.beta.data.uniform_(-stdv, stdv) + + def forward(self, x_t, m_t, target): + h_t = F.tanh( + F.linear(x_t, self.U * self.m) + + F.linear(target, self.V1 * self.m) + + F.linear(m_t, self.V2) + + self.beta + ) + x_hat_t = self.final_linear(h_t) + return x_hat_t + + +class MRNN(nn.Module): + def __init__(self, seq_len, feature_num, rnn_hidden_size, **kwargs): + super(MRNN, self).__init__() + # data settings + self.seq_len = seq_len + self.feature_num = feature_num + self.rnn_hidden_size = rnn_hidden_size + self.device = kwargs["device"] + + self.f_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) + self.b_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) + self.rnn_cells = {"forward": self.f_rnn, "backward": self.b_rnn} + self.concated_hidden_project = nn.Linear( + self.rnn_hidden_size * 2, self.feature_num + ) + self.fcn_regression = FCN_Regression(feature_num, rnn_hidden_size) + + def gene_hidden_states(self, data, direction): + values = data[direction]["X"] + masks = data[direction]["missing_mask"] + deltas = data[direction]["deltas"] + + hidden_states_collector = [] + hidden_state = torch.zeros( + (values.size()[0], self.rnn_hidden_size), device=self.device + ) + + for t in range(self.seq_len): + x = values[:, t, :] + m = masks[:, t, :] + d = deltas[:, t, :] + inputs = torch.cat([x, m, d], dim=1) + hidden_state = self.rnn_cells[direction](inputs, hidden_state) + hidden_states_collector.append(hidden_state) + return hidden_states_collector + + def impute(self, data): + hidden_states_f = self.gene_hidden_states(data, "forward") + hidden_states_b = self.gene_hidden_states(data, "backward")[::-1] + + values = data["forward"]["X"] + masks = data["forward"]["missing_mask"] + + reconstruction_loss = 0 + estimations = [] + for i in range( + self.seq_len + ): # calculating estimation loss for times can obtain better results than once + x = values[:, i, :] + m = masks[:, i, :] + h_f = hidden_states_f[i] + h_b = hidden_states_b[i] + h = torch.cat([h_f, h_b], dim=1) + RNN_estimation = self.concated_hidden_project(h) # xΜƒ_t + RNN_imputed_data = m * x + (1 - m) * RNN_estimation + FCN_estimation = self.fcn_regression( + x, m, RNN_imputed_data + ) # FCN estimation is output extimation + reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse( + RNN_estimation, x, m + ) + estimations.append(FCN_estimation.unsqueeze(dim=1)) + + estimations = torch.cat(estimations, dim=1) + imputed_data = masks * values + (1 - masks) * estimations + return imputed_data, [estimations, reconstruction_loss] + + def forward(self, data, stage): + values = data["forward"]["X"] + masks = data["forward"]["missing_mask"] + imputed_data, [estimations, reconstruction_loss] = self.impute(data) + reconstruction_loss /= self.seq_len + reconstruction_MAE = cal_mae(estimations.detach(), values, masks) + + if stage == "val": + # have to cal imputation loss in the val stage; no need to cal imputation loss here in the test stage + imputation_MAE = cal_mae( + imputed_data, data["X_holdout"], data["indicating_mask"] + ) + else: + imputation_MAE = torch.tensor(0.0) + + ret_dict = { + "reconstruction_loss": reconstruction_loss, + "reconstruction_MAE": reconstruction_MAE, + "imputation_loss": imputation_MAE, + "imputation_MAE": imputation_MAE, + "imputed_data": imputed_data, + } + if "X_holdout" in data: + ret_dict["X_holdout"] = data["X_holdout"] + ret_dict["indicating_mask"] = data["indicating_mask"] + return ret_dict From f87e31e087da6021f8d7f1b2a0b04c41ff85ccda Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 23 Jun 2022 13:56:03 +0800 Subject: [PATCH 02/10] feat: update MRNN; --- pypots/imputation/mrnn.py | 106 +++++++++++++++++++++++++++++++++++--- 1 file changed, 100 insertions(+), 6 deletions(-) diff --git a/pypots/imputation/mrnn.py b/pypots/imputation/mrnn.py index 3b350bd6..7e99b36c 100644 --- a/pypots/imputation/mrnn.py +++ b/pypots/imputation/mrnn.py @@ -14,9 +14,15 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter +from torch.utils.data import DataLoader +from pypots.data.base import BaseDataset +from pypots.data.dataset_for_brits import DatasetForBRITS +from pypots.data.integration import mcar, masked_fill +from pypots.imputation.base import BaseNNImputer from pypots.imputation.brits import FeatureRegression -from pypots.utils.metrics import cal_mae, cal_rmse +from pypots.utils.metrics import cal_mae +from pypots.utils.metrics import cal_rmse class FCN_Regression(nn.Module): @@ -51,14 +57,14 @@ def forward(self, x_t, m_t, target): return x_hat_t -class MRNN(nn.Module): - def __init__(self, seq_len, feature_num, rnn_hidden_size, **kwargs): - super(MRNN, self).__init__() +class _MRNN(nn.Module): + def __init__(self, seq_len, feature_num, rnn_hidden_size, device): + super().__init__() # data settings self.seq_len = seq_len self.feature_num = feature_num self.rnn_hidden_size = rnn_hidden_size - self.device = kwargs["device"] + self.device = device self.f_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) self.b_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) @@ -108,7 +114,7 @@ def impute(self, data): RNN_imputed_data = m * x + (1 - m) * RNN_estimation FCN_estimation = self.fcn_regression( x, m, RNN_imputed_data - ) # FCN estimation is output extimation + ) # FCN estimation is output estimation reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse( RNN_estimation, x, m ) @@ -144,3 +150,91 @@ def forward(self, data, stage): ret_dict["X_holdout"] = data["X_holdout"] ret_dict["indicating_mask"] = data["indicating_mask"] return ret_dict + + +class MRNN(BaseNNImputer): + def __init__( + self, + n_steps, + n_features, + rnn_hidden_size, + learning_rate=1e-3, + epochs=100, + patience=10, + batch_size=32, + weight_decay=1e-5, + device=None, + ): + super().__init__( + learning_rate, epochs, patience, batch_size, weight_decay, device + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.rnn_hidden_size = rnn_hidden_size + + self.model = _MRNN( + self.n_steps, self.n_features, self.rnn_hidden_size, self.device + ) + self.model = self.model.to(self.device) + self._print_model_size() + + def fit(self, train_X, val_X=None): + train_X = self.check_input(self.n_steps, self.n_features, train_X) + if val_X is not None: + val_X = self.check_input(self.n_steps, self.n_features, val_X) + + training_set = DatasetForBRITS(train_X) + training_loader = DataLoader(training_set, batch_size=self.batch_size, shuffle=True) + if val_X is None: + self._train_model(training_loader) + else: + val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar(val_X, 0.2) + val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) + val_set = DatasetForBRITS(val_X) + val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False) + self._train_model(training_loader, val_loader, val_X_intact, val_X_indicating_mask) + + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + def assemble_input_data(self, data): + """ Assemble the input data into a dictionary. + + Parameters + ---------- + data : list + A list containing data fetched from Dataset by Dataload. + + Returns + ------- + inputs : dict + A dictionary with data assembled. + """ + indices, X_intact, X, missing_mask, indicating_mask = data + + inputs = { + 'X': X, + 'X_intact': X_intact, + 'missing_mask': missing_mask, + 'indicating_mask': indicating_mask + } + + return inputs + + def impute(self, X): + X = self.check_input(self.n_steps, self.n_features, X) + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset(X) + test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False) + imputation_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = {'X': data[1], 'missing_mask': data[2]} + imputed_data, _ = self.model.impute(inputs) + imputation_collector.append(imputed_data) + + imputation_collector = torch.cat(imputation_collector) + return imputation_collector.cpu().detach().numpy() From 56d5d378e081b3330bf5342fdc5e4a82befd37e2 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Sun, 26 Jun 2022 16:27:53 +0800 Subject: [PATCH 03/10] feat: update MRNN; --- pypots/imputation/mrnn.py | 118 ++++++++++++++++++-------------------- 1 file changed, 55 insertions(+), 63 deletions(-) diff --git a/pypots/imputation/mrnn.py b/pypots/imputation/mrnn.py index 7e99b36c..a7197058 100644 --- a/pypots/imputation/mrnn.py +++ b/pypots/imputation/mrnn.py @@ -124,31 +124,23 @@ def impute(self, data): imputed_data = masks * values + (1 - masks) * estimations return imputed_data, [estimations, reconstruction_loss] - def forward(self, data, stage): - values = data["forward"]["X"] - masks = data["forward"]["missing_mask"] - imputed_data, [estimations, reconstruction_loss] = self.impute(data) + def forward(self, inputs, stage): + imputed_data, [_, reconstruction_loss] = self.impute(inputs) reconstruction_loss /= self.seq_len - reconstruction_MAE = cal_mae(estimations.detach(), values, masks) if stage == "val": # have to cal imputation loss in the val stage; no need to cal imputation loss here in the test stage imputation_MAE = cal_mae( - imputed_data, data["X_holdout"], data["indicating_mask"] + imputed_data, inputs["X_holdout"], inputs["indicating_mask"] ) else: imputation_MAE = torch.tensor(0.0) ret_dict = { "reconstruction_loss": reconstruction_loss, - "reconstruction_MAE": reconstruction_MAE, "imputation_loss": imputation_MAE, - "imputation_MAE": imputation_MAE, "imputed_data": imputed_data, } - if "X_holdout" in data: - ret_dict["X_holdout"] = data["X_holdout"] - ret_dict["indicating_mask"] = data["indicating_mask"] return ret_dict @@ -181,60 +173,60 @@ def __init__( self._print_model_size() def fit(self, train_X, val_X=None): - train_X = self.check_input(self.n_steps, self.n_features, train_X) - if val_X is not None: - val_X = self.check_input(self.n_steps, self.n_features, val_X) - - training_set = DatasetForBRITS(train_X) - training_loader = DataLoader(training_set, batch_size=self.batch_size, shuffle=True) - if val_X is None: - self._train_model(training_loader) - else: - val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar(val_X, 0.2) - val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) - val_set = DatasetForBRITS(val_X) - val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False) - self._train_model(training_loader, val_loader, val_X_intact, val_X_indicating_mask) - - self.model.load_state_dict(self.best_model_dict) - self.model.eval() # set the model as eval status to freeze it. - - def assemble_input_data(self, data): - """ Assemble the input data into a dictionary. - - Parameters - ---------- - data : list - A list containing data fetched from Dataset by Dataload. + train_X = self.check_input(self.n_steps, self.n_features, train_X) + if val_X is not None: + val_X = self.check_input(self.n_steps, self.n_features, val_X) + + training_set = DatasetForBRITS(train_X) + training_loader = DataLoader(training_set, batch_size=self.batch_size, shuffle=True) + if val_X is None: + self._train_model(training_loader) + else: + val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar(val_X, 0.2) + val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) + val_set = DatasetForBRITS(val_X) + val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False) + self._train_model(training_loader, val_loader, val_X_intact, val_X_indicating_mask) - Returns - ------- - inputs : dict - A dictionary with data assembled. - """ - indices, X_intact, X, missing_mask, indicating_mask = data + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. - inputs = { - 'X': X, - 'X_intact': X_intact, - 'missing_mask': missing_mask, - 'indicating_mask': indicating_mask - } + def assemble_input_data(self, data): + """ Assemble the input data into a dictionary. + + Parameters + ---------- + data : list + A list containing data fetched from Dataset by Dataload. + + Returns + ------- + inputs : dict + A dictionary with data assembled. + """ + indices, X_intact, X, missing_mask, indicating_mask = data + + inputs = { + 'X': X, + 'X_intact': X_intact, + 'missing_mask': missing_mask, + 'indicating_mask': indicating_mask + } - return inputs + return inputs def impute(self, X): - X = self.check_input(self.n_steps, self.n_features, X) - self.model.eval() # set the model as eval status to freeze it. - test_set = BaseDataset(X) - test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False) - imputation_collector = [] - - with torch.no_grad(): - for idx, data in enumerate(test_loader): - inputs = {'X': data[1], 'missing_mask': data[2]} - imputed_data, _ = self.model.impute(inputs) - imputation_collector.append(imputed_data) - - imputation_collector = torch.cat(imputation_collector) - return imputation_collector.cpu().detach().numpy() + X = self.check_input(self.n_steps, self.n_features, X) + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset(X) + test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False) + imputation_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = {'X': data[1], 'missing_mask': data[2]} + imputed_data, _ = self.model.impute(inputs) + imputation_collector.append(imputed_data) + + imputation_collector = torch.cat(imputation_collector) + return imputation_collector.cpu().detach().numpy() From 976606210e86287e201d85e87cf57f9a0a243ab9 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 27 Jun 2022 15:31:20 +0800 Subject: [PATCH 04/10] fix: update MRNN; --- pypots/imputation/mrnn.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/pypots/imputation/mrnn.py b/pypots/imputation/mrnn.py index a7197058..f89e6b31 100644 --- a/pypots/imputation/mrnn.py +++ b/pypots/imputation/mrnn.py @@ -172,27 +172,33 @@ def __init__( self.model = self.model.to(self.device) self._print_model_size() - def fit(self, train_X, val_X=None): + def fit(self, train_X, val_X=None): train_X = self.check_input(self.n_steps, self.n_features, train_X) if val_X is not None: val_X = self.check_input(self.n_steps, self.n_features, val_X) training_set = DatasetForBRITS(train_X) - training_loader = DataLoader(training_set, batch_size=self.batch_size, shuffle=True) + training_loader = DataLoader( + training_set, batch_size=self.batch_size, shuffle=True + ) if val_X is None: self._train_model(training_loader) else: - val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar(val_X, 0.2) + val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar( + val_X, 0.2 + ) val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) val_set = DatasetForBRITS(val_X) val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False) - self._train_model(training_loader, val_loader, val_X_intact, val_X_indicating_mask) + self._train_model( + training_loader, val_loader, val_X_intact, val_X_indicating_mask + ) self.model.load_state_dict(self.best_model_dict) self.model.eval() # set the model as eval status to freeze it. def assemble_input_data(self, data): - """ Assemble the input data into a dictionary. + """Assemble the input data into a dictionary. Parameters ---------- @@ -207,10 +213,10 @@ def assemble_input_data(self, data): indices, X_intact, X, missing_mask, indicating_mask = data inputs = { - 'X': X, - 'X_intact': X_intact, - 'missing_mask': missing_mask, - 'indicating_mask': indicating_mask + "X": X, + "X_intact": X_intact, + "missing_mask": missing_mask, + "indicating_mask": indicating_mask, } return inputs @@ -224,7 +230,7 @@ def impute(self, X): with torch.no_grad(): for idx, data in enumerate(test_loader): - inputs = {'X': data[1], 'missing_mask': data[2]} + inputs = {"X": data[1], "missing_mask": data[2]} imputed_data, _ = self.model.impute(inputs) imputation_collector.append(imputed_data) From ebb84e1f3d2bfecf261a0721738f9761e1e0477d Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 30 Jun 2022 12:17:39 +0800 Subject: [PATCH 05/10] fix: update MRNN; --- pypots/imputation/mrnn.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/pypots/imputation/mrnn.py b/pypots/imputation/mrnn.py index f89e6b31..2e65ef71 100644 --- a/pypots/imputation/mrnn.py +++ b/pypots/imputation/mrnn.py @@ -124,17 +124,14 @@ def impute(self, data): imputed_data = masks * values + (1 - masks) * estimations return imputed_data, [estimations, reconstruction_loss] - def forward(self, inputs, stage): + def forward(self, inputs): imputed_data, [_, reconstruction_loss] = self.impute(inputs) reconstruction_loss /= self.seq_len - if stage == "val": - # have to cal imputation loss in the val stage; no need to cal imputation loss here in the test stage - imputation_MAE = cal_mae( - imputed_data, inputs["X_holdout"], inputs["indicating_mask"] - ) - else: - imputation_MAE = torch.tensor(0.0) + # have to cal imputation loss in the val stage; no need to cal imputation loss here in the test stage + imputation_MAE = cal_mae( + imputed_data, inputs["X_holdout"], inputs["indicating_mask"] + ) ret_dict = { "reconstruction_loss": reconstruction_loss, From 23c66f9e9639f752edbcec689b497cff2b42ed81 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 22 May 2023 00:10:31 +0800 Subject: [PATCH 06/10] feat: add MRNN to the current framework; --- pypots/imputation/__init__.py | 6 +- pypots/imputation/mrnn.py | 235 ---------------------- pypots/imputation/mrnn/__init__.py | 13 ++ pypots/imputation/mrnn/data.py | 46 +++++ pypots/imputation/mrnn/model.py | 310 +++++++++++++++++++++++++++++ pypots/imputation/mrnn/module.py | 48 +++++ tests/test_imputation.py | 70 +++++++ 7 files changed, 491 insertions(+), 237 deletions(-) delete mode 100644 pypots/imputation/mrnn.py create mode 100644 pypots/imputation/mrnn/__init__.py create mode 100644 pypots/imputation/mrnn/data.py create mode 100644 pypots/imputation/mrnn/model.py create mode 100644 pypots/imputation/mrnn/module.py diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 7e6878d1..9de8d0bc 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -9,10 +9,12 @@ from .locf import LOCF from .saits import SAITS from .transformer import Transformer +from .mrnn import MRNN __all__ = [ - "BRITS", - "Transformer", "SAITS", + "Transformer", + "BRITS", + "MRNN", "LOCF", ] diff --git a/pypots/imputation/mrnn.py b/pypots/imputation/mrnn.py deleted file mode 100644 index 2e65ef71..00000000 --- a/pypots/imputation/mrnn.py +++ /dev/null @@ -1,235 +0,0 @@ -""" -PyTorch MRNN model for the time-series imputation task. -Some part of the code is from https://github.com/WenjieDu/SAITS. - -""" - -# Created by Wenjie Du -# License: GLP-v3 - - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.parameter import Parameter -from torch.utils.data import DataLoader - -from pypots.data.base import BaseDataset -from pypots.data.dataset_for_brits import DatasetForBRITS -from pypots.data.integration import mcar, masked_fill -from pypots.imputation.base import BaseNNImputer -from pypots.imputation.brits import FeatureRegression -from pypots.utils.metrics import cal_mae -from pypots.utils.metrics import cal_rmse - - -class FCN_Regression(nn.Module): - def __init__(self, feature_num, rnn_hid_size): - super(FCN_Regression, self).__init__() - self.feat_reg = FeatureRegression(rnn_hid_size * 2) - self.U = Parameter(torch.Tensor(feature_num, feature_num)) - self.V1 = Parameter(torch.Tensor(feature_num, feature_num)) - self.V2 = Parameter(torch.Tensor(feature_num, feature_num)) - self.beta = Parameter(torch.Tensor(feature_num)) # bias beta - self.final_linear = nn.Linear(feature_num, feature_num) - - m = torch.ones(feature_num, feature_num) - torch.eye(feature_num, feature_num) - self.register_buffer("m", m) - self.reset_parameters() - - def reset_parameters(self): - stdv = 1.0 / math.sqrt(self.U.size(0)) - self.U.data.uniform_(-stdv, stdv) - self.V1.data.uniform_(-stdv, stdv) - self.V2.data.uniform_(-stdv, stdv) - self.beta.data.uniform_(-stdv, stdv) - - def forward(self, x_t, m_t, target): - h_t = F.tanh( - F.linear(x_t, self.U * self.m) - + F.linear(target, self.V1 * self.m) - + F.linear(m_t, self.V2) - + self.beta - ) - x_hat_t = self.final_linear(h_t) - return x_hat_t - - -class _MRNN(nn.Module): - def __init__(self, seq_len, feature_num, rnn_hidden_size, device): - super().__init__() - # data settings - self.seq_len = seq_len - self.feature_num = feature_num - self.rnn_hidden_size = rnn_hidden_size - self.device = device - - self.f_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) - self.b_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) - self.rnn_cells = {"forward": self.f_rnn, "backward": self.b_rnn} - self.concated_hidden_project = nn.Linear( - self.rnn_hidden_size * 2, self.feature_num - ) - self.fcn_regression = FCN_Regression(feature_num, rnn_hidden_size) - - def gene_hidden_states(self, data, direction): - values = data[direction]["X"] - masks = data[direction]["missing_mask"] - deltas = data[direction]["deltas"] - - hidden_states_collector = [] - hidden_state = torch.zeros( - (values.size()[0], self.rnn_hidden_size), device=self.device - ) - - for t in range(self.seq_len): - x = values[:, t, :] - m = masks[:, t, :] - d = deltas[:, t, :] - inputs = torch.cat([x, m, d], dim=1) - hidden_state = self.rnn_cells[direction](inputs, hidden_state) - hidden_states_collector.append(hidden_state) - return hidden_states_collector - - def impute(self, data): - hidden_states_f = self.gene_hidden_states(data, "forward") - hidden_states_b = self.gene_hidden_states(data, "backward")[::-1] - - values = data["forward"]["X"] - masks = data["forward"]["missing_mask"] - - reconstruction_loss = 0 - estimations = [] - for i in range( - self.seq_len - ): # calculating estimation loss for times can obtain better results than once - x = values[:, i, :] - m = masks[:, i, :] - h_f = hidden_states_f[i] - h_b = hidden_states_b[i] - h = torch.cat([h_f, h_b], dim=1) - RNN_estimation = self.concated_hidden_project(h) # xΜƒ_t - RNN_imputed_data = m * x + (1 - m) * RNN_estimation - FCN_estimation = self.fcn_regression( - x, m, RNN_imputed_data - ) # FCN estimation is output estimation - reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse( - RNN_estimation, x, m - ) - estimations.append(FCN_estimation.unsqueeze(dim=1)) - - estimations = torch.cat(estimations, dim=1) - imputed_data = masks * values + (1 - masks) * estimations - return imputed_data, [estimations, reconstruction_loss] - - def forward(self, inputs): - imputed_data, [_, reconstruction_loss] = self.impute(inputs) - reconstruction_loss /= self.seq_len - - # have to cal imputation loss in the val stage; no need to cal imputation loss here in the test stage - imputation_MAE = cal_mae( - imputed_data, inputs["X_holdout"], inputs["indicating_mask"] - ) - - ret_dict = { - "reconstruction_loss": reconstruction_loss, - "imputation_loss": imputation_MAE, - "imputed_data": imputed_data, - } - return ret_dict - - -class MRNN(BaseNNImputer): - def __init__( - self, - n_steps, - n_features, - rnn_hidden_size, - learning_rate=1e-3, - epochs=100, - patience=10, - batch_size=32, - weight_decay=1e-5, - device=None, - ): - super().__init__( - learning_rate, epochs, patience, batch_size, weight_decay, device - ) - - self.n_steps = n_steps - self.n_features = n_features - # model hype-parameters - self.rnn_hidden_size = rnn_hidden_size - - self.model = _MRNN( - self.n_steps, self.n_features, self.rnn_hidden_size, self.device - ) - self.model = self.model.to(self.device) - self._print_model_size() - - def fit(self, train_X, val_X=None): - train_X = self.check_input(self.n_steps, self.n_features, train_X) - if val_X is not None: - val_X = self.check_input(self.n_steps, self.n_features, val_X) - - training_set = DatasetForBRITS(train_X) - training_loader = DataLoader( - training_set, batch_size=self.batch_size, shuffle=True - ) - if val_X is None: - self._train_model(training_loader) - else: - val_X_intact, val_X, val_X_missing_mask, val_X_indicating_mask = mcar( - val_X, 0.2 - ) - val_X = masked_fill(val_X, 1 - val_X_missing_mask, torch.nan) - val_set = DatasetForBRITS(val_X) - val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False) - self._train_model( - training_loader, val_loader, val_X_intact, val_X_indicating_mask - ) - - self.model.load_state_dict(self.best_model_dict) - self.model.eval() # set the model as eval status to freeze it. - - def assemble_input_data(self, data): - """Assemble the input data into a dictionary. - - Parameters - ---------- - data : list - A list containing data fetched from Dataset by Dataload. - - Returns - ------- - inputs : dict - A dictionary with data assembled. - """ - indices, X_intact, X, missing_mask, indicating_mask = data - - inputs = { - "X": X, - "X_intact": X_intact, - "missing_mask": missing_mask, - "indicating_mask": indicating_mask, - } - - return inputs - - def impute(self, X): - X = self.check_input(self.n_steps, self.n_features, X) - self.model.eval() # set the model as eval status to freeze it. - test_set = BaseDataset(X) - test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False) - imputation_collector = [] - - with torch.no_grad(): - for idx, data in enumerate(test_loader): - inputs = {"X": data[1], "missing_mask": data[2]} - imputed_data, _ = self.model.impute(inputs) - imputation_collector.append(imputed_data) - - imputation_collector = torch.cat(imputation_collector) - return imputation_collector.cpu().detach().numpy() diff --git a/pypots/imputation/mrnn/__init__.py b/pypots/imputation/mrnn/__init__.py new file mode 100644 index 00000000..f5934cca --- /dev/null +++ b/pypots/imputation/mrnn/__init__.py @@ -0,0 +1,13 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .model import MRNN + + +__all__ = [ + "MRNN", +] diff --git a/pypots/imputation/mrnn/data.py b/pypots/imputation/mrnn/data.py new file mode 100644 index 00000000..1c8b3f1c --- /dev/null +++ b/pypots/imputation/mrnn/data.py @@ -0,0 +1,46 @@ +""" +Dataset class for model MRNN. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from typing import Union + +from ..brits.data import DatasetForBRITS + + +class DatasetForMRNN(DatasetForBRITS): + """Dataset class for BRITS. + + Parameters + ---------- + data : dict or str, + The dataset for model input, should be a dictionary including keys as 'X' and 'y', + or a path string locating a data file. + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for input, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + return_labels : bool, default = True, + Whether to return labels in function __getitem__() if they exist in the given data. If `True`, for example, + during training of classification models, the Dataset class will return labels in __getitem__() for model input. + Otherwise, labels won't be included in the data returned by __getitem__(). This parameter exists because we + need the defined Dataset class for all training/validating/testing stages. For those big datasets stored in h5 + files, they already have both X and y saved. But we don't read labels from the file for validating and testing + with function _fetch_data_from_file(), which works for all three stages. Therefore, we need this parameter for + distinction. + + file_type : str, default = "h5py" + The type of the given file if train_set and val_set are path strings. + """ + + def __init__( + self, + data: Union[dict, str], + return_labels: bool = True, + file_type: str = "h5py", + ): + super().__init__(data, return_labels, file_type) diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py new file mode 100644 index 00000000..459e5ab1 --- /dev/null +++ b/pypots/imputation/mrnn/model.py @@ -0,0 +1,310 @@ +""" +PyTorch MRNN model for the time-series imputation task. +Some part of the code is from https://github.com/WenjieDu/SAITS. + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from typing import Union, Optional + +import h5py +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader + +from .data import DatasetForMRNN +from .module import FCN_Regression +from ..base import BaseNNImputer +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.metrics import cal_rmse + + +class _MRNN(nn.Module): + def __init__(self, seq_len, feature_num, rnn_hidden_size, device): + super().__init__() + # data settings + self.seq_len = seq_len + self.feature_num = feature_num + self.rnn_hidden_size = rnn_hidden_size + self.device = device + + self.f_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) + self.b_rnn = nn.GRUCell(self.feature_num * 3, self.rnn_hidden_size) + self.rnn_cells = {"forward": self.f_rnn, "backward": self.b_rnn} + self.concated_hidden_project = nn.Linear( + self.rnn_hidden_size * 2, self.feature_num + ) + self.fcn_regression = FCN_Regression(feature_num, rnn_hidden_size) + + def gene_hidden_states(self, data, direction): + values = data[direction]["X"] + masks = data[direction]["missing_mask"] + deltas = data[direction]["deltas"] + + hidden_states_collector = [] + hidden_state = torch.zeros( + (values.size()[0], self.rnn_hidden_size), device=self.device + ) + + for t in range(self.seq_len): + x = values[:, t, :] + m = masks[:, t, :] + d = deltas[:, t, :] + inputs = torch.cat([x, m, d], dim=1) + hidden_state = self.rnn_cells[direction](inputs, hidden_state) + hidden_states_collector.append(hidden_state) + return hidden_states_collector + + def forward(self, inputs, training=True): + hidden_states_f = self.gene_hidden_states(inputs, "forward") + hidden_states_b = self.gene_hidden_states(inputs, "backward")[::-1] + + X = inputs["forward"]["X"] + masks = inputs["forward"]["missing_mask"] + + reconstruction_loss = 0 + estimations = [] + for i in range( + self.seq_len + ): # calculating estimation loss for times can obtain better results than once + x = X[:, i, :] + m = masks[:, i, :] + h_f = hidden_states_f[i] + h_b = hidden_states_b[i] + h = torch.cat([h_f, h_b], dim=1) + RNN_estimation = self.concated_hidden_project(h) # xΜƒ_t + RNN_imputed_data = m * x + (1 - m) * RNN_estimation + FCN_estimation = self.fcn_regression( + x, m, RNN_imputed_data + ) # FCN estimation is output estimation + reconstruction_loss += cal_rmse(FCN_estimation, x, m) + cal_rmse( + RNN_estimation, x, m + ) + estimations.append(FCN_estimation.unsqueeze(dim=1)) + + estimations = torch.cat(estimations, dim=1) + imputed_data = masks * X + (1 - masks) * estimations + + if not training: + # if not in training mode, return the classification result only + return { + "imputed_data": imputed_data, + } + + reconstruction_loss /= self.seq_len + + ret_dict = { + "loss": reconstruction_loss, + "imputed_data": imputed_data, + } + return ret_dict + + +class MRNN(BaseNNImputer): + """The PyTorch implementation of the MRNN model :cite:`yoon2019MRNN`. + + Parameters + ---------- + rnn_hidden_size : + The size of the RNN hidden state, also the number of hidden units in the RNN cell. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + + Attributes + ---------- + model : :class:`torch.nn.Module` + The underlying BRITS model. + + optimizer : :class:`pypots.optim.Optimizer` + The optimizer for model training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + rnn_hidden_size: int, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + self.rnn_hidden_size = rnn_hidden_size + + self.model = _MRNN( + self.n_steps, + self.n_features, + self.rnn_hidden_size, + self.device, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + # fetch data + ( + indices, + X, + missing_mask, + deltas, + back_X, + back_missing_mask, + back_deltas, + ) = self._send_data_to_given_device(data) + + # assemble input data + inputs = { + "indices": indices, + "forward": { + "X": X, + "missing_mask": missing_mask, + "deltas": deltas, + }, + "backward": { + "X": back_X, + "missing_mask": back_missing_mask, + "deltas": back_deltas, + }, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + return self._assemble_input_for_validating(data) + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForMRNN( + train_set, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if isinstance(val_set, str): + with h5py.File(val_set, "r") as hf: + # Here we read the whole validation set from the file to mask a portion for validation. + # In PyPOTS, using a file usually because the data is too big. However, the validation set is + # generally shouldn't be too large. For example, we have 1 billion samples for model training. + # We won't take 20% of them as the validation set because we want as much as possible data for the + # training stage to enhance the model's generalization ability. Therefore, 100,000 representative + # samples will be enough to validate the model. + val_set = { + "X": hf["X"][:], + "X_intact": hf["X_intact"][:], + "indicating_mask": hf["indicating_mask"][:], + } + val_set = DatasetForMRNN(val_set, return_labels=False, file_type=file_type) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(training_finished=True) + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + self.model.eval() # set the model as eval status to freeze it. + test_set = DatasetForMRNN(X, return_labels=False, file_type=file_type) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputed_data = results["imputed_data"] + imputation_collector.append(imputed_data) + + imputation_collector = torch.cat(imputation_collector) + return imputation_collector.cpu().detach().numpy() diff --git a/pypots/imputation/mrnn/module.py b/pypots/imputation/mrnn/module.py new file mode 100644 index 00000000..873d2d73 --- /dev/null +++ b/pypots/imputation/mrnn/module.py @@ -0,0 +1,48 @@ +""" + +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from ...imputation.brits.modules import FeatureRegression + + +class FCN_Regression(nn.Module): + def __init__(self, feature_num, rnn_hid_size): + super(FCN_Regression, self).__init__() + self.feat_reg = FeatureRegression(rnn_hid_size * 2) + self.U = Parameter(torch.Tensor(feature_num, feature_num)) + self.V1 = Parameter(torch.Tensor(feature_num, feature_num)) + self.V2 = Parameter(torch.Tensor(feature_num, feature_num)) + self.beta = Parameter(torch.Tensor(feature_num)) # bias beta + self.final_linear = nn.Linear(feature_num, feature_num) + + m = torch.ones(feature_num, feature_num) - torch.eye(feature_num, feature_num) + self.register_buffer("m", m) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1.0 / math.sqrt(self.U.size(0)) + self.U.data.uniform_(-stdv, stdv) + self.V1.data.uniform_(-stdv, stdv) + self.V2.data.uniform_(-stdv, stdv) + self.beta.data.uniform_(-stdv, stdv) + + def forward(self, x_t, m_t, target): + h_t = F.tanh( + F.linear(x_t, self.U * self.m) + + F.linear(target, self.V1 * self.m) + + F.linear(m_t, self.V2) + + self.beta + ) + x_hat_t = self.final_linear(h_t) + return x_hat_t diff --git a/tests/test_imputation.py b/tests/test_imputation.py index 24929b61..6094ce62 100644 --- a/tests/test_imputation.py +++ b/tests/test_imputation.py @@ -16,6 +16,7 @@ SAITS, Transformer, BRITS, + MRNN, LOCF, ) from pypots.optim import Adam @@ -262,6 +263,75 @@ def test_3_saving_path(self): self.brits.load_model(saved_model_path) +class TestMRNN(unittest.TestCase): + logger.info("Running tests for an imputation model MRNN...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "MRNN") + model_save_name = "saved_MRNN_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a MRNN model + mrnn = MRNN( + DATA["n_steps"], + DATA["n_features"], + 256, + epochs=EPOCH, + saving_path=f"{RESULT_SAVING_DIR_FOR_IMPUTATION}/MRNN", + optimizer=optimizer, + ) + + @pytest.mark.xdist_group(name="imputation-mrnn") + def test_0_fit(self): + self.mrnn.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-mrnn") + def test_1_impute(self): + imputed_X = self.mrnn.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"MRNN test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="imputation-mrnn") + def test_2_parameters(self): + assert hasattr(self.mrnn, "model") and self.mrnn.model is not None + + assert hasattr(self.mrnn, "optimizer") and self.mrnn.optimizer is not None + + assert hasattr(self.mrnn, "best_loss") + self.assertNotEqual(self.mrnn.best_loss, float("inf")) + + assert ( + hasattr(self.mrnn, "best_model_dict") + and self.mrnn.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-mrnn") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.mrnn) + + # save the trained model into file, and check if the path exists + self.mrnn.save_model( + saving_dir=self.saving_path, file_name=self.model_save_name + ) + + # test loading the saved model, not necessary, but need to test + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.mrnn.load_model(saved_model_path) + + class TestLOCF(unittest.TestCase): logger.info("Running tests for an imputation model LOCF...") locf = LOCF(nan=0) From 35a21fed1b7026bcbeeddb86ad5229882d7d4039 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 22 May 2023 00:13:09 +0800 Subject: [PATCH 07/10] docs: add the citation of MRNN; --- docs/references.bib | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/references.bib b/docs/references.bib index 050110c1..214b69b8 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -12,6 +12,18 @@ @article{cao2018BRITS keywords = {Computer Science - Machine Learning,Statistics - Machine Learning} } +@ARTICLE{yoon2019MRNN, +author={Yoon, Jinsung and Zame, William R. and van der Schaar, Mihaela}, +journal={IEEE Transactions on Biomedical Engineering}, +title={Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks}, +year={2019}, +volume={66}, +number={5}, +pages={1477-1490}, +doi={10.1109/TBME.2018.2874712} +} + + @article{che2018GRUD, title = {Recurrent {{Neural Networks}} for {{Multivariate Time Series}} with {{Missing Values}}}, author = {Che, Zhengping and Purushotham, Sanjay and Cho, Kyunghyun and Sontag, David and Liu, Yan}, From e321b9a3fb8d105e4426754e58ebb8267fb5451c Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 22 May 2023 00:38:46 +0800 Subject: [PATCH 08/10] docs: add docs for MRNN; --- docs/pypots.imputation.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/pypots.imputation.rst b/docs/pypots.imputation.rst index 50ff9d11..0e31f8c8 100644 --- a/docs/pypots.imputation.rst +++ b/docs/pypots.imputation.rst @@ -28,6 +28,15 @@ pypots.imputation.brits module :show-inheritance: :inherited-members: +pypots.imputation.mrnn module +------------------------------ + +.. automodule:: pypots.imputation.mrnn + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + pypots.imputation.locf module ----------------------------- From 69ec426c3433dae3f3412f75caad420977e283f9 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 22 May 2023 00:39:18 +0800 Subject: [PATCH 09/10] release v0.1.1; --- pypots/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypots/__init__.py b/pypots/__init__.py index 8cedd5b4..b7a737b6 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.1.0" +__version__ = "0.1.1" __all__ = [ From f67ec08b35e29ae41c865eb4aba151ece9bacf82 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 22 May 2023 00:52:38 +0800 Subject: [PATCH 10/10] docs: update README; --- README.md | 41 ++++++++++++++++++++++------------------- docs/index.rst | 2 ++ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 13669944..69fdb951 100644 --- a/README.md +++ b/README.md @@ -146,25 +146,26 @@ mae = cal_mae(imputation, X_intact, indicating_mask) # calculate mean absolute ## ❖ Available Algorithms PyPOTS supports imputation, classification, clustering, and forecasting tasks on multivariate time series with missing values. The currently available algorithms of four tasks are cataloged in the following table with four partitions. The paper references are all listed at the bottom of this readme file. Please refer to them if you want more details. -| ***`Imputation`*** | πŸš₯ | πŸš₯ | πŸš₯ | -|:----------------------:|:------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:| -| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | -| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 | -| Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 | -| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | -| Naive | LOCF | Last Observation Carried Forward | - | -| ***`Classification`*** | πŸš₯ | πŸš₯ | πŸš₯ | -| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | -| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | -| Neural Net | GRU-D | Recurrent Neural Networks for Multivariate Time Series with Missing Values [^4] | 2018 | -| Neural Net | Raindrop | Graph-Guided Network for Irregularly Sampled Multivariate Time Series [^5] | 2022 | -| ***`Clustering`*** | πŸš₯ | πŸš₯ | πŸš₯ | -| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | -| Neural Net | CRLI | Clustering Representation Learning on Incomplete time-series data [^6] | 2021 | -| Neural Net | VaDER | Variational Deep Embedding with Recurrence [^7] | 2019 | -| ***`Forecasting`*** | πŸš₯ | πŸš₯ | πŸš₯ | -| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | -| Probabilistic | BTTF | Bayesian Temporal Tensor Factorization [^8] | 2021 | +| ***`Imputation`*** | πŸš₯ | πŸš₯ | πŸš₯ | +|:----------------------:|:-----------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:| +| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | +| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 | +| Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 | +| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | +| Neural Net | M-RNN | Multi-directional Recurrent Neural Network [^9] | 2019 | +| Naive | LOCF | Last Observation Carried Forward | - | +| ***`Classification`*** | πŸš₯ | πŸš₯ | πŸš₯ | +| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | +| Neural Net | BRITS | Bidirectional Recurrent Imputation for Time Series [^3] | 2018 | +| Neural Net | GRU-D | Recurrent Neural Networks for Multivariate Time Series with Missing Values [^4] | 2018 | +| Neural Net | Raindrop | Graph-Guided Network for Irregularly Sampled Multivariate Time Series [^5] | 2022 | +| ***`Clustering`*** | πŸš₯ | πŸš₯ | πŸš₯ | +| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | +| Neural Net | CRLI | Clustering Representation Learning on Incomplete time-series data [^6] | 2021 | +| Neural Net | VaDER | Variational Deep Embedding with Recurrence [^7] | 2019 | +| ***`Forecasting`*** | πŸš₯ | πŸš₯ | πŸš₯ | +| **Type** | **Abbr.** | **Full name of the algorithm/model/paper** | **Year** | +| Probabilistic | BTTF | Bayesian Temporal Tensor Factorization [^8] | 2021 | ## ❖ Citing PyPOTS @@ -254,6 +255,8 @@ Thank you all for your attention! πŸ˜ƒ [^6]: Ma, Q., Chen, C., Li, S., & Cottrell, G. W. (2021). [Learning Representations for Incomplete Time Series Clustering](https://ojs.aaai.org/index.php/AAAI/article/view/17070). *AAAI 2021*. [^7]: Jong, J.D., Emon, M.A., Wu, P., Karki, R., Sood, M., Godard, P., Ahmad, A., Vrooman, H.A., Hofmann-Apitius, M., & FrΓΆhlich, H. (2019). [Deep learning for clustering of multivariate clinical patient trajectories with missing values](https://academic.oup.com/gigascience/article/8/11/giz134/5626377). *GigaScience*. [^8]: Chen, X., & Sun, L. (2021). [Bayesian Temporal Factorization for Multidimensional Time Series Prediction](https://arxiv.org/abs/1910.06366). *IEEE transactions on pattern analysis and machine intelligence*. +[^9]: Yoon, J., Zame, W. R., & van der Schaar, M. (2019). [Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks](https://ieeexplore.ieee.org/document/8485748). *IEEE Transactions on Biomedical Engineering*. +
🏠 Visits diff --git a/docs/index.rst b/docs/index.rst index 6c264ff9..29411c15 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -128,6 +128,7 @@ Task Type Algorithm Imputation Neural Network SAITS (Self-Attention-based Imputation for Time Series) 2022 :cite:`du2023SAITS` Imputation Neural Network Transformer 2017 :cite:`vaswani2017Transformer`, :cite:`du2023SAITS` Imputation, Classification Neural Network BRITS (Bidirectional Recurrent Imputation for Time Series) 2018 :cite:`cao2018BRITS` +Imputation Neural Network M-RNN (Multi-directional Recurrent Neural Network) 2019 :cite:`yoon2019MRNN` Imputation Naive LOCF (Last Observation Carried Forward) / / Classification Neural Network GRU-D 2018 :cite:`che2018GRUD` Classification Neural Network Raindrop 2022 :cite:`zhang2022Raindrop` @@ -136,6 +137,7 @@ Clustering Neural Network VaDER (Variational Deep Embeddin Forecasting Probabilistic BTTF (Bayesian Temporal Tensor Factorization) 2021 :cite:`chen2021BTMF` ============================== ================ ========================================================================= ====== ========= +[^9]: Yoon, J., Zame, W. R., & van der Schaar, M. (2019). [Estimating Missing Data in Temporal Data Streams Using Multi-Directional Recurrent Neural Networks](https://ieeexplore.ieee.org/document/8485748). *IEEE Transactions on Biomedical Engineering*. ❖ Citing PyPOTS ^^^^^^^^^^^^^^^^