diff --git a/pypots/__init__.py b/pypots/__init__.py index f37b0aff..5238b9f7 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.2" +__version__ = "0.2.1" from . import imputation, classification, clustering, forecasting, optim, data, utils diff --git a/pypots/classification/brits/modules/core.py b/pypots/classification/brits/modules/core.py index a0abada6..c3cc9146 100644 --- a/pypots/classification/brits/modules/core.py +++ b/pypots/classification/brits/modules/core.py @@ -89,37 +89,36 @@ def forward(self, inputs: dict, training: bool = True) -> dict: ret_b = self._reverse(self.rits_b(inputs, "backward")) classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2 - if not training: - # if not in training mode, return the classification result only - return {"classification_pred": classification_pred} - - ret_f["classification_loss"] = F.nll_loss( - torch.log(ret_f["prediction"]), inputs["label"] - ) - ret_b["classification_loss"] = F.nll_loss( - torch.log(ret_b["prediction"]), inputs["label"] - ) - consistency_loss = self._get_consistency_loss( - ret_f["imputed_data"], ret_b["imputed_data"] - ) - classification_loss = ( - ret_f["classification_loss"] + ret_b["classification_loss"] - ) / 2 - reconstruction_loss = ( - ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"] - ) / 2 - - loss = ( - consistency_loss - + reconstruction_loss * self.reconstruction_weight - + classification_loss * self.classification_weight - ) - - results = { - "classification_pred": classification_pred, - "consistency_loss": consistency_loss, - "classification_loss": classification_loss, - "reconstruction_loss": reconstruction_loss, - "loss": loss, - } + results = {"classification_pred": classification_pred} + + # if in training mode, return results with losses + if training: + ret_f["classification_loss"] = F.nll_loss( + torch.log(ret_f["prediction"]), inputs["label"] + ) + ret_b["classification_loss"] = F.nll_loss( + torch.log(ret_b["prediction"]), inputs["label"] + ) + consistency_loss = self._get_consistency_loss( + ret_f["imputed_data"], ret_b["imputed_data"] + ) + classification_loss = ( + ret_f["classification_loss"] + ret_b["classification_loss"] + ) / 2 + reconstruction_loss = ( + ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"] + ) / 2 + + results["consistency_loss"] = consistency_loss + results["classification_loss"] = classification_loss + results["reconstruction_loss"] = reconstruction_loss + + # `loss` is always the item for backward propagating to update the model + loss = ( + consistency_loss + + reconstruction_loss * self.reconstruction_weight + + classification_loss * self.classification_weight + ) + results["loss"] = loss + return results diff --git a/pypots/classification/grud/modules/core.py b/pypots/classification/grud/modules/core.py index fe528655..c1b873f4 100644 --- a/pypots/classification/grud/modules/core.py +++ b/pypots/classification/grud/modules/core.py @@ -91,18 +91,14 @@ def forward(self, inputs: dict, training: bool = True) -> dict: logits = self.classifier(hidden_state) classification_pred = torch.softmax(logits, dim=1) + results = {"classification_pred": classification_pred} - if not training: - # if not in training mode, return the classification result only - return {"classification_pred": classification_pred} + # if in training mode, return results with losses + if training: + torch.log(classification_pred) + classification_loss = F.nll_loss( + torch.log(classification_pred), inputs["label"] + ) + results["loss"] = classification_loss - torch.log(classification_pred) - classification_loss = F.nll_loss( - torch.log(classification_pred), inputs["label"] - ) - - results = { - "classification_pred": classification_pred, - "loss": classification_loss, - } return results diff --git a/pypots/classification/raindrop/modules/core.py b/pypots/classification/raindrop/modules/core.py index 798ff5e6..68ba1f81 100644 --- a/pypots/classification/raindrop/modules/core.py +++ b/pypots/classification/raindrop/modules/core.py @@ -262,18 +262,13 @@ def classify(self, inputs: dict) -> torch.Tensor: def forward(self, inputs, training=True): classification_pred = self.classify(inputs) - if not training: - # if not in training mode, return the classification result only - return {"classification_pred": classification_pred} + results = {"classification_pred": classification_pred} - classification_loss = F.nll_loss( - torch.log(classification_pred), inputs["label"] - ) - - results = { - "prediction": classification_pred, - "loss": classification_loss - # 'distance': distance, - } + # if in training mode, return results with losses + if training: + classification_loss = F.nll_loss( + torch.log(classification_pred), inputs["label"] + ) + results["loss"] = classification_loss return results diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index c59d840e..2aa77239 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -270,7 +270,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward(inputs, return_loss=True) + results = self.model.forward(inputs, training=True) epoch_val_loss_G_collector.append( results["generation_loss"].sum().item() ) @@ -424,7 +424,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - inputs = self.model.forward(inputs, return_loss=False) + inputs = self.model.forward(inputs, training=False) clustering_latent_collector.append(inputs["fcn_latent"]) if return_latent_vars: imputation_collector.append(inputs["imputation_latent"]) diff --git a/pypots/clustering/crli/modules/core.py b/pypots/clustering/crli/modules/core.py index 3d1cddd4..a4c16a2a 100644 --- a/pypots/clustering/crli/modules/core.py +++ b/pypots/clustering/crli/modules/core.py @@ -58,7 +58,7 @@ def forward( self, inputs: dict, training_object: str = "generator", - return_loss: bool = True, + training: bool = True, ) -> dict: X = inputs["X"] missing_mask = inputs["missing_mask"] @@ -76,7 +76,7 @@ def forward( inputs["fcn_latent"] = fcn_latent # return results directly, skip loss calculation to reduce inference time - if not return_loss: + if not training: return inputs if training_object == "discriminator": @@ -106,4 +106,5 @@ def forward( l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans losses["generation_loss"] = loss_gene + return losses diff --git a/pypots/clustering/vader/modules/core.py b/pypots/clustering/vader/modules/core.py index 4ffea516..8ff2f4ac 100644 --- a/pypots/clustering/vader/modules/core.py +++ b/pypots/clustering/vader/modules/core.py @@ -173,18 +173,15 @@ def forward( stddev_tilde, ) = self.get_results(X, missing_mask) - if not training and not pretrain: - results = { - "mu_tilde": mu_tilde, - "stddev_tilde": stddev_tilde, - "mu": mu_c, - "var": var_c, - "phi": phi_c, - "z": z, - "imputation_latent": X_reconstructed, - } - # if only run clustering, then no need to calculate loss - return results + results = { + "mu_tilde": mu_tilde, + "stddev_tilde": stddev_tilde, + "mu": mu_c, + "var": var_c, + "phi": phi_c, + "z": z, + "imputation_latent": X_reconstructed, + } # calculate the reconstruction loss unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask) @@ -194,66 +191,68 @@ def forward( * self.d_input / missing_mask.sum() ) + if pretrain: - results = {"loss": reconstruction_loss, "z": z} + results["loss"] = reconstruction_loss return results - # calculate the latent loss - var_tilde = torch.exp(stddev_tilde) - stddev_c = torch.log(var_c + self.eps) - log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device)) - log_phi_c = torch.log(phi_c + self.eps) - - batch_size = z.shape[0] - - ii, jj = torch.meshgrid( - torch.arange(self.n_clusters, dtype=torch.int64, device=device), - torch.arange(batch_size, dtype=torch.int64, device=device), - indexing="ij", - ) - ii = ii.flatten() - jj = jj.flatten() - - lsc_b = stddev_c.index_select(dim=0, index=ii) - mc_b = mu_c.index_select(dim=0, index=ii) - sc_b = var_c.index_select(dim=0, index=ii) - z_b = z.index_select(dim=0, index=jj) - log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b) - log_pdf_z = log_pdf_z.reshape([batch_size, self.n_clusters, self.d_mu_stddev]) - - log_p = log_phi_c + log_pdf_z.sum(dim=2) - lse_p = log_p.logsumexp(dim=1, keepdim=True) - log_gamma_c = log_p - lse_p - gamma_c = torch.exp(log_gamma_c) - - term1 = torch.log(var_c + self.eps) - st_b = var_tilde.index_select(dim=0, index=jj) - sc_b = var_c.index_select(dim=0, index=ii) - term2 = torch.reshape( - st_b / (sc_b + self.eps), [batch_size, self.n_clusters, self.d_mu_stddev] - ) - mt_b = mu_tilde.index_select(dim=0, index=jj) - mc_b = mu_c.index_select(dim=0, index=ii) - term3 = torch.reshape( - torch.square(mt_b - mc_b) / (sc_b + self.eps), - [batch_size, self.n_clusters, self.d_mu_stddev], - ) - - latent_loss1 = 0.5 * torch.sum( - gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1 - ) - latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1) - latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1) - - latent_loss1 = latent_loss1.mean() - latent_loss2 = latent_loss2.mean() - latent_loss3 = latent_loss3.mean() - latent_loss = latent_loss1 + latent_loss2 + latent_loss3 - - results = { - "loss": reconstruction_loss + self.alpha * latent_loss, - "z": z, - "imputation_latent": X_reconstructed, - } + # if in training mode, return results with losses + if training: + # calculate the latent loss for model training + var_tilde = torch.exp(stddev_tilde) + stddev_c = torch.log(var_c + self.eps) + log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device)) + log_phi_c = torch.log(phi_c + self.eps) + + batch_size = z.shape[0] + + ii, jj = torch.meshgrid( + torch.arange(self.n_clusters, dtype=torch.int64, device=device), + torch.arange(batch_size, dtype=torch.int64, device=device), + indexing="ij", + ) + ii = ii.flatten() + jj = jj.flatten() + + lsc_b = stddev_c.index_select(dim=0, index=ii) + mc_b = mu_c.index_select(dim=0, index=ii) + sc_b = var_c.index_select(dim=0, index=ii) + z_b = z.index_select(dim=0, index=jj) + log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b) + log_pdf_z = log_pdf_z.reshape( + [batch_size, self.n_clusters, self.d_mu_stddev] + ) + + log_p = log_phi_c + log_pdf_z.sum(dim=2) + lse_p = log_p.logsumexp(dim=1, keepdim=True) + log_gamma_c = log_p - lse_p + gamma_c = torch.exp(log_gamma_c) + + term1 = torch.log(var_c + self.eps) + st_b = var_tilde.index_select(dim=0, index=jj) + sc_b = var_c.index_select(dim=0, index=ii) + term2 = torch.reshape( + st_b / (sc_b + self.eps), + [batch_size, self.n_clusters, self.d_mu_stddev], + ) + mt_b = mu_tilde.index_select(dim=0, index=jj) + mc_b = mu_c.index_select(dim=0, index=ii) + term3 = torch.reshape( + torch.square(mt_b - mc_b) / (sc_b + self.eps), + [batch_size, self.n_clusters, self.d_mu_stddev], + ) + + latent_loss1 = 0.5 * torch.sum( + gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1 + ) + latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1) + latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1) + + latent_loss1 = latent_loss1.mean() + latent_loss2 = latent_loss2.mean() + latent_loss3 = latent_loss3.mean() + latent_loss = latent_loss1 + latent_loss2 + latent_loss3 + + results["loss"] = reconstruction_loss + self.alpha * latent_loss return results diff --git a/pypots/imputation/brits/modules/core.py b/pypots/imputation/brits/modules/core.py index 31f3fc43..83b48f95 100644 --- a/pypots/imputation/brits/modules/core.py +++ b/pypots/imputation/brits/modules/core.py @@ -316,27 +316,23 @@ def forward(self, inputs: dict, training: bool = True) -> dict: imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2 - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - - consistency_loss = self._get_consistency_loss( - ret_f["imputed_data"], ret_b["imputed_data"] - ) - - # `loss` is always the item for backward propagating to update the model - loss = ( - consistency_loss - + ret_f["reconstruction_loss"] - + ret_b["reconstruction_loss"] - ) - results = { "imputed_data": imputed_data, - "consistency_loss": consistency_loss, - "loss": loss, # will be used for backward propagating to update the model } + # if in training mode, return results with losses + if training: + consistency_loss = self._get_consistency_loss( + ret_f["imputed_data"], ret_b["imputed_data"] + ) + + # `loss` is always the item for backward propagating to update the model + loss = ( + consistency_loss + + ret_f["reconstruction_loss"] + + ret_b["reconstruction_loss"] + ) + results["consistency_loss"] = consistency_loss + results["loss"] = loss + return results diff --git a/pypots/imputation/csdi/modules/core.py b/pypots/imputation/csdi/modules/core.py index 12a03cf0..b25fd190 100644 --- a/pypots/imputation/csdi/modules/core.py +++ b/pypots/imputation/csdi/modules/core.py @@ -65,6 +65,10 @@ def __init__( ) elif schedule == "linear": self.beta = np.linspace(beta_start, beta_end, self.n_diffusion_steps) + else: + raise ValueError( + f"The argument schedule should be 'quad' or 'linear', but got {schedule}" + ) self.alpha_hat = 1 - self.beta self.alpha = np.cumprod(self.alpha_hat) @@ -72,7 +76,8 @@ def __init__( "alpha_torch", torch.tensor(self.alpha).float().unsqueeze(1).unsqueeze(1) ) - def time_embedding(self, pos, d_model=128): + @staticmethod + def time_embedding(pos, d_model=128): pe = torch.zeros(pos.shape[0], pos.shape[1], d_model).to(pos.device) position = pos.unsqueeze(2) div_term = 1 / torch.pow( @@ -82,7 +87,8 @@ def time_embedding(self, pos, d_model=128): pe[:, :, 1::2] = torch.cos(position * div_term) return pe - def get_randmask(self, observed_mask): + @staticmethod + def get_rand_mask(observed_mask): rand_for_mask = torch.rand_like(observed_mask) * observed_mask rand_for_mask = rand_for_mask.reshape(len(rand_for_mask), -1) for i in range(len(observed_mask)): @@ -97,14 +103,14 @@ def get_hist_mask(self, observed_mask, for_pattern_mask=None): if for_pattern_mask is None: for_pattern_mask = observed_mask if self.target_strategy == "mix": - rand_mask = self.get_randmask(observed_mask) + rand_mask = self.get_rand_mask(observed_mask) cond_mask = observed_mask.clone() for i in range(len(cond_mask)): mask_choice = np.random.rand() if self.target_strategy == "mix" and mask_choice > 0.5: cond_mask[i] = rand_mask[i] - else: # draw another sample for histmask (i-1 corresponds to another sample) + else: # draw another sample for hist mask (i-1 corresponds to another sample) cond_mask[i] = cond_mask[i] * for_pattern_mask[i - 1] return cond_mask @@ -241,23 +247,22 @@ def forward(self, inputs, training=True, n_sampling_times=1): observed_mask, for_pattern_mask=for_pattern_mask ) else: - cond_mask = self.get_randmask(observed_mask) + cond_mask = self.get_rand_mask(observed_mask) side_info = self.get_side_info(observed_tp, cond_mask) - loss_func = self.calc_loss if training == 1 else self.calc_loss_valid + loss_func = self.calc_loss if training else self.calc_loss_valid # `loss` is always the item for backward propagating to update the model loss = loss_func(observed_data, cond_mask, observed_mask, side_info, training) + results = {"loss": loss} - results = { - "loss": loss, # will be used for backward propagating to update the model - } if not training: samples = self.impute( observed_data, cond_mask, side_info, n_sampling_times - ) # (B,nsample,K,L) + ) # (B,bz,K,L) imputation = samples.mean(dim=1) # (B,K,L) imputed_data = observed_data + imputation * (1 - gt_mask) results["imputed_data"] = imputed_data.permute(0, 2, 1) # (B,L,K) + return results diff --git a/pypots/imputation/csdi/modules/submodules.py b/pypots/imputation/csdi/modules/submodules.py index be197643..31e71fff 100644 --- a/pypots/imputation/csdi/modules/submodules.py +++ b/pypots/imputation/csdi/modules/submodules.py @@ -19,7 +19,7 @@ def get_torch_trans(heads=8, layers=1, channels=64): return nn.TransformerEncoder(encoder_layer, num_layers=layers) -def Conv1d_with_init(in_channels, out_channels, kernel_size): +def conv1d_with_init(in_channels, out_channels, kernel_size): layer = nn.Conv1d(in_channels, out_channels, kernel_size) nn.init.kaiming_normal_(layer.weight) return layer @@ -46,7 +46,8 @@ def forward(self, diffusion_step): x = F.silu(x) return x - def _build_embedding(self, n_steps, d_embedding=64): + @staticmethod + def _build_embedding(n_steps, d_embedding=64): steps = torch.arange(n_steps).unsqueeze(1) # (T,1) frequencies = 10.0 ** ( torch.arange(d_embedding) / (d_embedding - 1) * 4.0 @@ -62,9 +63,9 @@ class ResidualBlock(nn.Module): def __init__(self, d_side, n_channels, diffusion_embedding_dim, nheads): super().__init__() self.diffusion_projection = nn.Linear(diffusion_embedding_dim, n_channels) - self.cond_projection = Conv1d_with_init(d_side, 2 * n_channels, 1) - self.mid_projection = Conv1d_with_init(n_channels, 2 * n_channels, 1) - self.output_projection = Conv1d_with_init(n_channels, 2 * n_channels, 1) + self.cond_projection = conv1d_with_init(d_side, 2 * n_channels, 1) + self.mid_projection = conv1d_with_init(n_channels, 2 * n_channels, 1) + self.output_projection = conv1d_with_init(n_channels, 2 * n_channels, 1) self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=n_channels) self.feature_layer = get_torch_trans( @@ -135,9 +136,9 @@ def __init__( n_diffusion_steps=n_diffusion_steps, d_embedding=d_diffusion_embedding, ) - self.input_projection = Conv1d_with_init(d_input, n_channels, 1) - self.output_projection1 = Conv1d_with_init(n_channels, n_channels, 1) - self.output_projection2 = Conv1d_with_init(n_channels, 1, 1) + self.input_projection = conv1d_with_init(d_input, n_channels, 1) + self.output_projection1 = conv1d_with_init(n_channels, n_channels, 1) + self.output_projection2 = conv1d_with_init(n_channels, 1, 1) nn.init.zeros_(self.output_projection2.weight) self.residual_layers = nn.ModuleList( diff --git a/pypots/imputation/gpvae/modules/core.py b/pypots/imputation/gpvae/modules/core.py index 935daca0..7a37ffde 100644 --- a/pypots/imputation/gpvae/modules/core.py +++ b/pypots/imputation/gpvae/modules/core.py @@ -109,6 +109,53 @@ def decode(self, z): assert num_dim > 2 return self.decoder(torch.transpose(z, num_dim - 1, num_dim - 2)) + @staticmethod + def kl_divergence(a, b): + return torch.distributions.kl.kl_divergence(a, b) + + def _init_prior(self, device="cpu"): + # Compute kernel matrices for each latent dimension + kernel_matrices = [] + for i in range(self.kernel_scales): + if self.kernel == "rbf": + kernel_matrices.append( + rbf_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "diffusion": + kernel_matrices.append( + diffusion_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "matern": + kernel_matrices.append( + matern_kernel(self.time_length, self.length_scale / 2**i) + ) + elif self.kernel == "cauchy": + kernel_matrices.append( + cauchy_kernel( + self.time_length, self.sigma, self.length_scale / 2**i + ) + ) + + # Combine kernel matrices for each latent dimension + tiled_matrices = [] + total = 0 + for i in range(self.kernel_scales): + if i == self.kernel_scales - 1: + multiplier = self.latent_dim - total + else: + multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) + total += multiplier + tiled_matrices.append( + torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) + ) + kernel_matrix_tiled = torch.cat(tiled_matrices) + assert len(kernel_matrix_tiled) == self.latent_dim + prior = torch.distributions.MultivariateNormal( + loc=torch.zeros(self.latent_dim, self.time_length, device=device), + covariance_matrix=kernel_matrix_tiled.to(device), + ) + return prior + def forward(self, inputs, training=True): x = inputs["X"] m_mask = inputs["missing_mask"] @@ -151,63 +198,12 @@ def forward(self, inputs, training=True): elbo = elbo.mean() imputed_data = self.decode(self.encode(x).mean).mean * ~m_mask + x * m_mask - - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - results = { - "loss": -elbo.mean(), "imputed_data": imputed_data, } - return results - - @staticmethod - def kl_divergence(a, b): - return torch.distributions.kl.kl_divergence(a, b) - - def _init_prior(self, device="cpu"): - # Compute kernel matrices for each latent dimension - kernel_matrices = [] - for i in range(self.kernel_scales): - if self.kernel == "rbf": - kernel_matrices.append( - rbf_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "diffusion": - kernel_matrices.append( - diffusion_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "matern": - kernel_matrices.append( - matern_kernel(self.time_length, self.length_scale / 2**i) - ) - elif self.kernel == "cauchy": - kernel_matrices.append( - cauchy_kernel( - self.time_length, self.sigma, self.length_scale / 2**i - ) - ) - # Combine kernel matrices for each latent dimension - tiled_matrices = [] - total = 0 - for i in range(self.kernel_scales): - if i == self.kernel_scales - 1: - multiplier = self.latent_dim - total - else: - multiplier = int(np.ceil(self.latent_dim / self.kernel_scales)) - total += multiplier - tiled_matrices.append( - torch.unsqueeze(kernel_matrices[i], 0).repeat(multiplier, 1, 1) - ) - kernel_matrix_tiled = torch.cat(tiled_matrices) - assert len(kernel_matrix_tiled) == self.latent_dim - prior = torch.distributions.MultivariateNormal( - loc=torch.zeros(self.latent_dim, self.time_length, device=device), - covariance_matrix=kernel_matrix_tiled.to(device), - ) + # if in training mode, return results with losses + if training: + results["loss"] = -elbo.mean() - return prior + return results diff --git a/pypots/imputation/mrnn/modules/core.py b/pypots/imputation/mrnn/modules/core.py index 865b60b2..e4936ec8 100644 --- a/pypots/imputation/mrnn/modules/core.py +++ b/pypots/imputation/mrnn/modules/core.py @@ -52,7 +52,7 @@ def gene_hidden_states(self, inputs, direction): hidden_states_collector.append(hidden_state) return hidden_states_collector - def forward(self, inputs, training=True): + def forward(self, inputs: dict, training: bool = True) -> dict: hidden_states_f = self.gene_hidden_states(inputs, "forward") hidden_states_b = self.gene_hidden_states(inputs, "backward")[::-1] @@ -82,16 +82,13 @@ def forward(self, inputs, training=True): estimations = torch.cat(estimations, dim=1) imputed_data = masks * X + (1 - masks) * estimations - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - - reconstruction_loss /= self.seq_len - - ret_dict = { - "loss": reconstruction_loss, + results = { "imputed_data": imputed_data, } - return ret_dict + + # if in training mode, return results with losses + if training: + reconstruction_loss /= self.seq_len + results["loss"] = reconstruction_loss + + return results diff --git a/pypots/imputation/mrnn/modules/submodules.py b/pypots/imputation/mrnn/modules/submodules.py index f157d3fb..b61eb2f2 100644 --- a/pypots/imputation/mrnn/modules/submodules.py +++ b/pypots/imputation/mrnn/modules/submodules.py @@ -38,7 +38,7 @@ def reset_parameters(self): self.beta.data.uniform_(-stdv, stdv) def forward(self, x_t, m_t, target): - h_t = F.tanh( + h_t = torch.tanh( F.linear(x_t, self.U * self.m) + F.linear(target, self.V1 * self.m) + F.linear(m_t, self.V2) diff --git a/pypots/imputation/saits/modules/core.py b/pypots/imputation/saits/modules/core.py index 1f7ef721..eb062709 100644 --- a/pypots/imputation/saits/modules/core.py +++ b/pypots/imputation/saits/modules/core.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.functional as F -from ....modules.self_attention import EncoderLayer, PositionalEncoding +from ....modules.transformer import EncoderLayer, PositionalEncoding from ....utils.metrics import cal_mae @@ -81,7 +81,7 @@ def __init__( ) self.dropout = nn.Dropout(p=dropout) - self.position_enc = PositionalEncoding(d_model, n_position=n_steps) + self.position_enc = PositionalEncoding(d_model, n_positions=n_steps) # for the 1st block self.embedding_1 = nn.Linear(actual_n_features, d_model) self.reduce_dim_z = nn.Linear(d_model, n_features) @@ -180,27 +180,23 @@ def forward( "combining_weights": combining_weights, "imputed_data": imputed_data, } - if not training: - # if not in training mode, return the classification result only - return results - - ORT_loss = 0 - ORT_loss += self.customized_loss_func(X_tilde_1, X, masks) - ORT_loss += self.customized_loss_func(X_tilde_2, X, masks) - ORT_loss += self.customized_loss_func(X_tilde_3, X, masks) - ORT_loss /= 3 - - MIT_loss = self.customized_loss_func( - X_tilde_3, inputs["X_intact"], inputs["indicating_mask"] - ) - # `loss` is always the item for backward propagating to update the model - loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss + # if in training mode, return results with losses + if training: + ORT_loss = 0 + ORT_loss += self.customized_loss_func(X_tilde_1, X, masks) + ORT_loss += self.customized_loss_func(X_tilde_2, X, masks) + ORT_loss += self.customized_loss_func(X_tilde_3, X, masks) + ORT_loss /= 3 - results["ORT_loss"] = ORT_loss - results["MIT_loss"] = MIT_loss + MIT_loss = self.customized_loss_func( + X_tilde_3, inputs["X_intact"], inputs["indicating_mask"] + ) - # will be used for backward propagating to update the model - results["loss"] = loss + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss + results["loss"] = loss return results diff --git a/pypots/imputation/transformer/modules/core.py b/pypots/imputation/transformer/modules/core.py index c09e5d4d..34750da8 100644 --- a/pypots/imputation/transformer/modules/core.py +++ b/pypots/imputation/transformer/modules/core.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn -from ....modules.self_attention import EncoderLayer, PositionalEncoding +from ....modules.transformer import EncoderLayer, PositionalEncoding from ....utils.metrics import cal_mae @@ -60,7 +60,7 @@ def __init__( ) self.embedding = nn.Linear(actual_d_feature, d_model) - self.position_enc = PositionalEncoding(d_model, n_position=d_time) + self.position_enc = PositionalEncoding(d_model, n_positions=d_time) self.dropout = nn.Dropout(p=dropout) self.reduce_dim = nn.Linear(d_model, d_feature) @@ -83,24 +83,20 @@ def forward(self, inputs: dict, training: bool = True) -> dict: X, masks = inputs["X"], inputs["missing_mask"] imputed_data, learned_presentation = self._process(inputs) - if not training: - # if not in training mode, return the classification result only - return { - "imputed_data": imputed_data, - } - - ORT_loss = cal_mae(learned_presentation, X, masks) - MIT_loss = cal_mae( - learned_presentation, inputs["X_intact"], inputs["indicating_mask"] - ) - - # `loss` is always the item for backward propagating to update the model - loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss - results = { "imputed_data": imputed_data, - "ORT_loss": ORT_loss, - "MIT_loss": MIT_loss, - "loss": loss, } + + # if in training mode, return results with losses + if training: + ORT_loss = cal_mae(learned_presentation, X, masks) + MIT_loss = cal_mae( + learned_presentation, inputs["X_intact"], inputs["indicating_mask"] + ) + results["ORT_loss"] = ORT_loss + results["MIT_loss"] = MIT_loss + # `loss` is always the item for backward propagating to update the model + loss = self.ORT_weight * ORT_loss + self.MIT_weight * MIT_loss + results["loss"] = loss + return results diff --git a/pypots/imputation/usgan/modules/core.py b/pypots/imputation/usgan/modules/core.py index 71c43f84..16504d6b 100644 --- a/pypots/imputation/usgan/modules/core.py +++ b/pypots/imputation/usgan/modules/core.py @@ -56,29 +56,30 @@ def forward( "discriminator", ], 'training_object should be "generator" or "discriminator"' - forward_X = inputs["forward"]["X"] - forward_missing_mask = inputs["forward"]["missing_mask"] - losses = {} results = self.generator(inputs, training=training) - inputs["discrimination"] = self.discriminator(forward_X, forward_missing_mask) - if not training: - # if only run imputation operation, then no need to calculate loss - return results - - if training_object == "discriminator": - l_D = F.binary_cross_entropy_with_logits( - inputs["discrimination"], forward_missing_mask - ) - losses["discrimination_loss"] = l_D - else: - inputs["discrimination"] = inputs["discrimination"].detach() - l_G = F.binary_cross_entropy_with_logits( - inputs["discrimination"], - 1 - forward_missing_mask, - weight=1 - forward_missing_mask, + + # if in training mode, return results with losses + if training: + forward_X = inputs["forward"]["X"] + forward_missing_mask = inputs["forward"]["missing_mask"] + + inputs["discrimination"] = self.discriminator( + forward_X, forward_missing_mask ) - loss_gene = l_G + self.lambda_mse * results["loss"] - losses["generation_loss"] = loss_gene - losses["imputed_data"] = results["imputed_data"] - return losses + if training_object == "discriminator": + l_D = F.binary_cross_entropy_with_logits( + inputs["discrimination"], forward_missing_mask + ) + results["discrimination_loss"] = l_D + else: + inputs["discrimination"] = inputs["discrimination"].detach() + l_G = F.binary_cross_entropy_with_logits( + inputs["discrimination"], + 1 - forward_missing_mask, + weight=1 - forward_missing_mask, + ) + loss_gene = l_G + self.lambda_mse * results["loss"] + results["generation_loss"] = loss_gene + + return results diff --git a/pypots/modules/self_attention.py b/pypots/modules/self_attention.py deleted file mode 100644 index d44fe63a..00000000 --- a/pypots/modules/self_attention.py +++ /dev/null @@ -1,729 +0,0 @@ -""" -The implementation of the modules for Transformer :cite:`vaswani2017Transformer` - -Notes ------ -Partial implementation uses code from https://github.com/WenjieDu/SAITS, -and https://github.com/jadore801120/attention-is-all-you-need-pytorch. - -""" - -# Created by Wenjie Du -# License: BSD-3-Clause - -from typing import Tuple, Optional, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ScaledDotProductAttention(nn.Module): - """Scaled dot-product attention. - - Parameters - ---------- - temperature: - The temperature for scaling. - - attn_dropout: - The dropout rate for the attention map. - - """ - - def __init__(self, temperature: float, attn_dropout: float = 0.1): - super().__init__() - assert temperature > 0, "temperature should be positive" - assert attn_dropout >= 0, "dropout rate should be non-negative" - self.temperature = temperature - self.dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else None - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward processing of the scaled dot-product attention. - - Parameters - ---------- - q: - Query tensor. - k: - Key tensor. - v: - Value tensor. - - attn_mask: - Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. - 0 in attn_mask means values at the according position in the attention map will be masked out. - - Returns - ------- - output: - The result of Value multiplied with the scaled dot-product attention map. - - attn: - The scaled dot-product attention map. - - """ - # q, k, v all have 4 dimensions [batch_size, n_heads, n_steps, d_tensor] - # d_tensor could be d_q, d_k, d_v - - # dot product q with k.T to obtain similarity - attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) - - # apply masking on the attention map, this is optional - if attn_mask is not None: - attn = attn.masked_fill(attn_mask == 0, -1e9) - - # compute attention score [0, 1], then apply dropout - attn = F.softmax(attn, dim=-1) - if self.dropout is not None: - attn = self.dropout(attn) - - # multiply the score with v - output = torch.matmul(attn, v) - return output, attn - - -class MultiHeadAttention(nn.Module): - """Transformer multi-head attention module. - - Parameters - ---------- - n_heads: - The number of heads in multi-head attention. - - d_model: - The dimension of the input tensor. - - d_k: - The dimension of the key and query tensor. - - d_v: - The dimension of the value tensor. - - dropout: - The dropout rate. - - attn_dropout: - The dropout rate for the attention map. - - """ - - def __init__( - self, - n_heads: int, - d_model: int, - d_k: int, - d_v: int, - dropout: float, - attn_dropout: float, - ): - super().__init__() - - self.n_heads = n_heads - self.d_k = d_k - self.d_v = d_v - - self.w_qs = nn.Linear(d_model, n_heads * d_k, bias=False) - self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False) - self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False) - - self.attention = ScaledDotProductAttention(d_k**0.5, attn_dropout) - self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) - - self.dropout = nn.Dropout(dropout) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) - - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - attn_mask: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward processing of the multi-head attention module. - - Parameters - ---------- - q: - Query tensor. - - k: - Key tensor. - - v: - Value tensor. - - attn_mask: - Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. - 0 in attn_mask means values at the according position in the attention map will be masked out. - - Returns - ------- - v: - The output of the multi-head attention layer. - - attn_weights: - The attention map. - - """ - # the input q, k, v currently have 3 dimensions [batch_size, n_steps, d_tensor] - # d_tensor could be n_heads*d_k, n_heads*d_v - - # keep useful variables - batch_size, n_steps = q.size(0), q.size(1) - residual = q - - # now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v] - q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k) - k = self.w_ks(k).view(batch_size, n_steps, self.n_heads, self.d_k) - v = self.w_vs(v).view(batch_size, n_steps, self.n_heads, self.d_v) - - # transpose for self-attention calculation -> [batch_size, n_steps, d_k or d_v, n_heads] - q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) - - if attn_mask is not None: - # broadcasting on the head axis - attn_mask = attn_mask.unsqueeze(1) - - v, attn_weights = self.attention(q, k, v, attn_mask) - - # transpose back -> [batch_size, n_steps, n_heads, d_v] - # then merge the last two dimensions to combine all the heads -> [batch_size, n_steps, n_heads*d_v] - v = v.transpose(1, 2).contiguous().view(batch_size, n_steps, -1) - v = self.fc(v) - - # apply dropout and residual connection - v = self.dropout(v) - v += residual - - # apply layer-norm - v = self.layer_norm(v) - - return v, attn_weights - - -class PositionWiseFeedForward(nn.Module): - """Position-wise feed forward network (FFN) in Transformer. - - Parameters - ---------- - d_in: - The dimension of the input tensor. - - d_hid: - The dimension of the hidden layer. - - dropout: - The dropout rate. - - """ - - def __init__(self, d_in: int, d_hid: int, dropout: float = 0.1): - super().__init__() - self.linear_1 = nn.Linear(d_in, d_hid) - self.linear_2 = nn.Linear(d_hid, d_in) - self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) - self.dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward processing of the position-wise feed forward network. - - Parameters - ---------- - x: - Input tensor. - - Returns - ------- - x: - Output tensor. - """ - # save the original input for the later residual connection - residual = x - # the 1st linear processing and ReLU non-linear projection - x = F.relu(self.linear_1(x)) - # the 2nd linear processing - x = self.linear_2(x) - # apply dropout - x = self.dropout(x) - # apply residual connection - x += residual - # apply layer-norm - x = self.layer_norm(x) - return x - - -class EncoderLayer(nn.Module): - """Transformer encoder layer. - - Parameters - ---------- - d_model: - The dimension of the input tensor. - - d_inner: - The dimension of the hidden layer. - - n_heads: - The number of heads in multi-head attention. - - d_k: - The dimension of the key and query tensor. - - d_v: - The dimension of the value tensor. - - dropout: - The dropout rate. - - attn_dropout: - The dropout rate for the attention map. - """ - - def __init__( - self, - d_model: int, - d_inner: int, - n_heads: int, - d_k: int, - d_v: int, - dropout: float = 0.1, - attn_dropout: float = 0.1, - ): - super().__init__() - self.slf_attn = MultiHeadAttention( - n_heads, d_model, d_k, d_v, dropout, attn_dropout - ) - self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout) - - def forward( - self, - enc_input: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Forward processing of the encoder layer. - - Parameters - ---------- - enc_input: - Input tensor. - - src_mask: - Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. - - Returns - ------- - enc_output: - Output tensor. - - attn_weights: - The attention map. - - """ - enc_output, attn_weights = self.slf_attn( - enc_input, - enc_input, - enc_input, - attn_mask=src_mask, - ) - enc_output = self.pos_ffn(enc_output) - return enc_output, attn_weights - - -class DecoderLayer(nn.Module): - """Transformer decoder layer. - - Parameters - ---------- - d_model: - The dimension of the input tensor. - - d_inner: - The dimension of the hidden layer. - - n_heads: - The number of heads in multi-head attention. - - d_k: - The dimension of the key and query tensor. - - d_v: - The dimension of the value tensor. - - dropout: - The dropout rate. - - attn_dropout: - The dropout rate for the attention map. - - """ - - def __init__( - self, - d_model: int, - d_inner: int, - n_heads: int, - d_k: int, - d_v: int, - dropout: float = 0.1, - attn_dropout: float = 0.1, - ): - super().__init__() - self.slf_attn = MultiHeadAttention( - n_heads, d_model, d_k, d_v, dropout, attn_dropout - ) - self.enc_attn = MultiHeadAttention( - n_heads, d_model, d_k, d_v, dropout, attn_dropout - ) - self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout) - - def forward( - self, - dec_input: torch.Tensor, - enc_output: torch.Tensor, - slf_attn_mask: Optional[torch.Tensor] = None, - dec_enc_attn_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Forward processing of the decoder layer. - - Parameters - ---------- - dec_input: - Input tensor. - - enc_output: - Output tensor from the encoder. - - slf_attn_mask: - Masking tensor for the self-attention module. - The shape should be [batch_size, n_heads, n_steps, n_steps]. - - dec_enc_attn_mask: - Masking tensor for the encoding attention module. - The shape should be [batch_size, n_heads, n_steps, n_steps]. - - Returns - ------- - dec_output: - Output tensor. - - dec_slf_attn: - The self-attention map. - - dec_enc_attn: - The encoding attention map. - - """ - dec_output, dec_slf_attn = self.slf_attn( - dec_input, dec_input, dec_input, attn_mask=slf_attn_mask - ) - dec_output, dec_enc_attn = self.enc_attn( - dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask - ) - dec_output = self.pos_ffn(dec_output) - return dec_output, dec_slf_attn, dec_enc_attn - - -class Encoder(nn.Module): - """Transformer encoder. - - Parameters - ---------- - n_layers: - The number of layers in the encoder. - - n_steps: - The number of time steps in the input tensor. - - n_features: - The number of features in the input tensor. - - d_model: - The dimension of the module manipulation space. - The input tensor will be projected to a space with d_model dimensions. - - d_inner: - The dimension of the hidden layer in the feed-forward network. - - n_heads: - The number of heads in multi-head attention. - - d_k: - The dimension of the key and query tensor. - - d_v: - The dimension of the value tensor. - - dropout: - The dropout rate. - - attn_dropout: - The dropout rate for the attention map. - - """ - - def __init__( - self, - n_layers: int, - n_steps: int, - n_features: int, - d_model: int, - d_inner: int, - n_heads: int, - d_k: int, - d_v: int, - dropout: float, - attn_dropout: float, - ): - super().__init__() - - self.embedding = nn.Linear(n_features, d_model) - self.dropout = nn.Dropout(dropout) - self.position_enc = PositionalEncoding(d_model, n_position=n_steps) - self.enc_layer_stack = nn.ModuleList( - [ - EncoderLayer( - d_model, - d_inner, - n_heads, - d_k, - d_v, - dropout, - attn_dropout, - ) - for _ in range(n_layers) - ] - ) - - def forward( - self, - x: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - return_attn_weights: bool = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]: - """Forward processing of the encoder. - - Parameters - ---------- - x: - Input tensor. - - src_mask: - Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. - - return_attn_weights: - Whether to return the attention map. - - Returns - ------- - enc_output: - Output tensor. - - attn_weights_collector: - A list containing the attention map from each encoder layer. - - """ - x = self.embedding(x) - enc_output = self.dropout(self.position_enc(x)) - attn_weights_collector = [] - - for layer in self.enc_layer_stack: - enc_output, attn_weights = layer(enc_output, src_mask) - attn_weights_collector.append(attn_weights) - - if return_attn_weights: - return enc_output, attn_weights_collector - - return enc_output - - -class Decoder(nn.Module): - """Transformer decoder. - - Parameters - ---------- - n_layers: - The number of layers in the decoder. - - n_steps: - The number of time steps in the input tensor. - - n_features: - The number of features in the input tensor. - - d_model: - The dimension of the module manipulation space. - The input tensor will be projected to a space with d_model dimensions. - - d_inner: - The dimension of the hidden layer in the feed-forward network. - - n_heads: - The number of heads in multi-head attention. - - d_k: - The dimension of the key and query tensor. - - d_v: - The dimension of the value tensor. - - dropout: - The dropout rate. - - attn_dropout: - The dropout rate for the attention map. - - """ - - def __init__( - self, - n_layers: int, - n_steps: int, - n_features: int, - d_model: int, - d_inner: int, - n_heads: int, - d_k: int, - d_v: int, - dropout: float, - attn_dropout: float, - ): - super().__init__() - self.embedding = nn.Linear(n_features, d_model) - self.dropout = nn.Dropout(dropout) - self.position_enc = PositionalEncoding(d_model, n_position=n_steps) - self.layer_stack = nn.ModuleList( - [ - DecoderLayer( - d_model, - d_inner, - n_heads, - d_k, - d_v, - dropout, - attn_dropout, - ) - for _ in range(n_layers) - ] - ) - - def forward( - self, - trg_seq: torch.Tensor, - enc_output: torch.Tensor, - trg_mask: Optional[torch.Tensor] = None, - src_mask: Optional[torch.Tensor] = None, - return_attn_weights: bool = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, list, list]]: - """Forward processing of the decoder. - - Parameters - ---------- - trg_seq: - Input tensor. - - enc_output: - Output tensor from the encoder. - - trg_mask: - Masking tensor for the self-attention module. - - src_mask: - Masking tensor for the encoding attention module. - - return_attn_weights: - Whether to return the attention map. - - Returns - ------- - dec_output: - Output tensor. - - dec_slf_attn_collector: - A list containing the self-attention map from each decoder layer. - - dec_enc_attn_collector: - A list containing the encoding attention map from each decoder layer. - - """ - trg_seq = self.embedding(trg_seq) - dec_output = self.dropout(self.position_enc(trg_seq)) - - dec_slf_attn_collector = [] - dec_enc_attn_collector = [] - - for layer in self.layer_stack: - dec_output, dec_slf_attn, dec_enc_attn = layer( - dec_output, - enc_output, - slf_attn_mask=trg_mask, - dec_enc_attn_mask=src_mask, - ) - dec_slf_attn_collector.append(dec_slf_attn) - dec_enc_attn_collector.append(dec_enc_attn) - - if return_attn_weights: - return dec_output, dec_slf_attn_collector, dec_enc_attn_collector - - return dec_output - - -class PositionalEncoding(nn.Module): - """Positional-encoding module for Transformer. - - Parameters - ---------- - d_hid: - The dimension of the hidden layer. - - n_position: - The number of positions. - - """ - - def __init__(self, d_hid: int, n_position: int = 200): - super().__init__() - # Not a parameter - self.register_buffer( - "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid) - ) - - @staticmethod - def _get_sinusoid_encoding_table(n_position: int, d_hid: int) -> torch.Tensor: - """Sinusoid position encoding table""" - - def get_position_angle_vec(position): - return [ - position / np.power(10000, 2 * (hid_j // 2) / d_hid) - for hid_j in range(d_hid) - ] - - sinusoid_table = np.array( - [get_position_angle_vec(pos_i) for pos_i in range(n_position)] - ) - sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i - sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 - return torch.FloatTensor(sinusoid_table).unsqueeze(0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward processing of the positional encoding module. - - Parameters - ---------- - x: - Input tensor. - - Returns - ------- - x: - Output tensor, the input tensor with the positional encoding added. - - """ - return x + self.pos_table[:, : x.size(1)].clone().detach() diff --git a/pypots/modules/transformer/__init__.py b/pypots/modules/transformer/__init__.py new file mode 100644 index 00000000..65b02d45 --- /dev/null +++ b/pypots/modules/transformer/__init__.py @@ -0,0 +1,22 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from .attention import ScaledDotProductAttention, MultiHeadAttention +from .auto_encoder import Encoder, Decoder +from .layers import EncoderLayer, DecoderLayer, PositionWiseFeedForward +from .pos_enc import PositionalEncoding + +__all__ = [ + "ScaledDotProductAttention", + "MultiHeadAttention", + "PositionalEncoding", + "EncoderLayer", + "DecoderLayer", + "PositionWiseFeedForward", + "Encoder", + "Decoder", +] diff --git a/pypots/modules/transformer/attention.py b/pypots/modules/transformer/attention.py new file mode 100644 index 00000000..f4b6cf29 --- /dev/null +++ b/pypots/modules/transformer/attention.py @@ -0,0 +1,208 @@ +""" +The implementation of the modules for Transformer :cite:`vaswani2017Transformer` + +Notes +----- +Partial implementation uses code from https://github.com/WenjieDu/SAITS, +and https://github.com/jadore801120/attention-is-all-you-need-pytorch. + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Tuple, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ScaledDotProductAttention(nn.Module): + """Scaled dot-product attention. + + Parameters + ---------- + temperature: + The temperature for scaling. + + attn_dropout: + The dropout rate for the attention map. + + """ + + def __init__(self, temperature: float, attn_dropout: float = 0.1): + super().__init__() + assert temperature > 0, "temperature should be positive" + assert attn_dropout >= 0, "dropout rate should be non-negative" + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else None + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward processing of the scaled dot-product attention. + + Parameters + ---------- + q: + Query tensor. + k: + Key tensor. + v: + Value tensor. + + attn_mask: + Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. + 0 in attn_mask means values at the according position in the attention map will be masked out. + + Returns + ------- + output: + The result of Value multiplied with the scaled dot-product attention map. + + attn: + The scaled dot-product attention map. + + """ + # q, k, v all have 4 dimensions [batch_size, n_heads, n_steps, d_tensor] + # d_tensor could be d_q, d_k, d_v + + # dot product q with k.T to obtain similarity + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + + # apply masking on the attention map, this is optional + if attn_mask is not None: + attn = attn.masked_fill(attn_mask == 0, -1e9) + + # compute attention score [0, 1], then apply dropout + attn = F.softmax(attn, dim=-1) + if self.dropout is not None: + attn = self.dropout(attn) + + # multiply the score with v + output = torch.matmul(attn, v) + return output, attn + + +class MultiHeadAttention(nn.Module): + """Transformer multi-head attention module. + + Parameters + ---------- + n_heads: + The number of heads in multi-head attention. + + d_model: + The dimension of the input tensor. + + d_k: + The dimension of the key and query tensor. + + d_v: + The dimension of the value tensor. + + dropout: + The dropout rate. + + attn_dropout: + The dropout rate for the attention map. + + """ + + def __init__( + self, + n_heads: int, + d_model: int, + d_k: int, + d_v: int, + dropout: float, + attn_dropout: float, + ): + super().__init__() + + self.n_heads = n_heads + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_heads * d_k, bias=False) + self.w_ks = nn.Linear(d_model, n_heads * d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_heads * d_v, bias=False) + + self.attention = ScaledDotProductAttention(d_k**0.5, attn_dropout) + self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) + + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward processing of the multi-head attention module. + + Parameters + ---------- + q: + Query tensor. + + k: + Key tensor. + + v: + Value tensor. + + attn_mask: + Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. + 0 in attn_mask means values at the according position in the attention map will be masked out. + + Returns + ------- + v: + The output of the multi-head attention layer. + + attn_weights: + The attention map. + + """ + # the input q, k, v currently have 3 dimensions [batch_size, n_steps, d_tensor] + # d_tensor could be n_heads*d_k, n_heads*d_v + + # keep useful variables + batch_size, n_steps = q.size(0), q.size(1) + residual = q + + # now separate the last dimension of q, k, v into different heads -> [batch_size, n_steps, n_heads, d_k or d_v] + q = self.w_qs(q).view(batch_size, n_steps, self.n_heads, self.d_k) + k = self.w_ks(k).view(batch_size, n_steps, self.n_heads, self.d_k) + v = self.w_vs(v).view(batch_size, n_steps, self.n_heads, self.d_v) + + # transpose for self-attention calculation -> [batch_size, n_steps, d_k or d_v, n_heads] + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + if attn_mask is not None: + # broadcasting on the head axis + attn_mask = attn_mask.unsqueeze(1) + + v, attn_weights = self.attention(q, k, v, attn_mask) + + # transpose back -> [batch_size, n_steps, n_heads, d_v] + # then merge the last two dimensions to combine all the heads -> [batch_size, n_steps, n_heads*d_v] + v = v.transpose(1, 2).contiguous().view(batch_size, n_steps, -1) + v = self.fc(v) + + # apply dropout and residual connection + v = self.dropout(v) + v += residual + + # apply layer-norm + v = self.layer_norm(v) + + return v, attn_weights diff --git a/pypots/modules/transformer/auto_encoder.py b/pypots/modules/transformer/auto_encoder.py new file mode 100644 index 00000000..212bbc68 --- /dev/null +++ b/pypots/modules/transformer/auto_encoder.py @@ -0,0 +1,258 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from .layers import EncoderLayer, DecoderLayer +from .pos_enc import PositionalEncoding + + +class Encoder(nn.Module): + """Transformer encoder. + + Parameters + ---------- + n_layers: + The number of layers in the encoder. + + n_steps: + The number of time steps in the input tensor. + + n_features: + The number of features in the input tensor. + + d_model: + The dimension of the module manipulation space. + The input tensor will be projected to a space with d_model dimensions. + + d_inner: + The dimension of the hidden layer in the feed-forward network. + + n_heads: + The number of heads in multi-head attention. + + d_k: + The dimension of the key and query tensor. + + d_v: + The dimension of the value tensor. + + dropout: + The dropout rate. + + attn_dropout: + The dropout rate for the attention map. + + """ + + def __init__( + self, + n_layers: int, + n_steps: int, + n_features: int, + d_model: int, + d_inner: int, + n_heads: int, + d_k: int, + d_v: int, + dropout: float, + attn_dropout: float, + ): + super().__init__() + + self.embedding = nn.Linear(n_features, d_model) + self.dropout = nn.Dropout(dropout) + self.position_enc = PositionalEncoding(d_model, n_positions=n_steps) + self.enc_layer_stack = nn.ModuleList( + [ + EncoderLayer( + d_model, + d_inner, + n_heads, + d_k, + d_v, + dropout, + attn_dropout, + ) + for _ in range(n_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + return_attn_weights: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]: + """Forward processing of the encoder. + + Parameters + ---------- + x: + Input tensor. + + src_mask: + Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. + + return_attn_weights: + Whether to return the attention map. + + Returns + ------- + enc_output: + Output tensor. + + attn_weights_collector: + A list containing the attention map from each encoder layer. + + """ + x = self.embedding(x) + enc_output = self.dropout(self.position_enc(x)) + attn_weights_collector = [] + + for layer in self.enc_layer_stack: + enc_output, attn_weights = layer(enc_output, src_mask) + attn_weights_collector.append(attn_weights) + + if return_attn_weights: + return enc_output, attn_weights_collector + + return enc_output + + +class Decoder(nn.Module): + """Transformer decoder. + + Parameters + ---------- + n_layers: + The number of layers in the decoder. + + n_steps: + The number of time steps in the input tensor. + + n_features: + The number of features in the input tensor. + + d_model: + The dimension of the module manipulation space. + The input tensor will be projected to a space with d_model dimensions. + + d_inner: + The dimension of the hidden layer in the feed-forward network. + + n_heads: + The number of heads in multi-head attention. + + d_k: + The dimension of the key and query tensor. + + d_v: + The dimension of the value tensor. + + dropout: + The dropout rate. + + attn_dropout: + The dropout rate for the attention map. + + """ + + def __init__( + self, + n_layers: int, + n_steps: int, + n_features: int, + d_model: int, + d_inner: int, + n_heads: int, + d_k: int, + d_v: int, + dropout: float, + attn_dropout: float, + ): + super().__init__() + self.embedding = nn.Linear(n_features, d_model) + self.dropout = nn.Dropout(dropout) + self.position_enc = PositionalEncoding(d_model, n_positions=n_steps) + self.layer_stack = nn.ModuleList( + [ + DecoderLayer( + d_model, + d_inner, + n_heads, + d_k, + d_v, + dropout, + attn_dropout, + ) + for _ in range(n_layers) + ] + ) + + def forward( + self, + trg_seq: torch.Tensor, + enc_output: torch.Tensor, + trg_mask: Optional[torch.Tensor] = None, + src_mask: Optional[torch.Tensor] = None, + return_attn_weights: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, list, list]]: + """Forward processing of the decoder. + + Parameters + ---------- + trg_seq: + Input tensor. + + enc_output: + Output tensor from the encoder. + + trg_mask: + Masking tensor for the self-attention module. + + src_mask: + Masking tensor for the encoding attention module. + + return_attn_weights: + Whether to return the attention map. + + Returns + ------- + dec_output: + Output tensor. + + dec_slf_attn_collector: + A list containing the self-attention map from each decoder layer. + + dec_enc_attn_collector: + A list containing the encoding attention map from each decoder layer. + + """ + trg_seq = self.embedding(trg_seq) + dec_output = self.dropout(self.position_enc(trg_seq)) + + dec_slf_attn_collector = [] + dec_enc_attn_collector = [] + + for layer in self.layer_stack: + dec_output, dec_slf_attn, dec_enc_attn = layer( + dec_output, + enc_output, + slf_attn_mask=trg_mask, + dec_enc_attn_mask=src_mask, + ) + dec_slf_attn_collector.append(dec_slf_attn) + dec_enc_attn_collector.append(dec_enc_attn) + + if return_attn_weights: + return dec_output, dec_slf_attn_collector, dec_enc_attn_collector + + return dec_output diff --git a/pypots/modules/transformer/layers.py b/pypots/modules/transformer/layers.py new file mode 100644 index 00000000..6fd1efd2 --- /dev/null +++ b/pypots/modules/transformer/layers.py @@ -0,0 +1,236 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from typing import Tuple, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .attention import MultiHeadAttention + + +class PositionWiseFeedForward(nn.Module): + """Position-wise feed forward network (FFN) in Transformer. + + Parameters + ---------- + d_in: + The dimension of the input tensor. + + d_hid: + The dimension of the hidden layer. + + dropout: + The dropout rate. + + """ + + def __init__(self, d_in: int, d_hid: int, dropout: float = 0.1): + super().__init__() + self.linear_1 = nn.Linear(d_in, d_hid) + self.linear_2 = nn.Linear(d_hid, d_in) + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward processing of the position-wise feed forward network. + + Parameters + ---------- + x: + Input tensor. + + Returns + ------- + x: + Output tensor. + """ + # save the original input for the later residual connection + residual = x + # the 1st linear processing and ReLU non-linear projection + x = F.relu(self.linear_1(x)) + # the 2nd linear processing + x = self.linear_2(x) + # apply dropout + x = self.dropout(x) + # apply residual connection + x += residual + # apply layer-norm + x = self.layer_norm(x) + return x + + +class EncoderLayer(nn.Module): + """Transformer encoder layer. + + Parameters + ---------- + d_model: + The dimension of the input tensor. + + d_inner: + The dimension of the hidden layer. + + n_heads: + The number of heads in multi-head attention. + + d_k: + The dimension of the key and query tensor. + + d_v: + The dimension of the value tensor. + + dropout: + The dropout rate. + + attn_dropout: + The dropout rate for the attention map. + """ + + def __init__( + self, + d_model: int, + d_inner: int, + n_heads: int, + d_k: int, + d_v: int, + dropout: float = 0.1, + attn_dropout: float = 0.1, + ): + super().__init__() + self.slf_attn = MultiHeadAttention( + n_heads, d_model, d_k, d_v, dropout, attn_dropout + ) + self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout) + + def forward( + self, + enc_input: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward processing of the encoder layer. + + Parameters + ---------- + enc_input: + Input tensor. + + src_mask: + Masking tensor for the attention map. The shape should be [batch_size, n_heads, n_steps, n_steps]. + + Returns + ------- + enc_output: + Output tensor. + + attn_weights: + The attention map. + + """ + enc_output, attn_weights = self.slf_attn( + enc_input, + enc_input, + enc_input, + attn_mask=src_mask, + ) + enc_output = self.pos_ffn(enc_output) + return enc_output, attn_weights + + +class DecoderLayer(nn.Module): + """Transformer decoder layer. + + Parameters + ---------- + d_model: + The dimension of the input tensor. + + d_inner: + The dimension of the hidden layer. + + n_heads: + The number of heads in multi-head attention. + + d_k: + The dimension of the key and query tensor. + + d_v: + The dimension of the value tensor. + + dropout: + The dropout rate. + + attn_dropout: + The dropout rate for the attention map. + + """ + + def __init__( + self, + d_model: int, + d_inner: int, + n_heads: int, + d_k: int, + d_v: int, + dropout: float = 0.1, + attn_dropout: float = 0.1, + ): + super().__init__() + self.slf_attn = MultiHeadAttention( + n_heads, d_model, d_k, d_v, dropout, attn_dropout + ) + self.enc_attn = MultiHeadAttention( + n_heads, d_model, d_k, d_v, dropout, attn_dropout + ) + self.pos_ffn = PositionWiseFeedForward(d_model, d_inner, dropout) + + def forward( + self, + dec_input: torch.Tensor, + enc_output: torch.Tensor, + slf_attn_mask: Optional[torch.Tensor] = None, + dec_enc_attn_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward processing of the decoder layer. + + Parameters + ---------- + dec_input: + Input tensor. + + enc_output: + Output tensor from the encoder. + + slf_attn_mask: + Masking tensor for the self-attention module. + The shape should be [batch_size, n_heads, n_steps, n_steps]. + + dec_enc_attn_mask: + Masking tensor for the encoding attention module. + The shape should be [batch_size, n_heads, n_steps, n_steps]. + + Returns + ------- + dec_output: + Output tensor. + + dec_slf_attn: + The self-attention map. + + dec_enc_attn: + The encoding attention map. + + """ + dec_output, dec_slf_attn = self.slf_attn( + dec_input, dec_input, dec_input, attn_mask=slf_attn_mask + ) + dec_output, dec_enc_attn = self.enc_attn( + dec_output, enc_output, enc_output, attn_mask=dec_enc_attn_mask + ) + dec_output = self.pos_ffn(dec_output) + return dec_output, dec_slf_attn, dec_enc_attn diff --git a/pypots/modules/transformer/pos_enc.py b/pypots/modules/transformer/pos_enc.py new file mode 100644 index 00000000..1697bf96 --- /dev/null +++ b/pypots/modules/transformer/pos_enc.py @@ -0,0 +1,67 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import torch +import torch.nn as nn + + +class PositionalEncoding(nn.Module): + """The original positional-encoding module for Transformer. + + Parameters + ---------- + d_hid: + The dimension of the hidden layer. + + n_positions: + The max number of positions. + + """ + + def __init__(self, d_hid: int, n_positions: int = 1000): + super().__init__() + pe = torch.zeros(n_positions, d_hid, requires_grad=False).float() + position = torch.arange(0, n_positions).float().unsqueeze(1) + div_term = ( + torch.arange(0, d_hid, 2).float() + * -(torch.log(torch.tensor(10000)) / d_hid) + ).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pos_table", pe) + + def forward(self, x: torch.Tensor, return_only_pos: bool = False) -> torch.Tensor: + """Forward processing of the positional encoding module. + + Parameters + ---------- + x: + Input tensor. + + return_only_pos: + Whether to return only the positional encoding. + + Returns + ------- + If return_only_pos is True: + pos_enc: + The positional encoding. + else: + x_with_pos: + Output tensor, the input tensor with the positional encoding added. + """ + pos_enc = self.pos_table[:, : x.size(1)].clone().detach() + + if return_only_pos: + return pos_enc + + x_with_pos = x + pos_enc + return x_with_pos