diff --git a/README.md b/README.md
index d2c4d275..6a667be9 100644
--- a/README.md
+++ b/README.md
@@ -190,6 +190,7 @@ This functionality is implemented with the [Microsoft NNI](https://github.com/mi
| **Type** | **Abbr.** | **Full name of the algorithm/model** | **Year** |
| Neural Net | SAITS | Self-Attention-based Imputation for Time Series [^1] | 2023 |
| Neural Net | Transformer | Attention is All you Need [^2];
Self-Attention-based Imputation for Time Series [^1];
Note: proposed in [^2], and re-implemented as an imputation model in [^1]. | 2017 |
+| Neural Net | TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis [^14] | 2023 |
| Neural Net | CSDI | Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation [^12] | 2021 |
| Neural Net | US-GAN | Unsupervised GAN for Multivariate Time Series Imputation [^10] | 2021 |
| Neural Net | GP-VAE | Gaussian Process Variational Autoencoder [^11] | 2020 |
@@ -302,6 +303,7 @@ PyPOTS community is open, transparent, and surely friendly. Let's work together
[^11]: Fortuin, V., Baranchuk, D., Raetsch, G. & Mandt, S. (2020). [GP-VAE: Deep Probabilistic Time Series Imputation](https://proceedings.mlr.press/v108/fortuin20a.html). *AISTATS 2020*.
[^12]: Tashiro, Y., Song, J., Song, Y., & Ermon, S. (2021). [CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation](https://proceedings.neurips.cc/paper/2021/hash/cfe8504bda37b575c70ee1a8276f3486-Abstract.html). *NeurIPS 2021*.
[^13]: Rubin, D. B. (1976). [Inference and missing data](https://academic.oup.com/biomet/article-abstract/63/3/581/270932). *Biometrika*.
+[^14]: Wu, H., Hu, T., Liu, Y., Zhou, H., Wang, J., & Long, M. (2023). [TimesNet: Temporal 2d-variation modeling for general time series analysis](https://openreview.net/forum?id=ju_Uqw384Oq). *ICLR 2023*
diff --git a/docs/index.rst b/docs/index.rst
index aa4c925a..79c73441 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -167,8 +167,9 @@ This functionality is implemented with the `Microsoft NNI dict:
ret_b = self._reverse(self.rits_b(inputs, "backward"))
classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2
- if not training:
- # if not in training mode, return the classification result only
- return {"classification_pred": classification_pred}
-
- ret_f["classification_loss"] = F.nll_loss(
- torch.log(ret_f["prediction"]), inputs["label"]
- )
- ret_b["classification_loss"] = F.nll_loss(
- torch.log(ret_b["prediction"]), inputs["label"]
- )
- consistency_loss = self._get_consistency_loss(
- ret_f["imputed_data"], ret_b["imputed_data"]
- )
- classification_loss = (
- ret_f["classification_loss"] + ret_b["classification_loss"]
- ) / 2
- reconstruction_loss = (
- ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"]
- ) / 2
-
- loss = (
- consistency_loss
- + reconstruction_loss * self.reconstruction_weight
- + classification_loss * self.classification_weight
- )
-
- results = {
- "classification_pred": classification_pred,
- "consistency_loss": consistency_loss,
- "classification_loss": classification_loss,
- "reconstruction_loss": reconstruction_loss,
- "loss": loss,
- }
+ results = {"classification_pred": classification_pred}
+
+ # if in training mode, return results with losses
+ if training:
+ ret_f["classification_loss"] = F.nll_loss(
+ torch.log(ret_f["prediction"]), inputs["label"]
+ )
+ ret_b["classification_loss"] = F.nll_loss(
+ torch.log(ret_b["prediction"]), inputs["label"]
+ )
+ consistency_loss = self._get_consistency_loss(
+ ret_f["imputed_data"], ret_b["imputed_data"]
+ )
+ classification_loss = (
+ ret_f["classification_loss"] + ret_b["classification_loss"]
+ ) / 2
+ reconstruction_loss = (
+ ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"]
+ ) / 2
+
+ results["consistency_loss"] = consistency_loss
+ results["classification_loss"] = classification_loss
+ results["reconstruction_loss"] = reconstruction_loss
+
+ # `loss` is always the item for backward propagating to update the model
+ loss = (
+ consistency_loss
+ + reconstruction_loss * self.reconstruction_weight
+ + classification_loss * self.classification_weight
+ )
+ results["loss"] = loss
+
return results
diff --git a/pypots/classification/grud/modules/core.py b/pypots/classification/grud/modules/core.py
index fe528655..c1b873f4 100644
--- a/pypots/classification/grud/modules/core.py
+++ b/pypots/classification/grud/modules/core.py
@@ -91,18 +91,14 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
logits = self.classifier(hidden_state)
classification_pred = torch.softmax(logits, dim=1)
+ results = {"classification_pred": classification_pred}
- if not training:
- # if not in training mode, return the classification result only
- return {"classification_pred": classification_pred}
+ # if in training mode, return results with losses
+ if training:
+ torch.log(classification_pred)
+ classification_loss = F.nll_loss(
+ torch.log(classification_pred), inputs["label"]
+ )
+ results["loss"] = classification_loss
- torch.log(classification_pred)
- classification_loss = F.nll_loss(
- torch.log(classification_pred), inputs["label"]
- )
-
- results = {
- "classification_pred": classification_pred,
- "loss": classification_loss,
- }
return results
diff --git a/pypots/classification/raindrop/modules/core.py b/pypots/classification/raindrop/modules/core.py
index 798ff5e6..68ba1f81 100644
--- a/pypots/classification/raindrop/modules/core.py
+++ b/pypots/classification/raindrop/modules/core.py
@@ -262,18 +262,13 @@ def classify(self, inputs: dict) -> torch.Tensor:
def forward(self, inputs, training=True):
classification_pred = self.classify(inputs)
- if not training:
- # if not in training mode, return the classification result only
- return {"classification_pred": classification_pred}
+ results = {"classification_pred": classification_pred}
- classification_loss = F.nll_loss(
- torch.log(classification_pred), inputs["label"]
- )
-
- results = {
- "prediction": classification_pred,
- "loss": classification_loss
- # 'distance': distance,
- }
+ # if in training mode, return results with losses
+ if training:
+ classification_loss = F.nll_loss(
+ torch.log(classification_pred), inputs["label"]
+ )
+ results["loss"] = classification_loss
return results
diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py
index c59d840e..2aa77239 100644
--- a/pypots/clustering/crli/model.py
+++ b/pypots/clustering/crli/model.py
@@ -270,7 +270,7 @@ def _train_model(
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
- results = self.model.forward(inputs, return_loss=True)
+ results = self.model.forward(inputs, training=True)
epoch_val_loss_G_collector.append(
results["generation_loss"].sum().item()
)
@@ -424,7 +424,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
- inputs = self.model.forward(inputs, return_loss=False)
+ inputs = self.model.forward(inputs, training=False)
clustering_latent_collector.append(inputs["fcn_latent"])
if return_latent_vars:
imputation_collector.append(inputs["imputation_latent"])
diff --git a/pypots/clustering/crli/modules/core.py b/pypots/clustering/crli/modules/core.py
index 3d1cddd4..a4c16a2a 100644
--- a/pypots/clustering/crli/modules/core.py
+++ b/pypots/clustering/crli/modules/core.py
@@ -58,7 +58,7 @@ def forward(
self,
inputs: dict,
training_object: str = "generator",
- return_loss: bool = True,
+ training: bool = True,
) -> dict:
X = inputs["X"]
missing_mask = inputs["missing_mask"]
@@ -76,7 +76,7 @@ def forward(
inputs["fcn_latent"] = fcn_latent
# return results directly, skip loss calculation to reduce inference time
- if not return_loss:
+ if not training:
return inputs
if training_object == "discriminator":
@@ -106,4 +106,5 @@ def forward(
l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss
loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans
losses["generation_loss"] = loss_gene
+
return losses
diff --git a/pypots/clustering/vader/modules/core.py b/pypots/clustering/vader/modules/core.py
index 4ffea516..8ff2f4ac 100644
--- a/pypots/clustering/vader/modules/core.py
+++ b/pypots/clustering/vader/modules/core.py
@@ -173,18 +173,15 @@ def forward(
stddev_tilde,
) = self.get_results(X, missing_mask)
- if not training and not pretrain:
- results = {
- "mu_tilde": mu_tilde,
- "stddev_tilde": stddev_tilde,
- "mu": mu_c,
- "var": var_c,
- "phi": phi_c,
- "z": z,
- "imputation_latent": X_reconstructed,
- }
- # if only run clustering, then no need to calculate loss
- return results
+ results = {
+ "mu_tilde": mu_tilde,
+ "stddev_tilde": stddev_tilde,
+ "mu": mu_c,
+ "var": var_c,
+ "phi": phi_c,
+ "z": z,
+ "imputation_latent": X_reconstructed,
+ }
# calculate the reconstruction loss
unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask)
@@ -194,66 +191,68 @@ def forward(
* self.d_input
/ missing_mask.sum()
)
+
if pretrain:
- results = {"loss": reconstruction_loss, "z": z}
+ results["loss"] = reconstruction_loss
return results
- # calculate the latent loss
- var_tilde = torch.exp(stddev_tilde)
- stddev_c = torch.log(var_c + self.eps)
- log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device))
- log_phi_c = torch.log(phi_c + self.eps)
-
- batch_size = z.shape[0]
-
- ii, jj = torch.meshgrid(
- torch.arange(self.n_clusters, dtype=torch.int64, device=device),
- torch.arange(batch_size, dtype=torch.int64, device=device),
- indexing="ij",
- )
- ii = ii.flatten()
- jj = jj.flatten()
-
- lsc_b = stddev_c.index_select(dim=0, index=ii)
- mc_b = mu_c.index_select(dim=0, index=ii)
- sc_b = var_c.index_select(dim=0, index=ii)
- z_b = z.index_select(dim=0, index=jj)
- log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b)
- log_pdf_z = log_pdf_z.reshape([batch_size, self.n_clusters, self.d_mu_stddev])
-
- log_p = log_phi_c + log_pdf_z.sum(dim=2)
- lse_p = log_p.logsumexp(dim=1, keepdim=True)
- log_gamma_c = log_p - lse_p
- gamma_c = torch.exp(log_gamma_c)
-
- term1 = torch.log(var_c + self.eps)
- st_b = var_tilde.index_select(dim=0, index=jj)
- sc_b = var_c.index_select(dim=0, index=ii)
- term2 = torch.reshape(
- st_b / (sc_b + self.eps), [batch_size, self.n_clusters, self.d_mu_stddev]
- )
- mt_b = mu_tilde.index_select(dim=0, index=jj)
- mc_b = mu_c.index_select(dim=0, index=ii)
- term3 = torch.reshape(
- torch.square(mt_b - mc_b) / (sc_b + self.eps),
- [batch_size, self.n_clusters, self.d_mu_stddev],
- )
-
- latent_loss1 = 0.5 * torch.sum(
- gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1
- )
- latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1)
- latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1)
-
- latent_loss1 = latent_loss1.mean()
- latent_loss2 = latent_loss2.mean()
- latent_loss3 = latent_loss3.mean()
- latent_loss = latent_loss1 + latent_loss2 + latent_loss3
-
- results = {
- "loss": reconstruction_loss + self.alpha * latent_loss,
- "z": z,
- "imputation_latent": X_reconstructed,
- }
+ # if in training mode, return results with losses
+ if training:
+ # calculate the latent loss for model training
+ var_tilde = torch.exp(stddev_tilde)
+ stddev_c = torch.log(var_c + self.eps)
+ log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device))
+ log_phi_c = torch.log(phi_c + self.eps)
+
+ batch_size = z.shape[0]
+
+ ii, jj = torch.meshgrid(
+ torch.arange(self.n_clusters, dtype=torch.int64, device=device),
+ torch.arange(batch_size, dtype=torch.int64, device=device),
+ indexing="ij",
+ )
+ ii = ii.flatten()
+ jj = jj.flatten()
+
+ lsc_b = stddev_c.index_select(dim=0, index=ii)
+ mc_b = mu_c.index_select(dim=0, index=ii)
+ sc_b = var_c.index_select(dim=0, index=ii)
+ z_b = z.index_select(dim=0, index=jj)
+ log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b)
+ log_pdf_z = log_pdf_z.reshape(
+ [batch_size, self.n_clusters, self.d_mu_stddev]
+ )
+
+ log_p = log_phi_c + log_pdf_z.sum(dim=2)
+ lse_p = log_p.logsumexp(dim=1, keepdim=True)
+ log_gamma_c = log_p - lse_p
+ gamma_c = torch.exp(log_gamma_c)
+
+ term1 = torch.log(var_c + self.eps)
+ st_b = var_tilde.index_select(dim=0, index=jj)
+ sc_b = var_c.index_select(dim=0, index=ii)
+ term2 = torch.reshape(
+ st_b / (sc_b + self.eps),
+ [batch_size, self.n_clusters, self.d_mu_stddev],
+ )
+ mt_b = mu_tilde.index_select(dim=0, index=jj)
+ mc_b = mu_c.index_select(dim=0, index=ii)
+ term3 = torch.reshape(
+ torch.square(mt_b - mc_b) / (sc_b + self.eps),
+ [batch_size, self.n_clusters, self.d_mu_stddev],
+ )
+
+ latent_loss1 = 0.5 * torch.sum(
+ gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1
+ )
+ latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1)
+ latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1)
+
+ latent_loss1 = latent_loss1.mean()
+ latent_loss2 = latent_loss2.mean()
+ latent_loss3 = latent_loss3.mean()
+ latent_loss = latent_loss1 + latent_loss2 + latent_loss3
+
+ results["loss"] = reconstruction_loss + self.alpha * latent_loss
return results
diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py
index 44bb1e54..f065f0f9 100644
--- a/pypots/imputation/__init__.py
+++ b/pypots/imputation/__init__.py
@@ -6,17 +6,19 @@
# License: BSD-3-Clause
from .brits import BRITS
+from .csdi import CSDI
from .gpvae import GPVAE
from .locf import LOCF
from .mrnn import MRNN
from .saits import SAITS
+from .timesnet import TimesNet
from .transformer import Transformer
from .usgan import USGAN
-from .csdi import CSDI
__all__ = [
"SAITS",
"Transformer",
+ "TimesNet",
"BRITS",
"MRNN",
"LOCF",
diff --git a/pypots/imputation/brits/modules/core.py b/pypots/imputation/brits/modules/core.py
index 31f3fc43..83b48f95 100644
--- a/pypots/imputation/brits/modules/core.py
+++ b/pypots/imputation/brits/modules/core.py
@@ -316,27 +316,23 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2
- if not training:
- # if not in training mode, return the classification result only
- return {
- "imputed_data": imputed_data,
- }
-
- consistency_loss = self._get_consistency_loss(
- ret_f["imputed_data"], ret_b["imputed_data"]
- )
-
- # `loss` is always the item for backward propagating to update the model
- loss = (
- consistency_loss
- + ret_f["reconstruction_loss"]
- + ret_b["reconstruction_loss"]
- )
-
results = {
"imputed_data": imputed_data,
- "consistency_loss": consistency_loss,
- "loss": loss, # will be used for backward propagating to update the model
}
+ # if in training mode, return results with losses
+ if training:
+ consistency_loss = self._get_consistency_loss(
+ ret_f["imputed_data"], ret_b["imputed_data"]
+ )
+
+ # `loss` is always the item for backward propagating to update the model
+ loss = (
+ consistency_loss
+ + ret_f["reconstruction_loss"]
+ + ret_b["reconstruction_loss"]
+ )
+ results["consistency_loss"] = consistency_loss
+ results["loss"] = loss
+
return results
diff --git a/pypots/imputation/csdi/modules/core.py b/pypots/imputation/csdi/modules/core.py
index 12a03cf0..b25fd190 100644
--- a/pypots/imputation/csdi/modules/core.py
+++ b/pypots/imputation/csdi/modules/core.py
@@ -65,6 +65,10 @@ def __init__(
)
elif schedule == "linear":
self.beta = np.linspace(beta_start, beta_end, self.n_diffusion_steps)
+ else:
+ raise ValueError(
+ f"The argument schedule should be 'quad' or 'linear', but got {schedule}"
+ )
self.alpha_hat = 1 - self.beta
self.alpha = np.cumprod(self.alpha_hat)
@@ -72,7 +76,8 @@ def __init__(
"alpha_torch", torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1)
)
- def time_embedding(self, pos, d_model=128):
+ @staticmethod
+ def time_embedding(pos, d_model=128):
pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(pos.device)
position = pos.unsqueeze(2)
div_term = 1 / torch.pow(
@@ -82,7 +87,8 @@ def time_embedding(self, pos, d_model=128):
pe[:, :, 1::2] = torch.cos(position * div_term)
return pe
- def get_randmask(self, observed_mask):
+ @staticmethod
+ def get_rand_mask(observed_mask):
rand_for_mask = torch.rand_like(observed_mask) * observed_mask
rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1)
for i in range(len(observed_mask)):
@@ -97,14 +103,14 @@ def get_hist_mask(self, observed_mask, for_pattern_mask=None):
if for_pattern_mask is None:
for_pattern_mask = observed_mask
if self.target_strategy == "mix":
- rand_mask = self.get_randmask(observed_mask)
+ rand_mask = self.get_rand_mask(observed_mask)
cond_mask = observed_mask.clone()
for i in range(len(cond_mask)):
mask_choice = np.random.rand()
if self.target_strategy == "mix" and mask_choice > 0.5:
cond_mask[i] = rand_mask[i]
- else: # draw another sample for histmask (i-1 corresponds to another sample)
+ else: # draw another sample for hist mask (i-1 corresponds to another sample)
cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1]
return cond_mask
@@ -241,23 +247,22 @@ def forward(self, inputs, training=True, n_sampling_times=1):
observed_mask, for_pattern_mask=for_pattern_mask
)
else:
- cond_mask = self.get_randmask(observed_mask)
+ cond_mask = self.get_rand_mask(observed_mask)
side_info = self.get_side_info(observed_tp, cond_mask)
- loss_func = self.calc_loss if training == 1 else self.calc_loss_valid
+ loss_func = self.calc_loss if training else self.calc_loss_valid
# `loss` is always the item for backward propagating to update the model
loss = loss_func(observed_data, cond_mask, observed_mask, side_info, training)
+ results = {"loss": loss}
- results = {
- "loss": loss, # will be used for backward propagating to update the model
- }
if not training:
samples = self.impute(
observed_data, cond_mask, side_info, n_sampling_times
- ) # (B,nsample,K,L)
+ ) # (B,bz,K,L)
imputation = samples.mean(dim=1) # (B,K,L)
imputed_data = observed_data + imputation * (1 - gt_mask)
results["imputed_data"] = imputed_data.permute(0, 2, 1) # (B,L,K)
+
return results
diff --git a/pypots/imputation/csdi/modules/submodules.py b/pypots/imputation/csdi/modules/submodules.py
index be197643..31e71fff 100644
--- a/pypots/imputation/csdi/modules/submodules.py
+++ b/pypots/imputation/csdi/modules/submodules.py
@@ -19,7 +19,7 @@ def get_torch_trans(heads=8, layers=1, channels=64):
return nn.TransformerEncoder(encoder_layer, num_layers=layers)
-def Conv1d_with_init(in_channels, out_channels, kernel_size):
+def conv1d_with_init(in_channels, out_channels, kernel_size):
layer = nn.Conv1d(in_channels, out_channels, kernel_size)
nn.init.kaiming_normal_(layer.weight)
return layer
@@ -46,7 +46,8 @@ def forward(self, diffusion_step):
x = F.silu(x)
return x
- def _build_embedding(self, n_steps, d_embedding=64):
+ @staticmethod
+ def _build_embedding(n_steps, d_embedding=64):
steps = torch.arange(n_steps).unsqueeze(1) # (T,1)
frequencies = 10.0 ** (
torch.arange(d_embedding) / (d_embedding - 1) * 4.0
@@ -62,9 +63,9 @@ class ResidualBlock(nn.Module):
def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads):
super().__init__()
self.diffusion_projection = nn.Linear(diffusion_embedding_dim, n_channels)
- self.cond_projection = Conv1d_with_init(d_side, 2 * n_channels, 1)
- self.mid_projection = Conv1d_with_init(n_channels, 2 * n_channels, 1)
- self.output_projection = Conv1d_with_init(n_channels, 2 * n_channels, 1)
+ self.cond_projection = conv1d_with_init(d_side, 2 * n_channels, 1)
+ self.mid_projection = conv1d_with_init(n_channels, 2 * n_channels, 1)
+ self.output_projection = conv1d_with_init(n_channels, 2 * n_channels, 1)
self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=n_channels)
self.feature_layer = get_torch_trans(
@@ -135,9 +136,9 @@ def __init__(
n_diffusion_steps=n_diffusion_steps,
d_embedding=d_diffusion_embedding,
)
- self.input_projection = Conv1d_with_init(d_input, n_channels, 1)
- self.output_projection1 = Conv1d_with_init(n_channels, n_channels, 1)
- self.output_projection2 = Conv1d_with_init(n_channels, 1, 1)
+ self.input_projection = conv1d_with_init(d_input, n_channels, 1)
+ self.output_projection1 = conv1d_with_init(n_channels, n_channels, 1)
+ self.output_projection2 = conv1d_with_init(n_channels, 1, 1)
nn.init.zeros_(self.output_projection2.weight)
self.residual_layers = nn.ModuleList(
diff --git a/pypots/imputation/gpvae/modules/core.py b/pypots/imputation/gpvae/modules/core.py
index 935daca0..7a37ffde 100644
--- a/pypots/imputation/gpvae/modules/core.py
+++ b/pypots/imputation/gpvae/modules/core.py
@@ -109,6 +109,53 @@ def decode(self, z):
assert num_dim > 2
return self.decoder(torch.transpose(z, num_dim - 1, num_dim - 2))
+ @staticmethod
+ def kl_divergence(a, b):
+ return torch.distributions.kl.kl_divergence(a, b)
+
+ def _init_prior(self, device="cpu"):
+ # Compute kernel matrices for each latent dimension
+ kernel_matrices = []
+ for i in range(self.kernel_scales):
+ if self.kernel == "rbf":
+ kernel_matrices.append(
+ rbf_kernel(self.time_length, self.length_scale / 2**i)
+ )
+ elif self.kernel == "diffusion":
+ kernel_matrices.append(
+ diffusion_kernel(self.time_length, self.length_scale / 2**i)
+ )
+ elif self.kernel == "matern":
+ kernel_matrices.append(
+ matern_kernel(self.time_length, self.length_scale / 2**i)
+ )
+ elif self.kernel == "cauchy":
+ kernel_matrices.append(
+ cauchy_kernel(
+ self.time_length, self.sigma, self.length_scale / 2**i
+ )
+ )
+
+ # Combine kernel matrices for each latent dimension
+ tiled_matrices = []
+ total = 0
+ for i in range(self.kernel_scales):
+ if i == self.kernel_scales - 1:
+ multiplier = self.latent_dim - total
+ else:
+ multiplier = int(np.ceil(self.latent_dim / self.kernel_scales))
+ total += multiplier
+ tiled_matrices.append(
+ torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1)
+ )
+ kernel_matrix_tiled = torch.cat(tiled_matrices)
+ assert len(kernel_matrix_tiled) == self.latent_dim
+ prior = torch.distributions.MultivariateNormal(
+ loc=torch.zeros(self.latent_dim, self.time_length, device=device),
+ covariance_matrix=kernel_matrix_tiled.to(device),
+ )
+ return prior
+
def forward(self, inputs, training=True):
x = inputs["X"]
m_mask = inputs["missing_mask"]
@@ -151,63 +198,12 @@ def forward(self, inputs, training=True):
elbo = elbo.mean()
imputed_data = self.decode(self.encode(x).mean).mean * ~m_mask + x * m_mask
-
- if not training:
- # if not in training mode, return the classification result only
- return {
- "imputed_data": imputed_data,
- }
-
results = {
- "loss": -elbo.mean(),
"imputed_data": imputed_data,
}
- return results
-
- @staticmethod
- def kl_divergence(a, b):
- return torch.distributions.kl.kl_divergence(a, b)
-
- def _init_prior(self, device="cpu"):
- # Compute kernel matrices for each latent dimension
- kernel_matrices = []
- for i in range(self.kernel_scales):
- if self.kernel == "rbf":
- kernel_matrices.append(
- rbf_kernel(self.time_length, self.length_scale / 2**i)
- )
- elif self.kernel == "diffusion":
- kernel_matrices.append(
- diffusion_kernel(self.time_length, self.length_scale / 2**i)
- )
- elif self.kernel == "matern":
- kernel_matrices.append(
- matern_kernel(self.time_length, self.length_scale / 2**i)
- )
- elif self.kernel == "cauchy":
- kernel_matrices.append(
- cauchy_kernel(
- self.time_length, self.sigma, self.length_scale / 2**i
- )
- )
- # Combine kernel matrices for each latent dimension
- tiled_matrices = []
- total = 0
- for i in range(self.kernel_scales):
- if i == self.kernel_scales - 1:
- multiplier = self.latent_dim - total
- else:
- multiplier = int(np.ceil(self.latent_dim / self.kernel_scales))
- total += multiplier
- tiled_matrices.append(
- torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1)
- )
- kernel_matrix_tiled = torch.cat(tiled_matrices)
- assert len(kernel_matrix_tiled) == self.latent_dim
- prior = torch.distributions.MultivariateNormal(
- loc=torch.zeros(self.latent_dim, self.time_length, device=device),
- covariance_matrix=kernel_matrix_tiled.to(device),
- )
+ # if in training mode, return results with losses
+ if training:
+ results["loss"] = -elbo.mean()
- return prior
+ return results
diff --git a/pypots/imputation/mrnn/modules/core.py b/pypots/imputation/mrnn/modules/core.py
index 865b60b2..e4936ec8 100644
--- a/pypots/imputation/mrnn/modules/core.py
+++ b/pypots/imputation/mrnn/modules/core.py
@@ -52,7 +52,7 @@ def gene_hidden_states(self, inputs, direction):
hidden_states_collector.append(hidden_state)
return hidden_states_collector
- def forward(self, inputs, training=True):
+ def forward(self, inputs: dict, training: bool = True) -> dict:
hidden_states_f = self.gene_hidden_states(inputs, "forward")
hidden_states_b = self.gene_hidden_states(inputs, "backward")[::-1]
@@ -82,16 +82,13 @@ def forward(self, inputs, training=True):
estimations = torch.cat(estimations, dim=1)
imputed_data = masks * X + (1 - masks) * estimations
- if not training:
- # if not in training mode, return the classification result only
- return {
- "imputed_data": imputed_data,
- }
-
- reconstruction_loss /= self.seq_len
-
- ret_dict = {
- "loss": reconstruction_loss,
+ results = {
"imputed_data": imputed_data,
}
- return ret_dict
+
+ # if in training mode, return results with losses
+ if training:
+ reconstruction_loss /= self.seq_len
+ results["loss"] = reconstruction_loss
+
+ return results
diff --git a/pypots/imputation/mrnn/modules/submodules.py b/pypots/imputation/mrnn/modules/submodules.py
index f157d3fb..b61eb2f2 100644
--- a/pypots/imputation/mrnn/modules/submodules.py
+++ b/pypots/imputation/mrnn/modules/submodules.py
@@ -38,7 +38,7 @@ def reset_parameters(self):
self.beta.data.uniform_(-stdv, stdv)
def forward(self, x_t, m_t, target):
- h_t = F.tanh(
+ h_t = torch.tanh(
F.linear(x_t, self.U * self.m)
+ F.linear(target, self.V1 * self.m)
+ F.linear(m_t, self.V2)
diff --git a/pypots/imputation/saits/modules/core.py b/pypots/imputation/saits/modules/core.py
index 1f7ef721..eb062709 100644
--- a/pypots/imputation/saits/modules/core.py
+++ b/pypots/imputation/saits/modules/core.py
@@ -19,7 +19,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from ....modules.self_attention import EncoderLayer, PositionalEncoding
+from ....modules.transformer import EncoderLayer, PositionalEncoding
from ....utils.metrics import cal_mae
@@ -81,7 +81,7 @@ def __init__(
)
self.dropout = nn.Dropout(p=dropout)
- self.position_enc = PositionalEncoding(d_model, n_position=n_steps)
+ self.position_enc = PositionalEncoding(d_model, n_positions=n_steps)
# for the 1st block
self.embedding_1 = nn.Linear(actual_n_features, d_model)
self.reduce_dim_z = nn.Linear(d_model, n_features)
@@ -180,27 +180,23 @@ def forward(
"combining_weights": combining_weights,
"imputed_data": imputed_data,
}
- if not training:
- # if not in training mode, return the classification result only
- return results
-
- ORT_loss = 0
- ORT_loss += self.customized_loss_func(X_tilde_1, X, masks)
- ORT_loss += self.customized_loss_func(X_tilde_2, X, masks)
- ORT_loss += self.customized_loss_func(X_tilde_3, X, masks)
- ORT_loss /= 3
-
- MIT_loss = self.customized_loss_func(
- X_tilde_3, inputs["X_intact"], inputs["indicating_mask"]
- )
- # `loss` is always the item for backward propagating to update the model
- loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
+ # if in training mode, return results with losses
+ if training:
+ ORT_loss = 0
+ ORT_loss += self.customized_loss_func(X_tilde_1, X, masks)
+ ORT_loss += self.customized_loss_func(X_tilde_2, X, masks)
+ ORT_loss += self.customized_loss_func(X_tilde_3, X, masks)
+ ORT_loss /= 3
- results["ORT_loss"] = ORT_loss
- results["MIT_loss"] = MIT_loss
+ MIT_loss = self.customized_loss_func(
+ X_tilde_3, inputs["X_intact"], inputs["indicating_mask"]
+ )
- # will be used for backward propagating to update the model
- results["loss"] = loss
+ results["ORT_loss"] = ORT_loss
+ results["MIT_loss"] = MIT_loss
+ # `loss` is always the item for backward propagating to update the model
+ loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
+ results["loss"] = loss
return results
diff --git a/pypots/imputation/timesnet/__init__.py b/pypots/imputation/timesnet/__init__.py
new file mode 100644
index 00000000..dce839ef
--- /dev/null
+++ b/pypots/imputation/timesnet/__init__.py
@@ -0,0 +1,17 @@
+"""
+The package of the partially-observed time-series imputation model Transformer.
+
+Refer to the paper "Du, W., Cote, D., & Liu, Y. (2023). SAITS: Self-Attention-based Imputation for Time Series.
+Expert systems with applications."
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+
+from .model import TimesNet
+
+__all__ = [
+ "TimesNet",
+]
diff --git a/pypots/imputation/timesnet/data.py b/pypots/imputation/timesnet/data.py
new file mode 100644
index 00000000..5d95170d
--- /dev/null
+++ b/pypots/imputation/timesnet/data.py
@@ -0,0 +1,23 @@
+"""
+Dataset class for TimesNet.
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Union
+
+from ..saits.data import DatasetForSAITS
+
+
+class DatasetForTimesNet(DatasetForSAITS):
+ """Actually TimesNet uses the same data strategy as SAITS, needs MIT for training."""
+
+ def __init__(
+ self,
+ data: Union[dict, str],
+ return_labels: bool = True,
+ file_type: str = "h5py",
+ rate: float = 0.2,
+ ):
+ super().__init__(data, return_labels, file_type, rate)
diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py
new file mode 100644
index 00000000..ec00835a
--- /dev/null
+++ b/pypots/imputation/timesnet/model.py
@@ -0,0 +1,337 @@
+"""
+The implementation of Transformer for the partially-observed time-series imputation task.
+
+Refer to the paper "Du, W., Cote, D., & Liu, Y. (2023). SAITS: Self-Attention-based Imputation for Time Series.
+Expert systems with applications."
+
+Notes
+-----
+Partial implementation uses code from https://github.com/WenjieDu/SAITS.
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Union, Optional
+
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from .data import DatasetForTimesNet
+from ...utils.logging import logger
+from .modules.core import _TimesNet
+from ..base import BaseNNImputer
+from ...data.base import BaseDataset
+from ...optim.adam import Adam
+from ...optim.base import Optimizer
+
+
+class TimesNet(BaseNNImputer):
+ """The PyTorch implementation of the TimesNet model.
+ TimesNet is originally proposed by Wu et al. in :cite:`wu2023timesnet`.
+
+ Parameters
+ ----------
+ n_steps :
+ The number of time steps in the time-series data sample.
+
+ n_features :
+ The number of features in the time-series data sample.
+
+ n_layers :
+ The number of layers in the 1st and 2nd DMSA blocks in the SAITS model.
+
+ top_k :
+ The number of top-k amplitude values to be selected to obtain the most significant frequencies.
+
+ d_model :
+ The dimension of the model.
+
+ d_ffn :
+ The dimension of the feed-forward network.
+
+ n_kernels :
+ The number of 2D kernels (2D convolutional layers) to use in the submodule InceptionBlockV1.
+
+ dropout :
+ The dropout rate for the model.
+
+ batch_size :
+ The batch size for training and evaluating the model.
+
+ epochs :
+ The number of epochs for training the model.
+
+ patience :
+ The patience for the early-stopping mechanism. Given a positive integer, the training process will be
+ stopped when the model does not perform better after that number of epochs.
+ Leaving it default as None will disable the early-stopping.
+
+ optimizer :
+ The optimizer for model training.
+ If not given, will use a default Adam optimizer.
+
+ num_workers :
+ The number of subprocesses to use for data loading.
+ `0` means data loading will be in the main process, i.e. there won't be subprocesses.
+
+ device :
+ The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them.
+ If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple),
+ then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models.
+ If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the
+ model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices).
+ Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future.
+
+ saving_path :
+ The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during
+ training into a tensorboard file). Will not save if not given.
+
+ model_saving_strategy :
+ The strategy to save model checkpoints. It has to be one of [None, "best", "better"].
+ No model will be saved when it is set as None.
+ The "best" strategy will only automatically save the best model after the training finished.
+ The "better" strategy will automatically save the model during training whenever the model performs
+ better than in previous epochs.
+
+ Attributes
+ ----------
+ model : :class:`torch.nn.Module`
+ The underlying Transformer model.
+
+ optimizer : :class:`pypots.optim.Optimizer`
+ The optimizer for model training.
+
+ """
+
+ def __init__(
+ self,
+ n_steps: int,
+ n_features: int,
+ n_layers: int,
+ top_k: int,
+ d_model: int,
+ d_ffn: int,
+ n_kernels: int,
+ dropout: float = 0,
+ batch_size: int = 32,
+ epochs: int = 100,
+ patience: int = None,
+ optimizer: Optional[Optimizer] = Adam(),
+ num_workers: int = 0,
+ device: Optional[Union[str, torch.device, list]] = None,
+ saving_path: str = None,
+ model_saving_strategy: Optional[str] = "best",
+ ):
+ super().__init__(
+ batch_size,
+ epochs,
+ patience,
+ num_workers,
+ device,
+ saving_path,
+ model_saving_strategy,
+ )
+
+ self.n_steps = n_steps
+ self.n_features = n_features
+ # model hype-parameters
+ self.n_layers = n_layers
+ self.top_k = top_k
+ self.d_model = d_model
+ self.d_ffn = d_ffn
+ self.n_kernels = n_kernels
+ self.dropout = dropout
+
+ # set up the model
+ self.model = _TimesNet(
+ self.n_layers,
+ self.n_steps,
+ self.n_features,
+ self.top_k,
+ self.d_model,
+ self.d_ffn,
+ self.n_kernels,
+ self.dropout,
+ )
+ self._send_model_to_given_device()
+ self._print_model_size()
+
+ # set up the optimizer
+ self.optimizer = optimizer
+ self.optimizer.init_optimizer(self.model.parameters())
+
+ def _assemble_input_for_training(self, data: dict) -> dict:
+ (
+ indices,
+ X_intact,
+ X,
+ missing_mask,
+ indicating_mask,
+ ) = self._send_data_to_given_device(data)
+
+ inputs = {
+ "X": X,
+ "X_intact": X_intact,
+ "missing_mask": missing_mask,
+ "indicating_mask": indicating_mask,
+ }
+
+ return inputs
+
+ def _assemble_input_for_validating(self, data: list) -> dict:
+ indices, X, missing_mask = self._send_data_to_given_device(data)
+
+ inputs = {
+ "X": X,
+ "missing_mask": missing_mask,
+ }
+
+ return inputs
+
+ def _assemble_input_for_testing(self, data: list) -> dict:
+ return self._assemble_input_for_validating(data)
+
+ def fit(
+ self,
+ train_set: Union[dict, str],
+ val_set: Optional[Union[dict, str]] = None,
+ file_type: str = "h5py",
+ ) -> None:
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ training_set = DatasetForTimesNet(
+ train_set, return_labels=False, file_type=file_type
+ )
+ training_loader = DataLoader(
+ training_set,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ )
+ val_loader = None
+ if val_set is not None:
+ if isinstance(val_set, str):
+ with h5py.File(val_set, "r") as hf:
+ # Here we read the whole validation set from the file to mask a portion for validation.
+ # In PyPOTS, using a file usually because the data is too big. However, the validation set is
+ # generally shouldn't be too large. For example, we have 1 billion samples for model training.
+ # We won't take 20% of them as the validation set because we want as much as possible data for the
+ # training stage to enhance the model's generalization ability. Therefore, 100,000 representative
+ # samples will be enough to validate the model.
+ val_set = {
+ "X": hf["X"][:],
+ "X_intact": hf["X_intact"][:],
+ "indicating_mask": hf["indicating_mask"][:],
+ }
+
+ # check if X_intact contains missing values
+ if np.isnan(val_set["X_intact"]).any():
+ val_set["X_intact"] = np.nan_to_num(val_set["X_intact"], nan=0)
+ logger.warning(
+ "X_intact shouldn't contain missing data but has NaN values. "
+ "PyPOTS has imputed them with zeros by default to start the training for now. "
+ "Please double-check your data if you have concerns over this operation."
+ )
+
+ val_set = BaseDataset(val_set, return_labels=False, file_type=file_type)
+ val_loader = DataLoader(
+ val_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+
+ # Step 2: train the model and freeze it
+ self._train_model(training_loader, val_loader)
+ self.model.load_state_dict(self.best_model_dict)
+ self.model.eval() # set the model as eval status to freeze it.
+
+ # Step 3: save the model if necessary
+ self._auto_save_model_if_necessary(training_finished=True)
+
+ def predict(
+ self,
+ test_set: Union[dict, str],
+ file_type: str = "h5py",
+ ) -> dict:
+ """Make predictions for the input data with the trained model.
+
+ Parameters
+ ----------
+ test_set : dict or str
+ The dataset for model validating, should be a dictionary including keys as 'X',
+ or a path string locating a data file supported by PyPOTS (e.g. h5 file).
+ If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features],
+ which is time-series data for validating, can contain missing values, and y should be array-like of shape
+ [n_samples], which is classification labels of X.
+ If it is a path string, the path should point to a data file, e.g. a h5 file, which contains
+ key-value pairs like a dict, and it has to include keys as 'X' and 'y'.
+
+ file_type : str
+ The type of the given file if test_set is a path string.
+
+ Returns
+ -------
+ result_dict : dict,
+ The dictionary containing the clustering results and latent variables if necessary.
+
+ """
+ # Step 1: wrap the input data with classes Dataset and DataLoader
+ self.model.eval() # set the model as eval status to freeze it.
+ test_set = BaseDataset(test_set, return_labels=False, file_type=file_type)
+ test_loader = DataLoader(
+ test_set,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.num_workers,
+ )
+ imputation_collector = []
+
+ # Step 2: process the data with the model
+ with torch.no_grad():
+ for idx, data in enumerate(test_loader):
+ inputs = self._assemble_input_for_testing(data)
+ results = self.model.forward(inputs, training=False)
+ imputation_collector.append(results["imputed_data"])
+
+ # Step 3: output collection and return
+ imputation = torch.cat(imputation_collector).cpu().detach().numpy()
+ result_dict = {
+ "imputation": imputation,
+ }
+ return result_dict
+
+ def impute(
+ self,
+ X: Union[dict, str],
+ file_type="h5py",
+ ) -> np.ndarray:
+ """Impute missing values in the given data with the trained model.
+
+ Warnings
+ --------
+ The method impute is deprecated. Please use `predict()` instead.
+
+ Parameters
+ ----------
+ X :
+ The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps),
+ n_features], or a path string locating a data file, e.g. h5 file.
+
+ file_type :
+ The type of the given file if X is a path string.
+
+ Returns
+ -------
+ array-like, shape [n_samples, sequence length (time steps), n_features],
+ Imputed data.
+ """
+ logger.warning(
+ "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead."
+ )
+
+ results_dict = self.predict(X, file_type=file_type)
+ return results_dict["imputation"]
diff --git a/pypots/imputation/timesnet/modules/__init__.py b/pypots/imputation/timesnet/modules/__init__.py
new file mode 100644
index 00000000..ceaa7ee3
--- /dev/null
+++ b/pypots/imputation/timesnet/modules/__init__.py
@@ -0,0 +1,6 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
diff --git a/pypots/imputation/timesnet/modules/core.py b/pypots/imputation/timesnet/modules/core.py
new file mode 100644
index 00000000..9dd4bf5a
--- /dev/null
+++ b/pypots/imputation/timesnet/modules/core.py
@@ -0,0 +1,94 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+import torch
+import torch.fft
+import torch.nn as nn
+
+from .embedding import DataEmbedding
+from .layer import TimesBlock
+from ....utils.metrics import cal_mse
+
+
+class _TimesNet(nn.Module):
+ def __init__(
+ self,
+ n_layers,
+ n_steps,
+ n_features,
+ top_k,
+ d_model,
+ d_ffn,
+ n_kernels,
+ dropout,
+ ):
+ super().__init__()
+
+ self.seq_len = n_steps
+ self.n_layers = n_layers
+
+ self.pred_len = 0 # for the imputation task, the pred_len is always 0
+ self.model = nn.ModuleList(
+ [
+ TimesBlock(n_steps, self.pred_len, top_k, d_model, d_ffn, n_kernels)
+ for _ in range(n_layers)
+ ]
+ )
+ self.enc_embedding = DataEmbedding(
+ n_features,
+ d_model,
+ dropout=dropout,
+ )
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ # for the imputation task, the output dim is the same as input dim
+ c_out = n_features
+ self.projection = nn.Linear(d_model, c_out)
+
+ def forward(self, inputs: dict, training: bool = True) -> dict:
+ X, masks = inputs["X"], inputs["missing_mask"]
+
+ # Normalization from Non-stationary Transformer
+ means = torch.sum(X, dim=1) / torch.sum(masks == 1, dim=1)
+ means = means.unsqueeze(1).detach()
+ x_enc = X - means
+ x_enc = x_enc.masked_fill(masks == 0, 0)
+ stdev = torch.sqrt(
+ torch.sum(x_enc * x_enc, dim=1) / torch.sum(masks == 1, dim=1) + 1e-5
+ )
+ stdev = stdev.unsqueeze(1).detach()
+ x_enc /= stdev
+
+ # embedding
+ enc_out = self.enc_embedding(x_enc) # [B,T,C]
+ # TimesNet
+ for i in range(self.n_layers):
+ enc_out = self.layer_norm(self.model[i](enc_out))
+
+ # project back the original data space
+ dec_out = self.projection(enc_out)
+
+ # De-Normalization from Non-stationary Transformer
+ dec_out = dec_out * (
+ stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1)
+ )
+ dec_out = dec_out + (
+ means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1)
+ )
+
+ imputed_data = masks * X + (1 - masks) * dec_out
+
+ results = {
+ "imputed_data": imputed_data,
+ }
+
+ if training:
+ # `loss` is always the item for backward propagating to update the model
+ loss = cal_mse(dec_out, inputs["X_intact"], inputs["indicating_mask"])
+ results["loss"] = loss
+
+ return results
diff --git a/pypots/imputation/timesnet/modules/embedding.py b/pypots/imputation/timesnet/modules/embedding.py
new file mode 100644
index 00000000..70bd739e
--- /dev/null
+++ b/pypots/imputation/timesnet/modules/embedding.py
@@ -0,0 +1,129 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+import math
+
+import torch
+import torch.fft
+import torch.nn as nn
+
+from ....modules.transformer import PositionalEncoding
+
+
+class TokenEmbedding(nn.Module):
+ def __init__(self, c_in, d_model):
+ super().__init__()
+ padding = 1 if torch.__version__ >= "1.5.0" else 2
+ self.tokenConv = nn.Conv1d(
+ in_channels=c_in,
+ out_channels=d_model,
+ kernel_size=3,
+ padding=padding,
+ padding_mode="circular",
+ bias=False,
+ )
+ for m in self.modules():
+ if isinstance(m, nn.Conv1d):
+ nn.init.kaiming_normal_(
+ m.weight, mode="fan_in", nonlinearity="leaky_relu"
+ )
+
+ def forward(self, x):
+ x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
+ return x
+
+
+class FixedEmbedding(nn.Module):
+ def __init__(self, c_in, d_model):
+ super().__init__()
+
+ w = torch.zeros(c_in, d_model).float()
+ w.require_grad = False
+
+ position = torch.arange(0, c_in).float().unsqueeze(1)
+ div_term = (
+ torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
+ ).exp()
+
+ w[:, 0::2] = torch.sin(position * div_term)
+ w[:, 1::2] = torch.cos(position * div_term)
+
+ self.emb = nn.Embedding(c_in, d_model)
+ self.emb.weight = nn.Parameter(w, requires_grad=False)
+
+ def forward(self, x):
+ return self.emb(x).detach()
+
+
+class TemporalEmbedding(nn.Module):
+ def __init__(self, d_model, embed_type="fixed", freq="h"):
+ super().__init__()
+
+ minute_size = 4
+ hour_size = 24
+ weekday_size = 7
+ day_size = 32
+ month_size = 13
+
+ Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
+ if freq == "t":
+ self.minute_embed = Embed(minute_size, d_model)
+ self.hour_embed = Embed(hour_size, d_model)
+ self.weekday_embed = Embed(weekday_size, d_model)
+ self.day_embed = Embed(day_size, d_model)
+ self.month_embed = Embed(month_size, d_model)
+
+ def forward(self, x):
+ x = x.long()
+ minute_x = (
+ self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0
+ )
+ hour_x = self.hour_embed(x[:, :, 3])
+ weekday_x = self.weekday_embed(x[:, :, 2])
+ day_x = self.day_embed(x[:, :, 1])
+ month_x = self.month_embed(x[:, :, 0])
+
+ return hour_x + weekday_x + day_x + month_x + minute_x
+
+
+class TimeFeatureEmbedding(nn.Module):
+ def __init__(self, d_model, freq="h"):
+ super().__init__()
+
+ freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
+ d_inp = freq_map[freq]
+ self.embed = nn.Linear(d_inp, d_model, bias=False)
+
+ def forward(self, x):
+ return self.embed(x)
+
+
+class DataEmbedding(nn.Module):
+ def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
+ super().__init__()
+
+ self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
+ self.position_embedding = PositionalEncoding(d_hid=d_model)
+ self.temporal_embedding = (
+ TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
+ if embed_type != "timeF"
+ else TimeFeatureEmbedding(d_model=d_model, freq=freq)
+ )
+ self.dropout = nn.Dropout(p=dropout)
+
+ def forward(self, x, x_timestamp=None):
+ if x_timestamp is None:
+ x = self.value_embedding(x) + self.position_embedding(
+ x, return_only_pos=True
+ )
+ else:
+ x = (
+ self.value_embedding(x)
+ + self.temporal_embedding(x_timestamp)
+ + self.position_embedding(x, return_only_pos=True)
+ )
+ return self.dropout(x)
diff --git a/pypots/imputation/timesnet/modules/layer.py b/pypots/imputation/timesnet/modules/layer.py
new file mode 100644
index 00000000..a1130910
--- /dev/null
+++ b/pypots/imputation/timesnet/modules/layer.py
@@ -0,0 +1,105 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+import torch
+import torch.fft
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def FFT_for_Period(x, k=2):
+ # [B, T, C]
+ xf = torch.fft.rfft(x, dim=1)
+ # find period by amplitudes
+ frequency_list = abs(xf).mean(0).mean(-1)
+ frequency_list[0] = 0
+ _, top_list = torch.topk(frequency_list, k)
+ top_list = top_list.detach().cpu().numpy()
+ period = x.shape[1] // top_list
+ return period, abs(xf).mean(-1)[:, top_list]
+
+
+class InceptionBlockV1(nn.Module):
+ def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_kernels = num_kernels
+ kernels = []
+ for i in range(self.num_kernels):
+ kernels.append(
+ nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i)
+ )
+ self.kernels = nn.ModuleList(kernels)
+ if init_weight:
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ res_list = []
+ for i in range(self.num_kernels):
+ res_list.append(self.kernels[i](x))
+ res = torch.stack(res_list, dim=-1).mean(-1)
+ return res
+
+
+class TimesBlock(nn.Module):
+ def __init__(self, seq_len, pred_len, top_k, d_model, d_ffn, num_kernels):
+ super().__init__()
+ self.seq_len = seq_len
+ self.pred_len = pred_len
+ self.top_k = top_k
+
+ # parameter-efficient design
+ self.conv = nn.Sequential(
+ InceptionBlockV1(d_model, d_ffn, num_kernels=num_kernels),
+ nn.GELU(),
+ InceptionBlockV1(d_ffn, d_model, num_kernels=num_kernels),
+ )
+
+ def forward(self, x):
+ B, T, N = x.size()
+ period_list, period_weight = FFT_for_Period(x, self.top_k)
+
+ res = []
+ for i in range(self.top_k):
+ period = period_list[i]
+ # padding
+ if (self.seq_len + self.pred_len) % period != 0:
+ length = (((self.seq_len + self.pred_len) // period) + 1) * period
+ padding = torch.zeros(
+ [x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]
+ ).to(x.device)
+ out = torch.cat([x, padding], dim=1)
+ else:
+ length = self.seq_len + self.pred_len
+ out = x
+ # reshape
+ out = (
+ out.reshape(B, length // period, period, N)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+ # 2D conv: from 1d Variation to 2d Variation
+ out = self.conv(out)
+ # reshape back
+ out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
+ res.append(out[:, : (self.seq_len + self.pred_len), :])
+ res = torch.stack(res, dim=-1)
+ # adaptive aggregation
+ period_weight = F.softmax(period_weight, dim=1)
+ period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
+ res = torch.sum(res * period_weight, -1)
+ # residual connection
+ res = res + x
+ return res
diff --git a/pypots/imputation/transformer/modules/core.py b/pypots/imputation/transformer/modules/core.py
index c09e5d4d..34750da8 100644
--- a/pypots/imputation/transformer/modules/core.py
+++ b/pypots/imputation/transformer/modules/core.py
@@ -18,7 +18,7 @@
import torch
import torch.nn as nn
-from ....modules.self_attention import EncoderLayer, PositionalEncoding
+from ....modules.transformer import EncoderLayer, PositionalEncoding
from ....utils.metrics import cal_mae
@@ -60,7 +60,7 @@ def __init__(
)
self.embedding = nn.Linear(actual_d_feature, d_model)
- self.position_enc = PositionalEncoding(d_model, n_position=d_time)
+ self.position_enc = PositionalEncoding(d_model, n_positions=d_time)
self.dropout = nn.Dropout(p=dropout)
self.reduce_dim = nn.Linear(d_model, d_feature)
@@ -83,24 +83,20 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
X, masks = inputs["X"], inputs["missing_mask"]
imputed_data, learned_presentation = self._process(inputs)
- if not training:
- # if not in training mode, return the classification result only
- return {
- "imputed_data": imputed_data,
- }
-
- ORT_loss = cal_mae(learned_presentation, X, masks)
- MIT_loss = cal_mae(
- learned_presentation, inputs["X_intact"], inputs["indicating_mask"]
- )
-
- # `loss` is always the item for backward propagating to update the model
- loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
-
results = {
"imputed_data": imputed_data,
- "ORT_loss": ORT_loss,
- "MIT_loss": MIT_loss,
- "loss": loss,
}
+
+ # if in training mode, return results with losses
+ if training:
+ ORT_loss = cal_mae(learned_presentation, X, masks)
+ MIT_loss = cal_mae(
+ learned_presentation, inputs["X_intact"], inputs["indicating_mask"]
+ )
+ results["ORT_loss"] = ORT_loss
+ results["MIT_loss"] = MIT_loss
+ # `loss` is always the item for backward propagating to update the model
+ loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss
+ results["loss"] = loss
+
return results
diff --git a/pypots/imputation/usgan/modules/core.py b/pypots/imputation/usgan/modules/core.py
index 71c43f84..16504d6b 100644
--- a/pypots/imputation/usgan/modules/core.py
+++ b/pypots/imputation/usgan/modules/core.py
@@ -56,29 +56,30 @@ def forward(
"discriminator",
], 'training_object should be "generator" or "discriminator"'
- forward_X = inputs["forward"]["X"]
- forward_missing_mask = inputs["forward"]["missing_mask"]
- losses = {}
results = self.generator(inputs, training=training)
- inputs["discrimination"] = self.discriminator(forward_X, forward_missing_mask)
- if not training:
- # if only run imputation operation, then no need to calculate loss
- return results
-
- if training_object == "discriminator":
- l_D = F.binary_cross_entropy_with_logits(
- inputs["discrimination"], forward_missing_mask
- )
- losses["discrimination_loss"] = l_D
- else:
- inputs["discrimination"] = inputs["discrimination"].detach()
- l_G = F.binary_cross_entropy_with_logits(
- inputs["discrimination"],
- 1 - forward_missing_mask,
- weight=1 - forward_missing_mask,
+
+ # if in training mode, return results with losses
+ if training:
+ forward_X = inputs["forward"]["X"]
+ forward_missing_mask = inputs["forward"]["missing_mask"]
+
+ inputs["discrimination"] = self.discriminator(
+ forward_X, forward_missing_mask
)
- loss_gene = l_G + self.lambda_mse * results["loss"]
- losses["generation_loss"] = loss_gene
- losses["imputed_data"] = results["imputed_data"]
- return losses
+ if training_object == "discriminator":
+ l_D = F.binary_cross_entropy_with_logits(
+ inputs["discrimination"], forward_missing_mask
+ )
+ results["discrimination_loss"] = l_D
+ else:
+ inputs["discrimination"] = inputs["discrimination"].detach()
+ l_G = F.binary_cross_entropy_with_logits(
+ inputs["discrimination"],
+ 1 - forward_missing_mask,
+ weight=1 - forward_missing_mask,
+ )
+ loss_gene = l_G + self.lambda_mse * results["loss"]
+ results["generation_loss"] = loss_gene
+
+ return results
diff --git a/pypots/modules/self_attention.py b/pypots/modules/self_attention.py
deleted file mode 100644
index d44fe63a..00000000
--- a/pypots/modules/self_attention.py
+++ /dev/null
@@ -1,729 +0,0 @@
-"""
-The implementation of the modules for Transformer :cite:`vaswani2017Transformer`
-
-Notes
------
-Partial implementation uses code from https://github.com/WenjieDu/SAITS,
-and https://github.com/jadore801120/attention-is-all-you-need-pytorch.
-
-"""
-
-# Created by Wenjie Du
-# License: BSD-3-Clause
-
-from typing import Tuple, Optional, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class ScaledDotProductAttention(nn.Module):
- """Scaled dot-product attention.
-
- Parameters
- ----------
- temperature:
- The temperature for scaling.
-
- attn_dropout:
- The dropout rate for the attention map.
-
- """
-
- def __init__(self, temperature: float, attn_dropout: float = 0.1):
- super().__init__()
- assert temperature > 0, "temperature should be positive"
- assert attn_dropout >= 0, "dropout rate should be non-negative"
- self.temperature = temperature
- self.dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else None
-
- def forward(
- self,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward processing of the scaled dot-product attention.
-
- Parameters
- ----------
- q:
- Query tensor.
- k:
- Key tensor.
- v:
- Value tensor.
-
- attn_mask:
- Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
- 0 in attn_mask means values at the according position in the attention map will be masked out.
-
- Returns
- -------
- output:
- The result of Value multiplied with the scaled dot-product attention map.
-
- attn:
- The scaled dot-product attention map.
-
- """
- # q, k, v all have 4 dimensions [batch_size, n_heads, n_steps, d_tensor]
- # d_tensor could be d_q, d_k, d_v
-
- # dot product q with k.T to obtain similarity
- attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
-
- # apply masking on the attention map, this is optional
- if attn_mask is not None:
- attn = attn.masked_fill(attn_mask == 0, -1e9)
-
- # compute attention score [0, 1], then apply dropout
- attn = F.softmax(attn, dim=-1)
- if self.dropout is not None:
- attn = self.dropout(attn)
-
- # multiply the score with v
- output = torch.matmul(attn, v)
- return output, attn
-
-
-class MultiHeadAttention(nn.Module):
- """Transformer multi-head attention module.
-
- Parameters
- ----------
- n_heads:
- The number of heads in multi-head attention.
-
- d_model:
- The dimension of the input tensor.
-
- d_k:
- The dimension of the key and query tensor.
-
- d_v:
- The dimension of the value tensor.
-
- dropout:
- The dropout rate.
-
- attn_dropout:
- The dropout rate for the attention map.
-
- """
-
- def __init__(
- self,
- n_heads: int,
- d_model: int,
- d_k: int,
- d_v: int,
- dropout: float,
- attn_dropout: float,
- ):
- super().__init__()
-
- self.n_heads = n_heads
- self.d_k = d_k
- self.d_v = d_v
-
- self.w_qs = nn.Linear(d_model, n_heads * d_k, bias=False)
- self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False)
- self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False)
-
- self.attention = ScaledDotProductAttention(d_k**0.5, attn_dropout)
- self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
-
- self.dropout = nn.Dropout(dropout)
- self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
-
- def forward(
- self,
- q: torch.Tensor,
- k: torch.Tensor,
- v: torch.Tensor,
- attn_mask: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward processing of the multi-head attention module.
-
- Parameters
- ----------
- q:
- Query tensor.
-
- k:
- Key tensor.
-
- v:
- Value tensor.
-
- attn_mask:
- Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
- 0 in attn_mask means values at the according position in the attention map will be masked out.
-
- Returns
- -------
- v:
- The output of the multi-head attention layer.
-
- attn_weights:
- The attention map.
-
- """
- # the input q, k, v currently have 3 dimensions [batch_size, n_steps, d_tensor]
- # d_tensor could be n_heads*d_k, n_heads*d_v
-
- # keep useful variables
- batch_size, n_steps = q.size(0), q.size(1)
- residual = q
-
- # now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v]
- q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k)
- k = self.w_ks(k).view(batch_size, n_steps, self.n_heads, self.d_k)
- v = self.w_vs(v).view(batch_size, n_steps, self.n_heads, self.d_v)
-
- # transpose for self-attention calculation -> [batch_size, n_steps, d_k or d_v, n_heads]
- q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
-
- if attn_mask is not None:
- # broadcasting on the head axis
- attn_mask = attn_mask.unsqueeze(1)
-
- v, attn_weights = self.attention(q, k, v, attn_mask)
-
- # transpose back -> [batch_size, n_steps, n_heads, d_v]
- # then merge the last two dimensions to combine all the heads -> [batch_size, n_steps, n_heads*d_v]
- v = v.transpose(1, 2).contiguous().view(batch_size, n_steps, -1)
- v = self.fc(v)
-
- # apply dropout and residual connection
- v = self.dropout(v)
- v += residual
-
- # apply layer-norm
- v = self.layer_norm(v)
-
- return v, attn_weights
-
-
-class PositionWiseFeedForward(nn.Module):
- """Position-wise feed forward network (FFN) in Transformer.
-
- Parameters
- ----------
- d_in:
- The dimension of the input tensor.
-
- d_hid:
- The dimension of the hidden layer.
-
- dropout:
- The dropout rate.
-
- """
-
- def __init__(self, d_in: int, d_hid: int, dropout: float = 0.1):
- super().__init__()
- self.linear_1 = nn.Linear(d_in, d_hid)
- self.linear_2 = nn.Linear(d_hid, d_in)
- self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward processing of the position-wise feed forward network.
-
- Parameters
- ----------
- x:
- Input tensor.
-
- Returns
- -------
- x:
- Output tensor.
- """
- # save the original input for the later residual connection
- residual = x
- # the 1st linear processing and ReLU non-linear projection
- x = F.relu(self.linear_1(x))
- # the 2nd linear processing
- x = self.linear_2(x)
- # apply dropout
- x = self.dropout(x)
- # apply residual connection
- x += residual
- # apply layer-norm
- x = self.layer_norm(x)
- return x
-
-
-class EncoderLayer(nn.Module):
- """Transformer encoder layer.
-
- Parameters
- ----------
- d_model:
- The dimension of the input tensor.
-
- d_inner:
- The dimension of the hidden layer.
-
- n_heads:
- The number of heads in multi-head attention.
-
- d_k:
- The dimension of the key and query tensor.
-
- d_v:
- The dimension of the value tensor.
-
- dropout:
- The dropout rate.
-
- attn_dropout:
- The dropout rate for the attention map.
- """
-
- def __init__(
- self,
- d_model: int,
- d_inner: int,
- n_heads: int,
- d_k: int,
- d_v: int,
- dropout: float = 0.1,
- attn_dropout: float = 0.1,
- ):
- super().__init__()
- self.slf_attn = MultiHeadAttention(
- n_heads, d_model, d_k, d_v, dropout, attn_dropout
- )
- self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)
-
- def forward(
- self,
- enc_input: torch.Tensor,
- src_mask: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward processing of the encoder layer.
-
- Parameters
- ----------
- enc_input:
- Input tensor.
-
- src_mask:
- Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
-
- Returns
- -------
- enc_output:
- Output tensor.
-
- attn_weights:
- The attention map.
-
- """
- enc_output, attn_weights = self.slf_attn(
- enc_input,
- enc_input,
- enc_input,
- attn_mask=src_mask,
- )
- enc_output = self.pos_ffn(enc_output)
- return enc_output, attn_weights
-
-
-class DecoderLayer(nn.Module):
- """Transformer decoder layer.
-
- Parameters
- ----------
- d_model:
- The dimension of the input tensor.
-
- d_inner:
- The dimension of the hidden layer.
-
- n_heads:
- The number of heads in multi-head attention.
-
- d_k:
- The dimension of the key and query tensor.
-
- d_v:
- The dimension of the value tensor.
-
- dropout:
- The dropout rate.
-
- attn_dropout:
- The dropout rate for the attention map.
-
- """
-
- def __init__(
- self,
- d_model: int,
- d_inner: int,
- n_heads: int,
- d_k: int,
- d_v: int,
- dropout: float = 0.1,
- attn_dropout: float = 0.1,
- ):
- super().__init__()
- self.slf_attn = MultiHeadAttention(
- n_heads, d_model, d_k, d_v, dropout, attn_dropout
- )
- self.enc_attn = MultiHeadAttention(
- n_heads, d_model, d_k, d_v, dropout, attn_dropout
- )
- self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)
-
- def forward(
- self,
- dec_input: torch.Tensor,
- enc_output: torch.Tensor,
- slf_attn_mask: Optional[torch.Tensor] = None,
- dec_enc_attn_mask: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Forward processing of the decoder layer.
-
- Parameters
- ----------
- dec_input:
- Input tensor.
-
- enc_output:
- Output tensor from the encoder.
-
- slf_attn_mask:
- Masking tensor for the self-attention module.
- The shape should be [batch_size, n_heads, n_steps, n_steps].
-
- dec_enc_attn_mask:
- Masking tensor for the encoding attention module.
- The shape should be [batch_size, n_heads, n_steps, n_steps].
-
- Returns
- -------
- dec_output:
- Output tensor.
-
- dec_slf_attn:
- The self-attention map.
-
- dec_enc_attn:
- The encoding attention map.
-
- """
- dec_output, dec_slf_attn = self.slf_attn(
- dec_input, dec_input, dec_input, attn_mask=slf_attn_mask
- )
- dec_output, dec_enc_attn = self.enc_attn(
- dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask
- )
- dec_output = self.pos_ffn(dec_output)
- return dec_output, dec_slf_attn, dec_enc_attn
-
-
-class Encoder(nn.Module):
- """Transformer encoder.
-
- Parameters
- ----------
- n_layers:
- The number of layers in the encoder.
-
- n_steps:
- The number of time steps in the input tensor.
-
- n_features:
- The number of features in the input tensor.
-
- d_model:
- The dimension of the module manipulation space.
- The input tensor will be projected to a space with d_model dimensions.
-
- d_inner:
- The dimension of the hidden layer in the feed-forward network.
-
- n_heads:
- The number of heads in multi-head attention.
-
- d_k:
- The dimension of the key and query tensor.
-
- d_v:
- The dimension of the value tensor.
-
- dropout:
- The dropout rate.
-
- attn_dropout:
- The dropout rate for the attention map.
-
- """
-
- def __init__(
- self,
- n_layers: int,
- n_steps: int,
- n_features: int,
- d_model: int,
- d_inner: int,
- n_heads: int,
- d_k: int,
- d_v: int,
- dropout: float,
- attn_dropout: float,
- ):
- super().__init__()
-
- self.embedding = nn.Linear(n_features, d_model)
- self.dropout = nn.Dropout(dropout)
- self.position_enc = PositionalEncoding(d_model, n_position=n_steps)
- self.enc_layer_stack = nn.ModuleList(
- [
- EncoderLayer(
- d_model,
- d_inner,
- n_heads,
- d_k,
- d_v,
- dropout,
- attn_dropout,
- )
- for _ in range(n_layers)
- ]
- )
-
- def forward(
- self,
- x: torch.Tensor,
- src_mask: Optional[torch.Tensor] = None,
- return_attn_weights: bool = False,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
- """Forward processing of the encoder.
-
- Parameters
- ----------
- x:
- Input tensor.
-
- src_mask:
- Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
-
- return_attn_weights:
- Whether to return the attention map.
-
- Returns
- -------
- enc_output:
- Output tensor.
-
- attn_weights_collector:
- A list containing the attention map from each encoder layer.
-
- """
- x = self.embedding(x)
- enc_output = self.dropout(self.position_enc(x))
- attn_weights_collector = []
-
- for layer in self.enc_layer_stack:
- enc_output, attn_weights = layer(enc_output, src_mask)
- attn_weights_collector.append(attn_weights)
-
- if return_attn_weights:
- return enc_output, attn_weights_collector
-
- return enc_output
-
-
-class Decoder(nn.Module):
- """Transformer decoder.
-
- Parameters
- ----------
- n_layers:
- The number of layers in the decoder.
-
- n_steps:
- The number of time steps in the input tensor.
-
- n_features:
- The number of features in the input tensor.
-
- d_model:
- The dimension of the module manipulation space.
- The input tensor will be projected to a space with d_model dimensions.
-
- d_inner:
- The dimension of the hidden layer in the feed-forward network.
-
- n_heads:
- The number of heads in multi-head attention.
-
- d_k:
- The dimension of the key and query tensor.
-
- d_v:
- The dimension of the value tensor.
-
- dropout:
- The dropout rate.
-
- attn_dropout:
- The dropout rate for the attention map.
-
- """
-
- def __init__(
- self,
- n_layers: int,
- n_steps: int,
- n_features: int,
- d_model: int,
- d_inner: int,
- n_heads: int,
- d_k: int,
- d_v: int,
- dropout: float,
- attn_dropout: float,
- ):
- super().__init__()
- self.embedding = nn.Linear(n_features, d_model)
- self.dropout = nn.Dropout(dropout)
- self.position_enc = PositionalEncoding(d_model, n_position=n_steps)
- self.layer_stack = nn.ModuleList(
- [
- DecoderLayer(
- d_model,
- d_inner,
- n_heads,
- d_k,
- d_v,
- dropout,
- attn_dropout,
- )
- for _ in range(n_layers)
- ]
- )
-
- def forward(
- self,
- trg_seq: torch.Tensor,
- enc_output: torch.Tensor,
- trg_mask: Optional[torch.Tensor] = None,
- src_mask: Optional[torch.Tensor] = None,
- return_attn_weights: bool = False,
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, list, list]]:
- """Forward processing of the decoder.
-
- Parameters
- ----------
- trg_seq:
- Input tensor.
-
- enc_output:
- Output tensor from the encoder.
-
- trg_mask:
- Masking tensor for the self-attention module.
-
- src_mask:
- Masking tensor for the encoding attention module.
-
- return_attn_weights:
- Whether to return the attention map.
-
- Returns
- -------
- dec_output:
- Output tensor.
-
- dec_slf_attn_collector:
- A list containing the self-attention map from each decoder layer.
-
- dec_enc_attn_collector:
- A list containing the encoding attention map from each decoder layer.
-
- """
- trg_seq = self.embedding(trg_seq)
- dec_output = self.dropout(self.position_enc(trg_seq))
-
- dec_slf_attn_collector = []
- dec_enc_attn_collector = []
-
- for layer in self.layer_stack:
- dec_output, dec_slf_attn, dec_enc_attn = layer(
- dec_output,
- enc_output,
- slf_attn_mask=trg_mask,
- dec_enc_attn_mask=src_mask,
- )
- dec_slf_attn_collector.append(dec_slf_attn)
- dec_enc_attn_collector.append(dec_enc_attn)
-
- if return_attn_weights:
- return dec_output, dec_slf_attn_collector, dec_enc_attn_collector
-
- return dec_output
-
-
-class PositionalEncoding(nn.Module):
- """Positional-encoding module for Transformer.
-
- Parameters
- ----------
- d_hid:
- The dimension of the hidden layer.
-
- n_position:
- The number of positions.
-
- """
-
- def __init__(self, d_hid: int, n_position: int = 200):
- super().__init__()
- # Not a parameter
- self.register_buffer(
- "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
- )
-
- @staticmethod
- def _get_sinusoid_encoding_table(n_position: int, d_hid: int) -> torch.Tensor:
- """Sinusoid position encoding table"""
-
- def get_position_angle_vec(position):
- return [
- position / np.power(10000, 2 * (hid_j // 2) / d_hid)
- for hid_j in range(d_hid)
- ]
-
- sinusoid_table = np.array(
- [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
- )
- sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
- sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
- return torch.FloatTensor(sinusoid_table).unsqueeze(0)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- """Forward processing of the positional encoding module.
-
- Parameters
- ----------
- x:
- Input tensor.
-
- Returns
- -------
- x:
- Output tensor, the input tensor with the positional encoding added.
-
- """
- return x + self.pos_table[:, : x.size(1)].clone().detach()
diff --git a/pypots/modules/transformer/__init__.py b/pypots/modules/transformer/__init__.py
new file mode 100644
index 00000000..65b02d45
--- /dev/null
+++ b/pypots/modules/transformer/__init__.py
@@ -0,0 +1,22 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from .attention import ScaledDotProductAttention, MultiHeadAttention
+from .auto_encoder import Encoder, Decoder
+from .layers import EncoderLayer, DecoderLayer, PositionWiseFeedForward
+from .pos_enc import PositionalEncoding
+
+__all__ = [
+ "ScaledDotProductAttention",
+ "MultiHeadAttention",
+ "PositionalEncoding",
+ "EncoderLayer",
+ "DecoderLayer",
+ "PositionWiseFeedForward",
+ "Encoder",
+ "Decoder",
+]
diff --git a/pypots/modules/transformer/attention.py b/pypots/modules/transformer/attention.py
new file mode 100644
index 00000000..f4b6cf29
--- /dev/null
+++ b/pypots/modules/transformer/attention.py
@@ -0,0 +1,208 @@
+"""
+The implementation of the modules for Transformer :cite:`vaswani2017Transformer`
+
+Notes
+-----
+Partial implementation uses code from https://github.com/WenjieDu/SAITS,
+and https://github.com/jadore801120/attention-is-all-you-need-pytorch.
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Tuple, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ScaledDotProductAttention(nn.Module):
+ """Scaled dot-product attention.
+
+ Parameters
+ ----------
+ temperature:
+ The temperature for scaling.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
+
+ def __init__(self, temperature: float, attn_dropout: float = 0.1):
+ super().__init__()
+ assert temperature > 0, "temperature should be positive"
+ assert attn_dropout >= 0, "dropout rate should be non-negative"
+ self.temperature = temperature
+ self.dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else None
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward processing of the scaled dot-product attention.
+
+ Parameters
+ ----------
+ q:
+ Query tensor.
+ k:
+ Key tensor.
+ v:
+ Value tensor.
+
+ attn_mask:
+ Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
+ 0 in attn_mask means values at the according position in the attention map will be masked out.
+
+ Returns
+ -------
+ output:
+ The result of Value multiplied with the scaled dot-product attention map.
+
+ attn:
+ The scaled dot-product attention map.
+
+ """
+ # q, k, v all have 4 dimensions [batch_size, n_heads, n_steps, d_tensor]
+ # d_tensor could be d_q, d_k, d_v
+
+ # dot product q with k.T to obtain similarity
+ attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
+
+ # apply masking on the attention map, this is optional
+ if attn_mask is not None:
+ attn = attn.masked_fill(attn_mask == 0, -1e9)
+
+ # compute attention score [0, 1], then apply dropout
+ attn = F.softmax(attn, dim=-1)
+ if self.dropout is not None:
+ attn = self.dropout(attn)
+
+ # multiply the score with v
+ output = torch.matmul(attn, v)
+ return output, attn
+
+
+class MultiHeadAttention(nn.Module):
+ """Transformer multi-head attention module.
+
+ Parameters
+ ----------
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_model:
+ The dimension of the input tensor.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
+
+ def __init__(
+ self,
+ n_heads: int,
+ d_model: int,
+ d_k: int,
+ d_v: int,
+ dropout: float,
+ attn_dropout: float,
+ ):
+ super().__init__()
+
+ self.n_heads = n_heads
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.w_qs = nn.Linear(d_model, n_heads * d_k, bias=False)
+ self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False)
+ self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False)
+
+ self.attention = ScaledDotProductAttention(d_k**0.5, attn_dropout)
+ self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
+
+ self.dropout = nn.Dropout(dropout)
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward processing of the multi-head attention module.
+
+ Parameters
+ ----------
+ q:
+ Query tensor.
+
+ k:
+ Key tensor.
+
+ v:
+ Value tensor.
+
+ attn_mask:
+ Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
+ 0 in attn_mask means values at the according position in the attention map will be masked out.
+
+ Returns
+ -------
+ v:
+ The output of the multi-head attention layer.
+
+ attn_weights:
+ The attention map.
+
+ """
+ # the input q, k, v currently have 3 dimensions [batch_size, n_steps, d_tensor]
+ # d_tensor could be n_heads*d_k, n_heads*d_v
+
+ # keep useful variables
+ batch_size, n_steps = q.size(0), q.size(1)
+ residual = q
+
+ # now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v]
+ q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k)
+ k = self.w_ks(k).view(batch_size, n_steps, self.n_heads, self.d_k)
+ v = self.w_vs(v).view(batch_size, n_steps, self.n_heads, self.d_v)
+
+ # transpose for self-attention calculation -> [batch_size, n_steps, d_k or d_v, n_heads]
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
+
+ if attn_mask is not None:
+ # broadcasting on the head axis
+ attn_mask = attn_mask.unsqueeze(1)
+
+ v, attn_weights = self.attention(q, k, v, attn_mask)
+
+ # transpose back -> [batch_size, n_steps, n_heads, d_v]
+ # then merge the last two dimensions to combine all the heads -> [batch_size, n_steps, n_heads*d_v]
+ v = v.transpose(1, 2).contiguous().view(batch_size, n_steps, -1)
+ v = self.fc(v)
+
+ # apply dropout and residual connection
+ v = self.dropout(v)
+ v += residual
+
+ # apply layer-norm
+ v = self.layer_norm(v)
+
+ return v, attn_weights
diff --git a/pypots/modules/transformer/auto_encoder.py b/pypots/modules/transformer/auto_encoder.py
new file mode 100644
index 00000000..212bbc68
--- /dev/null
+++ b/pypots/modules/transformer/auto_encoder.py
@@ -0,0 +1,258 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from .layers import EncoderLayer, DecoderLayer
+from .pos_enc import PositionalEncoding
+
+
+class Encoder(nn.Module):
+ """Transformer encoder.
+
+ Parameters
+ ----------
+ n_layers:
+ The number of layers in the encoder.
+
+ n_steps:
+ The number of time steps in the input tensor.
+
+ n_features:
+ The number of features in the input tensor.
+
+ d_model:
+ The dimension of the module manipulation space.
+ The input tensor will be projected to a space with d_model dimensions.
+
+ d_inner:
+ The dimension of the hidden layer in the feed-forward network.
+
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
+
+ def __init__(
+ self,
+ n_layers: int,
+ n_steps: int,
+ n_features: int,
+ d_model: int,
+ d_inner: int,
+ n_heads: int,
+ d_k: int,
+ d_v: int,
+ dropout: float,
+ attn_dropout: float,
+ ):
+ super().__init__()
+
+ self.embedding = nn.Linear(n_features, d_model)
+ self.dropout = nn.Dropout(dropout)
+ self.position_enc = PositionalEncoding(d_model, n_positions=n_steps)
+ self.enc_layer_stack = nn.ModuleList(
+ [
+ EncoderLayer(
+ d_model,
+ d_inner,
+ n_heads,
+ d_k,
+ d_v,
+ dropout,
+ attn_dropout,
+ )
+ for _ in range(n_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ src_mask: Optional[torch.Tensor] = None,
+ return_attn_weights: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
+ """Forward processing of the encoder.
+
+ Parameters
+ ----------
+ x:
+ Input tensor.
+
+ src_mask:
+ Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
+
+ return_attn_weights:
+ Whether to return the attention map.
+
+ Returns
+ -------
+ enc_output:
+ Output tensor.
+
+ attn_weights_collector:
+ A list containing the attention map from each encoder layer.
+
+ """
+ x = self.embedding(x)
+ enc_output = self.dropout(self.position_enc(x))
+ attn_weights_collector = []
+
+ for layer in self.enc_layer_stack:
+ enc_output, attn_weights = layer(enc_output, src_mask)
+ attn_weights_collector.append(attn_weights)
+
+ if return_attn_weights:
+ return enc_output, attn_weights_collector
+
+ return enc_output
+
+
+class Decoder(nn.Module):
+ """Transformer decoder.
+
+ Parameters
+ ----------
+ n_layers:
+ The number of layers in the decoder.
+
+ n_steps:
+ The number of time steps in the input tensor.
+
+ n_features:
+ The number of features in the input tensor.
+
+ d_model:
+ The dimension of the module manipulation space.
+ The input tensor will be projected to a space with d_model dimensions.
+
+ d_inner:
+ The dimension of the hidden layer in the feed-forward network.
+
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
+
+ def __init__(
+ self,
+ n_layers: int,
+ n_steps: int,
+ n_features: int,
+ d_model: int,
+ d_inner: int,
+ n_heads: int,
+ d_k: int,
+ d_v: int,
+ dropout: float,
+ attn_dropout: float,
+ ):
+ super().__init__()
+ self.embedding = nn.Linear(n_features, d_model)
+ self.dropout = nn.Dropout(dropout)
+ self.position_enc = PositionalEncoding(d_model, n_positions=n_steps)
+ self.layer_stack = nn.ModuleList(
+ [
+ DecoderLayer(
+ d_model,
+ d_inner,
+ n_heads,
+ d_k,
+ d_v,
+ dropout,
+ attn_dropout,
+ )
+ for _ in range(n_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ trg_seq: torch.Tensor,
+ enc_output: torch.Tensor,
+ trg_mask: Optional[torch.Tensor] = None,
+ src_mask: Optional[torch.Tensor] = None,
+ return_attn_weights: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, list, list]]:
+ """Forward processing of the decoder.
+
+ Parameters
+ ----------
+ trg_seq:
+ Input tensor.
+
+ enc_output:
+ Output tensor from the encoder.
+
+ trg_mask:
+ Masking tensor for the self-attention module.
+
+ src_mask:
+ Masking tensor for the encoding attention module.
+
+ return_attn_weights:
+ Whether to return the attention map.
+
+ Returns
+ -------
+ dec_output:
+ Output tensor.
+
+ dec_slf_attn_collector:
+ A list containing the self-attention map from each decoder layer.
+
+ dec_enc_attn_collector:
+ A list containing the encoding attention map from each decoder layer.
+
+ """
+ trg_seq = self.embedding(trg_seq)
+ dec_output = self.dropout(self.position_enc(trg_seq))
+
+ dec_slf_attn_collector = []
+ dec_enc_attn_collector = []
+
+ for layer in self.layer_stack:
+ dec_output, dec_slf_attn, dec_enc_attn = layer(
+ dec_output,
+ enc_output,
+ slf_attn_mask=trg_mask,
+ dec_enc_attn_mask=src_mask,
+ )
+ dec_slf_attn_collector.append(dec_slf_attn)
+ dec_enc_attn_collector.append(dec_enc_attn)
+
+ if return_attn_weights:
+ return dec_output, dec_slf_attn_collector, dec_enc_attn_collector
+
+ return dec_output
diff --git a/pypots/modules/transformer/layers.py b/pypots/modules/transformer/layers.py
new file mode 100644
index 00000000..6fd1efd2
--- /dev/null
+++ b/pypots/modules/transformer/layers.py
@@ -0,0 +1,236 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+from typing import Tuple, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .attention import MultiHeadAttention
+
+
+class PositionWiseFeedForward(nn.Module):
+ """Position-wise feed forward network (FFN) in Transformer.
+
+ Parameters
+ ----------
+ d_in:
+ The dimension of the input tensor.
+
+ d_hid:
+ The dimension of the hidden layer.
+
+ dropout:
+ The dropout rate.
+
+ """
+
+ def __init__(self, d_in: int, d_hid: int, dropout: float = 0.1):
+ super().__init__()
+ self.linear_1 = nn.Linear(d_in, d_hid)
+ self.linear_2 = nn.Linear(d_hid, d_in)
+ self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward processing of the position-wise feed forward network.
+
+ Parameters
+ ----------
+ x:
+ Input tensor.
+
+ Returns
+ -------
+ x:
+ Output tensor.
+ """
+ # save the original input for the later residual connection
+ residual = x
+ # the 1st linear processing and ReLU non-linear projection
+ x = F.relu(self.linear_1(x))
+ # the 2nd linear processing
+ x = self.linear_2(x)
+ # apply dropout
+ x = self.dropout(x)
+ # apply residual connection
+ x += residual
+ # apply layer-norm
+ x = self.layer_norm(x)
+ return x
+
+
+class EncoderLayer(nn.Module):
+ """Transformer encoder layer.
+
+ Parameters
+ ----------
+ d_model:
+ The dimension of the input tensor.
+
+ d_inner:
+ The dimension of the hidden layer.
+
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ d_inner: int,
+ n_heads: int,
+ d_k: int,
+ d_v: int,
+ dropout: float = 0.1,
+ attn_dropout: float = 0.1,
+ ):
+ super().__init__()
+ self.slf_attn = MultiHeadAttention(
+ n_heads, d_model, d_k, d_v, dropout, attn_dropout
+ )
+ self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)
+
+ def forward(
+ self,
+ enc_input: torch.Tensor,
+ src_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward processing of the encoder layer.
+
+ Parameters
+ ----------
+ enc_input:
+ Input tensor.
+
+ src_mask:
+ Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps].
+
+ Returns
+ -------
+ enc_output:
+ Output tensor.
+
+ attn_weights:
+ The attention map.
+
+ """
+ enc_output, attn_weights = self.slf_attn(
+ enc_input,
+ enc_input,
+ enc_input,
+ attn_mask=src_mask,
+ )
+ enc_output = self.pos_ffn(enc_output)
+ return enc_output, attn_weights
+
+
+class DecoderLayer(nn.Module):
+ """Transformer decoder layer.
+
+ Parameters
+ ----------
+ d_model:
+ The dimension of the input tensor.
+
+ d_inner:
+ The dimension of the hidden layer.
+
+ n_heads:
+ The number of heads in multi-head attention.
+
+ d_k:
+ The dimension of the key and query tensor.
+
+ d_v:
+ The dimension of the value tensor.
+
+ dropout:
+ The dropout rate.
+
+ attn_dropout:
+ The dropout rate for the attention map.
+
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ d_inner: int,
+ n_heads: int,
+ d_k: int,
+ d_v: int,
+ dropout: float = 0.1,
+ attn_dropout: float = 0.1,
+ ):
+ super().__init__()
+ self.slf_attn = MultiHeadAttention(
+ n_heads, d_model, d_k, d_v, dropout, attn_dropout
+ )
+ self.enc_attn = MultiHeadAttention(
+ n_heads, d_model, d_k, d_v, dropout, attn_dropout
+ )
+ self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout)
+
+ def forward(
+ self,
+ dec_input: torch.Tensor,
+ enc_output: torch.Tensor,
+ slf_attn_mask: Optional[torch.Tensor] = None,
+ dec_enc_attn_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Forward processing of the decoder layer.
+
+ Parameters
+ ----------
+ dec_input:
+ Input tensor.
+
+ enc_output:
+ Output tensor from the encoder.
+
+ slf_attn_mask:
+ Masking tensor for the self-attention module.
+ The shape should be [batch_size, n_heads, n_steps, n_steps].
+
+ dec_enc_attn_mask:
+ Masking tensor for the encoding attention module.
+ The shape should be [batch_size, n_heads, n_steps, n_steps].
+
+ Returns
+ -------
+ dec_output:
+ Output tensor.
+
+ dec_slf_attn:
+ The self-attention map.
+
+ dec_enc_attn:
+ The encoding attention map.
+
+ """
+ dec_output, dec_slf_attn = self.slf_attn(
+ dec_input, dec_input, dec_input, attn_mask=slf_attn_mask
+ )
+ dec_output, dec_enc_attn = self.enc_attn(
+ dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask
+ )
+ dec_output = self.pos_ffn(dec_output)
+ return dec_output, dec_slf_attn, dec_enc_attn
diff --git a/pypots/modules/transformer/pos_enc.py b/pypots/modules/transformer/pos_enc.py
new file mode 100644
index 00000000..1697bf96
--- /dev/null
+++ b/pypots/modules/transformer/pos_enc.py
@@ -0,0 +1,67 @@
+"""
+
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+
+import torch
+import torch.nn as nn
+
+
+class PositionalEncoding(nn.Module):
+ """The original positional-encoding module for Transformer.
+
+ Parameters
+ ----------
+ d_hid:
+ The dimension of the hidden layer.
+
+ n_positions:
+ The max number of positions.
+
+ """
+
+ def __init__(self, d_hid: int, n_positions: int = 1000):
+ super().__init__()
+ pe = torch.zeros(n_positions, d_hid, requires_grad=False).float()
+ position = torch.arange(0, n_positions).float().unsqueeze(1)
+ div_term = (
+ torch.arange(0, d_hid, 2).float()
+ * -(torch.log(torch.tensor(10000)) / d_hid)
+ ).exp()
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pos_table", pe)
+
+ def forward(self, x: torch.Tensor, return_only_pos: bool = False) -> torch.Tensor:
+ """Forward processing of the positional encoding module.
+
+ Parameters
+ ----------
+ x:
+ Input tensor.
+
+ return_only_pos:
+ Whether to return only the positional encoding.
+
+ Returns
+ -------
+ If return_only_pos is True:
+ pos_enc:
+ The positional encoding.
+ else:
+ x_with_pos:
+ Output tensor, the input tensor with the positional encoding added.
+ """
+ pos_enc = self.pos_table[:, : x.size(1)].clone().detach()
+
+ if return_only_pos:
+ return pos_enc
+
+ x_with_pos = x + pos_enc
+ return x_with_pos
diff --git a/tests/imputation/timesnet.py b/tests/imputation/timesnet.py
new file mode 100644
index 00000000..52e33ae4
--- /dev/null
+++ b/tests/imputation/timesnet.py
@@ -0,0 +1,112 @@
+"""
+Test cases for TimesNet imputation model.
+"""
+
+# Created by Wenjie Du
+# License: BSD-3-Clause
+
+
+import os.path
+import unittest
+
+import numpy as np
+import pytest
+
+from pypots.imputation import TimesNet
+from pypots.optim import Adam
+from pypots.utils.logging import logger
+from pypots.utils.metrics import cal_mae
+from tests.global_test_config import (
+ DATA,
+ DEVICE,
+ check_tb_and_model_checkpoints_existence,
+)
+from tests.imputation.config import (
+ TRAIN_SET,
+ VAL_SET,
+ TEST_SET,
+ RESULT_SAVING_DIR_FOR_IMPUTATION,
+ EPOCHS,
+)
+
+
+class TestTimesNet(unittest.TestCase):
+ logger.info("Running tests for an imputation model TimesNet...")
+
+ # set the log and model saving path
+ saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "TimesNet")
+ model_save_name = "saved_timesnet_model.pypots"
+
+ # initialize an Adam optimizer
+ optimizer = Adam(lr=0.001, weight_decay=1e-5)
+
+ # initialize a TimesNet model
+ timesnet = TimesNet(
+ DATA["n_steps"],
+ DATA["n_features"],
+ n_layers=2,
+ top_k=3,
+ d_model=128,
+ d_ffn=256,
+ n_kernels=3,
+ dropout=0.1,
+ epochs=EPOCHS,
+ saving_path=saving_path,
+ optimizer=optimizer,
+ device=DEVICE,
+ )
+
+ @pytest.mark.xdist_group(name="imputation-timesnet")
+ def test_0_fit(self):
+ self.timesnet.fit(TRAIN_SET, VAL_SET)
+
+ @pytest.mark.xdist_group(name="imputation-timesnet")
+ def test_1_impute(self):
+ imputation_results = self.timesnet.predict(TEST_SET)
+ assert not np.isnan(
+ imputation_results["imputation"]
+ ).any(), "Output still has missing values after running impute()."
+
+ test_MAE = cal_mae(
+ imputation_results["imputation"],
+ DATA["test_X_intact"],
+ DATA["test_X_indicating_mask"],
+ )
+ logger.info(f"TimesNet test_MAE: {test_MAE}")
+
+ @pytest.mark.xdist_group(name="imputation-timesnet")
+ def test_2_parameters(self):
+ assert hasattr(self.timesnet, "model") and self.timesnet.model is not None
+
+ assert (
+ hasattr(self.timesnet, "optimizer") and self.timesnet.optimizer is not None
+ )
+
+ assert hasattr(self.timesnet, "best_loss")
+ self.assertNotEqual(self.timesnet.best_loss, float("inf"))
+
+ assert (
+ hasattr(self.timesnet, "best_model_dict")
+ and self.timesnet.best_model_dict is not None
+ )
+
+ @pytest.mark.xdist_group(name="imputation-timesnet")
+ def test_3_saving_path(self):
+ # whether the root saving dir exists, which should be created by save_log_into_tb_file
+ assert os.path.exists(
+ self.saving_path
+ ), f"file {self.saving_path} does not exist"
+
+ # check if the tensorboard file and model checkpoints exist
+ check_tb_and_model_checkpoints_existence(self.timesnet)
+
+ # save the trained model into file, and check if the path exists
+ saved_model_path = os.path.join(self.saving_path, self.model_save_name)
+ self.timesnet.save(saved_model_path)
+
+ # test loading the saved model, not necessary, but need to test
+ self.timesnet.load(saved_model_path)
+
+
+if __name__ == "__main__":
+ unittest.main()