From 77e680ceb7dfd8c6c8bbc2b78985489ce5406989 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 20 Mar 2024 21:31:07 +0800 Subject: [PATCH 1/4] refactor: refactor TimesNet and fix some typos in its doc; --- pypots/imputation/timesnet/model.py | 8 ++++---- pypots/imputation/timesnet/modules/core.py | 2 +- .../timesnet/modules/{layer.py => submodules.py} | 0 3 files changed, 5 insertions(+), 5 deletions(-) rename pypots/imputation/timesnet/modules/{layer.py => submodules.py} (100%) diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py index aee1da5d..195805ff 100644 --- a/pypots/imputation/timesnet/model.py +++ b/pypots/imputation/timesnet/model.py @@ -1,12 +1,12 @@ """ -The implementation of Transformer for the partially-observed time-series imputation task. +The implementation of TimesNet for the partially-observed time-series imputation task. -Refer to the paper "Du, W., Cote, D., & Liu, Y. (2023). SAITS: Self-Attention-based Imputation for Time Series. -Expert systems with applications." +Refer to the paper "Wu, H., Hu, T., Liu, Y., Zhou, H., Wang, J., & Long, M. (2023). +TimesNet: Temporal 2d-variation modeling for general time series analysis. ICLR 2023." Notes ----- -Partial implementation uses code from https://github.com/WenjieDu/SAITS. +Partial implementation uses code from https://github.com/thuml/Time-Series-Library. """ diff --git a/pypots/imputation/timesnet/modules/core.py b/pypots/imputation/timesnet/modules/core.py index fd2e4cf0..fe7dc0b2 100644 --- a/pypots/imputation/timesnet/modules/core.py +++ b/pypots/imputation/timesnet/modules/core.py @@ -8,7 +8,7 @@ import torch.nn as nn from .embedding import DataEmbedding -from .layer import TimesBlock +from .submodules import TimesBlock from ....nn.functional import nonstationary_norm, nonstationary_denorm from ....utils.metrics import calc_mse diff --git a/pypots/imputation/timesnet/modules/layer.py b/pypots/imputation/timesnet/modules/submodules.py similarity index 100% rename from pypots/imputation/timesnet/modules/layer.py rename to pypots/imputation/timesnet/modules/submodules.py From 2dae4e0f0a8e267808e4c91c1a5760ed92192246 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 21 Mar 2024 23:05:01 +0800 Subject: [PATCH 2/4] feat: add Autoformer as an imputation model; --- pypots/imputation/autoformer/__init__.py | 17 + pypots/imputation/autoformer/data.py | 24 + pypots/imputation/autoformer/model.py | 318 +++++++++++++ .../imputation/autoformer/modules/__init__.py | 6 + pypots/imputation/autoformer/modules/core.py | 91 ++++ .../autoformer/modules/submodules.py | 422 ++++++++++++++++++ 6 files changed, 878 insertions(+) create mode 100644 pypots/imputation/autoformer/__init__.py create mode 100644 pypots/imputation/autoformer/data.py create mode 100644 pypots/imputation/autoformer/model.py create mode 100644 pypots/imputation/autoformer/modules/__init__.py create mode 100644 pypots/imputation/autoformer/modules/core.py create mode 100644 pypots/imputation/autoformer/modules/submodules.py diff --git a/pypots/imputation/autoformer/__init__.py b/pypots/imputation/autoformer/__init__.py new file mode 100644 index 00000000..18efc4a5 --- /dev/null +++ b/pypots/imputation/autoformer/__init__.py @@ -0,0 +1,17 @@ +""" +The package of the partially-observed time-series imputation model Autoformer. + +Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021). +Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.". + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .model import Autoformer + +__all__ = [ + "Autoformer", +] diff --git a/pypots/imputation/autoformer/data.py b/pypots/imputation/autoformer/data.py new file mode 100644 index 00000000..8e1cfb99 --- /dev/null +++ b/pypots/imputation/autoformer/data.py @@ -0,0 +1,24 @@ +""" +Dataset class for TimesNet. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union + +from ..saits.data import DatasetForSAITS + + +class DatasetForAutoformer(DatasetForSAITS): + """Actually Autoformer uses the same data strategy as SAITS, needs MIT for training.""" + + def __init__( + self, + data: Union[dict, str], + return_X_ori: bool, + return_labels: bool, + file_type: str = "h5py", + rate: float = 0.2, + ): + super().__init__(data, return_X_ori, return_labels, file_type, rate) diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py new file mode 100644 index 00000000..c27e80cc --- /dev/null +++ b/pypots/imputation/autoformer/model.py @@ -0,0 +1,318 @@ +""" +The implementation of Transformer for the partially-observed time-series imputation task. + +Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021). +Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.". + +Notes +----- +Partial implementation uses code from https://github.com/thuml/Time-Series-Library + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Union, Optional + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from .data import DatasetForAutoformer +from .modules.core import _Autoformer +from ..base import BaseNNImputer +from ...data.base import BaseDataset +from ...data.checking import check_X_ori_in_val_set +from ...optim.adam import Adam +from ...optim.base import Optimizer +from ...utils.logging import logger + + +class Autoformer(BaseNNImputer): + """The PyTorch implementation of the Autoformer model. + TimesNet is originally proposed by Wu et al. in :cite:`wu2021autoformer`. + + Parameters + ---------- + n_steps : + The number of time steps in the time-series data sample. + + n_features : + The number of features in the time-series data sample. + + n_layers : + The number of layers in the 1st and 2nd DMSA blocks in the SAITS model. + + d_model : + The dimension of the model. + + d_ffn : + The dimension of the feed-forward network. + + dropout : + The dropout rate for the model. + + batch_size : + The batch size for training and evaluating the model. + + epochs : + The number of epochs for training the model. + + patience : + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + optimizer : + The optimizer for model training. + If not given, will use a default Adam optimizer. + + num_workers : + The number of subprocesses to use for data loading. + `0` means data loading will be in the main process, i.e. there won't be subprocesses. + + device : + The device for the model to run on. It can be a string, a :class:`torch.device` object, or a list of them. + If not given, will try to use CUDA devices first (will use the default CUDA device if there are multiple), + then CPUs, considering CUDA and CPU are so far the main devices for people to train ML models. + If given a list of devices, e.g. ['cuda:0', 'cuda:1'], or [torch.device('cuda:0'), torch.device('cuda:1')] , the + model will be parallely trained on the multiple devices (so far only support parallel training on CUDA devices). + Other devices like Google TPU and Apple Silicon accelerator MPS may be added in the future. + + saving_path : + The path for automatically saving model checkpoints and tensorboard files (i.e. loss values recorded during + training into a tensorboard file). Will not save if not given. + + model_saving_strategy : + The strategy to save model checkpoints. It has to be one of [None, "best", "better", "all"]. + No model will be saved when it is set as None. + The "best" strategy will only automatically save the best model after the training finished. + The "better" strategy will automatically save the model during training whenever the model performs + better than in previous epochs. + The "all" strategy will save every model after each epoch training. + + Attributes + ---------- + model : :class:`torch.nn.Module` + The underlying Transformer model. + + optimizer : :class:`pypots.optim.Optimizer` + The optimizer for model training. + + """ + + def __init__( + self, + n_steps: int, + n_features: int, + n_layers: int, + n_heads: int, + d_model: int, + d_ffn: int, + factor: int, + moving_avg_kernel_size: int, + dropout: float = 0, + batch_size: int = 32, + epochs: int = 100, + patience: int = None, + optimizer: Optional[Optimizer] = Adam(), + num_workers: int = 0, + device: Optional[Union[str, torch.device, list]] = None, + saving_path: str = None, + model_saving_strategy: Optional[str] = "best", + ): + super().__init__( + batch_size, + epochs, + patience, + num_workers, + device, + saving_path, + model_saving_strategy, + ) + + self.n_steps = n_steps + self.n_features = n_features + # model hype-parameters + self.n_heads = n_heads + self.n_layers = n_layers + self.d_model = d_model + self.d_ffn = d_ffn + self.factor = factor + self.moving_avg_kernel_size = moving_avg_kernel_size + self.dropout = dropout + + # set up the model + self.model = _Autoformer( + self.n_steps, + self.n_features, + self.n_layers, + self.n_heads, + self.d_model, + self.d_ffn, + self.factor, + self.moving_avg_kernel_size, + self.dropout, + ) + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer = optimizer + self.optimizer.init_optimizer(self.model.parameters()) + + def _assemble_input_for_training(self, data: list) -> dict: + ( + indices, + X, + missing_mask, + X_ori, + indicating_mask, + ) = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + "X_ori": X_ori, + "indicating_mask": indicating_mask, + } + + return inputs + + def _assemble_input_for_validating(self, data: list) -> dict: + return self._assemble_input_for_training(data) + + def _assemble_input_for_testing(self, data: list) -> dict: + indices, X, missing_mask = self._send_data_to_given_device(data) + + inputs = { + "X": X, + "missing_mask": missing_mask, + } + + return inputs + + def fit( + self, + train_set: Union[dict, str], + val_set: Optional[Union[dict, str]] = None, + file_type: str = "h5py", + ) -> None: + # Step 1: wrap the input data with classes Dataset and DataLoader + training_set = DatasetForAutoformer( + train_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + training_loader = DataLoader( + training_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_loader = None + if val_set is not None: + if not check_X_ori_in_val_set(val_set): + raise ValueError("val_set must contain 'X_ori' for model validation.") + val_set = DatasetForAutoformer( + val_set, return_X_ori=True, return_labels=False, file_type=file_type + ) + val_loader = DataLoader( + val_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + # Step 2: train the model and freeze it + self._train_model(training_loader, val_loader) + self.model.load_state_dict(self.best_model_dict) + self.model.eval() # set the model as eval status to freeze it. + + # Step 3: save the model if necessary + self._auto_save_model_if_necessary(confirm_saving=True) + + def predict( + self, + test_set: Union[dict, str], + file_type: str = "h5py", + ) -> dict: + """Make predictions for the input data with the trained model. + + Parameters + ---------- + test_set : dict or str + The dataset for model validating, should be a dictionary including keys as 'X', + or a path string locating a data file supported by PyPOTS (e.g. h5 file). + If it is a dict, X should be array-like of shape [n_samples, sequence length (time steps), n_features], + which is time-series data for validating, can contain missing values, and y should be array-like of shape + [n_samples], which is classification labels of X. + If it is a path string, the path should point to a data file, e.g. a h5 file, which contains + key-value pairs like a dict, and it has to include keys as 'X' and 'y'. + + file_type : str + The type of the given file if test_set is a path string. + + Returns + ------- + result_dict : dict, + The dictionary containing the clustering results and latent variables if necessary. + + """ + # Step 1: wrap the input data with classes Dataset and DataLoader + self.model.eval() # set the model as eval status to freeze it. + test_set = BaseDataset( + test_set, return_X_ori=False, return_labels=False, file_type=file_type + ) + test_loader = DataLoader( + test_set, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + imputation_collector = [] + + # Step 2: process the data with the model + with torch.no_grad(): + for idx, data in enumerate(test_loader): + inputs = self._assemble_input_for_testing(data) + results = self.model.forward(inputs, training=False) + imputation_collector.append(results["imputed_data"]) + + # Step 3: output collection and return + imputation = torch.cat(imputation_collector).cpu().detach().numpy() + result_dict = { + "imputation": imputation, + } + return result_dict + + def impute( + self, + X: Union[dict, str], + file_type="h5py", + ) -> np.ndarray: + """Impute missing values in the given data with the trained model. + + Warnings + -------- + The method impute is deprecated. Please use `predict()` instead. + + Parameters + ---------- + X : + The data samples for testing, should be array-like of shape [n_samples, sequence length (time steps), + n_features], or a path string locating a data file, e.g. h5 file. + + file_type : + The type of the given file if X is a path string. + + Returns + ------- + array-like, shape [n_samples, sequence length (time steps), n_features], + Imputed data. + """ + logger.warning( + "🚨DeprecationWarning: The method impute is deprecated. Please use `predict` instead." + ) + + results_dict = self.predict(X, file_type=file_type) + return results_dict["imputation"] diff --git a/pypots/imputation/autoformer/modules/__init__.py b/pypots/imputation/autoformer/modules/__init__.py new file mode 100644 index 00000000..ceaa7ee3 --- /dev/null +++ b/pypots/imputation/autoformer/modules/__init__.py @@ -0,0 +1,6 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause diff --git a/pypots/imputation/autoformer/modules/core.py b/pypots/imputation/autoformer/modules/core.py new file mode 100644 index 00000000..c42eb3ab --- /dev/null +++ b/pypots/imputation/autoformer/modules/core.py @@ -0,0 +1,91 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import torch.nn as nn + +from ....utils.metrics import calc_mse +from .submodules import ( + DataEmbedding_wo_Pos, + SeriesDecompositionBlock, + SeasonalLayerNorm, + AutoformerEncoderLayer, + AutoformerEncoder, + AutoCorrelation, + AutoCorrelationLayer, +) + + +class _Autoformer(nn.Module): + def __init__( + self, + n_steps, + n_features, + n_layers, + n_heads, + d_model, + d_ffn, + factor, + moving_avg_kernel_size, + dropout, + activation="relu", + output_attention=False, + ): + super().__init__() + + self.seq_len = n_steps + self.n_layers = n_layers + self.series_decomp = SeriesDecompositionBlock(moving_avg_kernel_size) + self.enc_embedding = DataEmbedding_wo_Pos( + n_features, + d_model, + dropout=dropout, + ) + self.encoder = AutoformerEncoder( + [ + AutoformerEncoderLayer( + AutoCorrelationLayer( + AutoCorrelation(False, factor, dropout, output_attention), + d_model, + n_heads, + ), + d_model, + d_ffn, + moving_avg_kernel_size, + dropout, + activation, + ) + for i in range(n_layers) + ], + norm_layer=SeasonalLayerNorm(d_model), + ) + + # for the imputation task, the output dim is the same as input dim + self.projection = nn.Linear(d_model, n_features) + + def forward(self, inputs: dict, training: bool = True) -> dict: + X, masks = inputs["X"], inputs["missing_mask"] + + # embedding + enc_out = self.enc_embedding(X) # [B,T,C] + + # Autoformer encoder processing + enc_out, attns = self.encoder(enc_out) + + # project back the original data space + dec_out = self.projection(enc_out) + + imputed_data = masks * X + (1 - masks) * dec_out + results = { + "imputed_data": imputed_data, + } + + if training: + # `loss` is always the item for backward propagating to update the model + loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) + results["loss"] = loss + + return results diff --git a/pypots/imputation/autoformer/modules/submodules.py b/pypots/imputation/autoformer/modules/submodules.py new file mode 100644 index 00000000..c08ff81b --- /dev/null +++ b/pypots/imputation/autoformer/modules/submodules.py @@ -0,0 +1,422 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +import math + +import torch +import torch.fft +import torch.nn as nn +import torch.nn.functional as F + +from ...timesnet.modules.embedding import ( + TokenEmbedding, + TemporalEmbedding, + TimeFeatureEmbedding, +) +from ....nn.modules.transformer import PositionalEncoding + + +class DataEmbedding_wo_Pos(nn.Module): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): + super().__init__() + + self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) + self.position_embedding = PositionalEncoding(d_model) + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, freq=freq) + ) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, x_timestamp=None): + if x_timestamp is None: + x = self.value_embedding(x) + else: + x = self.value_embedding(x) + self.temporal_embedding(x_timestamp) + return self.dropout(x) + + +class AutoCorrelation(nn.Module): + """ + AutoCorrelation Mechanism with the following two phases: + (1) period-based dependencies discovery + (2) time delay aggregation + + This block can replace the self-attention family mechanism seamlessly. + """ + + def __init__( + self, + mask_flag=True, + factor=1, + scale=None, + attention_dropout=0.1, + output_attention=False, + ): + super().__init__() + self.factor = factor + self.scale = scale + self.mask_flag = mask_flag + self.output_attention = output_attention + self.dropout = nn.Dropout(attention_dropout) + + def time_delay_agg_training(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the training phase. + """ + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] + weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + pattern = torch.roll(tmp_values, -int(index[i]), -1) + delays_agg = delays_agg + pattern * ( + tmp_corr[:, i] + .unsqueeze(1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, head, channel, length) + ) + return delays_agg + + def time_delay_agg_inference(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the inference phase. + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = ( + torch.arange(length) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch, head, channel, 1) + .to(values.device) + ) + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + weights, delay = torch.topk(mean_value, top_k, dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze( + 1 + ).repeat(1, head, channel, length) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * ( + tmp_corr[:, i] + .unsqueeze(1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, head, channel, length) + ) + return delays_agg + + def time_delay_agg_full(self, values, corr): + """ + Standard version of Autocorrelation + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = ( + torch.arange(length) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch, head, channel, 1) + .to(values.device) + ) + # find top k + top_k = int(self.factor * math.log(length)) + weights, delay = torch.topk(corr, top_k, dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[..., i].unsqueeze(-1) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1)) + return delays_agg + + def forward(self, queries, keys, values, attn_mask): + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, : (L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + + # period-based dependencies + q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) + k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) + res = q_fft * torch.conj(k_fft) + corr = torch.fft.irfft(res, dim=-1) + + # time delay agg + if self.training: + V = self.time_delay_agg_training( + values.permute(0, 2, 3, 1).contiguous(), corr + ).permute(0, 3, 1, 2) + else: + V = self.time_delay_agg_inference( + values.permute(0, 2, 3, 1).contiguous(), corr + ).permute(0, 3, 1, 2) + + if self.output_attention: + return (V.contiguous(), corr.permute(0, 3, 1, 2)) + else: + return (V.contiguous(), None) + + +class AutoCorrelationLayer(nn.Module): + def __init__(self, correlation, d_model, n_heads, d_keys=None, d_values=None): + super().__init__() + + d_keys = d_keys or (d_model // n_heads) + d_values = d_values or (d_model // n_heads) + + self.inner_correlation = correlation + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_model, d_keys * n_heads) + self.value_projection = nn.Linear(d_model, d_values * n_heads) + self.out_projection = nn.Linear(d_values * n_heads, d_model) + self.n_heads = n_heads + + def forward(self, queries, keys, values, attn_mask): + B, L, _ = queries.shape + _, S, _ = keys.shape + H = self.n_heads + + queries = self.query_projection(queries).view(B, L, H, -1) + keys = self.key_projection(keys).view(B, S, H, -1) + values = self.value_projection(values).view(B, S, H, -1) + + out, attn = self.inner_correlation(queries, keys, values, attn_mask) + out = out.view(B, L, -1) + + return self.out_projection(out), attn + + +class SeasonalLayerNorm(nn.Module): + """A special designed layer normalization for the seasonal part.""" + + def __init__(self, n_channels): + super().__init__() + self.layer_norm = nn.LayerNorm(n_channels) + + def forward(self, x): + x_hat = self.layer_norm(x) + bias = torch.mean(x_hat, dim=1).unsqueeze(1).repeat(1, x.shape[1], 1) + return x_hat - bias + + +class MovingAvgBlock(nn.Module): + """ + The moving average block to highlight the trend of time series. + """ + + def __init__(self, kernel_size, stride): + super().__init__() + self.kernel_size = kernel_size + self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0) + + def forward(self, x): + # padding on the both ends of time series + front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1) + end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1) + x = torch.cat([front, x, end], dim=1) + x = self.avg(x.permute(0, 2, 1)) + x = x.permute(0, 2, 1) + return x + + +class SeriesDecompositionBlock(nn.Module): + """ + Series decomposition block + """ + + def __init__(self, kernel_size): + super().__init__() + self.moving_avg = MovingAvgBlock(kernel_size, stride=1) + + def forward(self, x): + moving_mean = self.moving_avg(x) + res = x - moving_mean + return res, moving_mean + + +class AutoformerEncoderLayer(nn.Module): + """Autoformer encoder layer with the progressive decomposition architecture.""" + + def __init__( + self, + attention, + d_model, + d_ff=None, + moving_avg=25, + dropout=0.1, + activation="relu", + ): + super().__init__() + d_ff = d_ff or 4 * d_model + self.attention = attention + self.conv1 = nn.Conv1d( + in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False + ) + self.conv2 = nn.Conv1d( + in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False + ) + self.series_decomp1 = SeriesDecompositionBlock(moving_avg) + self.series_decomp2 = SeriesDecompositionBlock(moving_avg) + self.dropout = nn.Dropout(dropout) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, attn_mask=None): + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask) + x = x + self.dropout(new_x) + x, _ = self.series_decomp1(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + res, _ = self.series_decomp2(x + y) + return res, attn + + +class AutoformerEncoder(nn.Module): + def __init__(self, attn_layers, conv_layers=None, norm_layer=None): + super().__init__() + self.attn_layers = nn.ModuleList(attn_layers) + self.conv_layers = ( + nn.ModuleList(conv_layers) if conv_layers is not None else None + ) + self.norm = norm_layer + + def forward(self, x, attn_mask=None): + attns = [] + if self.conv_layers is not None: + for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers): + x, attn = attn_layer(x, attn_mask=attn_mask) + x = conv_layer(x) + attns.append(attn) + x, attn = self.attn_layers[-1](x) + attns.append(attn) + else: + for attn_layer in self.attn_layers: + x, attn = attn_layer(x, attn_mask=attn_mask) + attns.append(attn) + + if self.norm is not None: + x = self.norm(x) + + return x, attns + + +class AutoformerDecoderLayer(nn.Module): + """ + Autoformer decoder layer with the progressive decomposition architecture + """ + + def __init__( + self, + self_attention, + cross_attention, + d_model, + c_out, + d_ff=None, + moving_avg=25, + dropout=0.1, + activation="relu", + ): + super().__init__() + d_ff = d_ff or 4 * d_model + self.self_attention = self_attention + self.cross_attention = cross_attention + self.conv1 = nn.Conv1d( + in_channels=d_model, out_channels=d_ff, kernel_size=1, bias=False + ) + self.conv2 = nn.Conv1d( + in_channels=d_ff, out_channels=d_model, kernel_size=1, bias=False + ) + self.series_decomp1 = SeriesDecompositionBlock(moving_avg) + self.series_decomp2 = SeriesDecompositionBlock(moving_avg) + self.series_decomp3 = SeriesDecompositionBlock(moving_avg) + self.dropout = nn.Dropout(dropout) + self.projection = nn.Conv1d( + in_channels=d_model, + out_channels=c_out, + kernel_size=3, + stride=1, + padding=1, + padding_mode="circular", + bias=False, + ) + self.activation = F.relu if activation == "relu" else F.gelu + + def forward(self, x, cross, x_mask=None, cross_mask=None): + x = x + self.dropout(self.self_attention(x, x, x, attn_mask=x_mask)[0]) + x, trend1 = self.series_decomp1(x) + x = x + self.dropout( + self.cross_attention(x, cross, cross, attn_mask=cross_mask)[0] + ) + x, trend2 = self.series_decomp2(x) + y = x + y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) + y = self.dropout(self.conv2(y).transpose(-1, 1)) + x, trend3 = self.series_decomp3(x + y) + + residual_trend = trend1 + trend2 + trend3 + residual_trend = self.projection(residual_trend.permute(0, 2, 1)).transpose( + 1, 2 + ) + return x, residual_trend + + +class AutoformerDecoder(nn.Module): + def __init__(self, layers, norm_layer=None, projection=None): + super().__init__() + self.layers = nn.ModuleList(layers) + self.norm = norm_layer + self.projection = projection + + def forward(self, x, cross, x_mask=None, cross_mask=None, trend=None): + for layer in self.layers: + x, residual_trend = layer(x, cross, x_mask=x_mask, cross_mask=cross_mask) + trend = trend + residual_trend + + if self.norm is not None: + x = self.norm(x) + + if self.projection is not None: + x = self.projection(x) + return x, trend From 78a77aca6e3b5217694ddeae77e2fb1af1fb0719 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 22 Mar 2024 01:13:22 +0800 Subject: [PATCH 3/4] test: add Autoformer unit test; --- tests/imputation/autoformer.py | 130 +++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 tests/imputation/autoformer.py diff --git a/tests/imputation/autoformer.py b/tests/imputation/autoformer.py new file mode 100644 index 00000000..8fa93a78 --- /dev/null +++ b/tests/imputation/autoformer.py @@ -0,0 +1,130 @@ +""" +Test cases for Autoformer imputation model. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import os.path +import unittest + +import numpy as np +import pytest + +from pypots.imputation import Autoformer +from pypots.optim import Adam +from pypots.utils.logging import logger +from pypots.utils.metrics import calc_mse +from tests.global_test_config import ( + DATA, + EPOCHS, + DEVICE, + TRAIN_SET, + VAL_SET, + TEST_SET, + H5_TRAIN_SET_PATH, + H5_VAL_SET_PATH, + H5_TEST_SET_PATH, + RESULT_SAVING_DIR_FOR_IMPUTATION, + check_tb_and_model_checkpoints_existence, +) + + +class TestAutoformer(unittest.TestCase): + logger.info("Running tests for an imputation model Autoformer...") + + # set the log and model saving path + saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "Autoformer") + model_save_name = "saved_autoformer_model.pypots" + + # initialize an Adam optimizer + optimizer = Adam(lr=0.001, weight_decay=1e-5) + + # initialize a Autoformer model + autoformer = Autoformer( + DATA["n_steps"], + DATA["n_features"], + n_layers=2, + n_heads=2, + d_model=128, + d_ffn=256, + factor=3, + moving_avg_kernel_size=3, + dropout=0, + epochs=EPOCHS, + saving_path=saving_path, + optimizer=optimizer, + device=DEVICE, + ) + + @pytest.mark.xdist_group(name="imputation-autoformer") + def test_0_fit(self): + self.autoformer.fit(TRAIN_SET, VAL_SET) + + @pytest.mark.xdist_group(name="imputation-autoformer") + def test_1_impute(self): + imputation_results = self.autoformer.predict(TEST_SET) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Autoformer test_MSE: {test_MSE}") + + @pytest.mark.xdist_group(name="imputation-autoformer") + def test_2_parameters(self): + assert hasattr(self.autoformer, "model") and self.autoformer.model is not None + + assert ( + hasattr(self.autoformer, "optimizer") + and self.autoformer.optimizer is not None + ) + + assert hasattr(self.autoformer, "best_loss") + self.assertNotEqual(self.autoformer.best_loss, float("inf")) + + assert ( + hasattr(self.autoformer, "best_model_dict") + and self.autoformer.best_model_dict is not None + ) + + @pytest.mark.xdist_group(name="imputation-autoformer") + def test_3_saving_path(self): + # whether the root saving dir exists, which should be created by save_log_into_tb_file + assert os.path.exists( + self.saving_path + ), f"file {self.saving_path} does not exist" + + # check if the tensorboard file and model checkpoints exist + check_tb_and_model_checkpoints_existence(self.autoformer) + + # save the trained model into file, and check if the path exists + saved_model_path = os.path.join(self.saving_path, self.model_save_name) + self.autoformer.save(saved_model_path) + + # test loading the saved model, not necessary, but need to test + self.autoformer.load(saved_model_path) + + @pytest.mark.xdist_group(name="imputation-autoformer") + def test_4_lazy_loading(self): + self.autoformer.fit(H5_TRAIN_SET_PATH, H5_VAL_SET_PATH) + imputation_results = self.autoformer.predict(H5_TEST_SET_PATH) + assert not np.isnan( + imputation_results["imputation"] + ).any(), "Output still has missing values after running impute()." + + test_MSE = calc_mse( + imputation_results["imputation"], + DATA["test_X_ori"], + DATA["test_X_indicating_mask"], + ) + logger.info(f"Lazy-loading Autoformer test_MSE: {test_MSE}") + + +if __name__ == "__main__": + unittest.main() From 00adcf45b81cb414413f94446daa8211c29bdf5d Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 25 Mar 2024 16:39:18 +0800 Subject: [PATCH 4/4] docs: update the docs of Autoformer and fix some typos; --- pypots/imputation/__init__.py | 4 +++- pypots/imputation/autoformer/model.py | 11 ++++++++++- pypots/imputation/csdi/model.py | 2 +- pypots/imputation/timesnet/model.py | 2 +- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/pypots/imputation/__init__.py b/pypots/imputation/__init__.py index 4e7605f6..ad174ca5 100644 --- a/pypots/imputation/__init__.py +++ b/pypots/imputation/__init__.py @@ -11,8 +11,9 @@ from .gpvae import GPVAE from .mrnn import MRNN from .saits import SAITS -from .timesnet import TimesNet from .transformer import Transformer +from .timesnet import TimesNet +from .autoformer import Autoformer from .usgan import USGAN # naive imputation methods @@ -25,6 +26,7 @@ "SAITS", "Transformer", "TimesNet", + "Autoformer", "BRITS", "MRNN", "GPVAE", diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py index c27e80cc..b1851868 100644 --- a/pypots/imputation/autoformer/model.py +++ b/pypots/imputation/autoformer/model.py @@ -42,7 +42,10 @@ class Autoformer(BaseNNImputer): The number of features in the time-series data sample. n_layers : - The number of layers in the 1st and 2nd DMSA blocks in the SAITS model. + The number of layers in the Autoformer model. + + n_heads : + The number of heads in each layer of Autoformer. d_model : The dimension of the model. @@ -50,6 +53,12 @@ class Autoformer(BaseNNImputer): d_ffn : The dimension of the feed-forward network. + factor : + The factor of the auto correlation mechanism for the Autoformer model. + + moving_avg_kernel_size : + The window size of moving average. + dropout : The dropout rate for the model. diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 7c48aaea..f0086d11 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -43,7 +43,7 @@ class CSDI(BaseNNImputer): The number of features in the time-series data sample. n_layers : - The number of layers in the 1st and 2nd DMSA blocks in the SAITS model. + The number of layers in the CSDI model. n_heads : The number of heads in the multi-head attention mechanism. diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py index 195805ff..9e93d2f9 100644 --- a/pypots/imputation/timesnet/model.py +++ b/pypots/imputation/timesnet/model.py @@ -42,7 +42,7 @@ class TimesNet(BaseNNImputer): The number of features in the time-series data sample. n_layers : - The number of layers in the 1st and 2nd DMSA blocks in the SAITS model. + The number of layers in the TimesNet model. top_k : The number of top-k amplitude values to be selected to obtain the most significant frequencies.