Skip to content

Commit

Permalink
Merge pull request #251 from WenjieDu/code_refactor
Browse files Browse the repository at this point in the history
Code refactoring
  • Loading branch information
WenjieDu authored Dec 1, 2023
2 parents 73d8ada + 5936e30 commit 191e777
Show file tree
Hide file tree
Showing 22 changed files with 1,069 additions and 1,029 deletions.
2 changes: 1 addition & 1 deletion pypots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.2"
__version__ = "0.2.1"


from . import imputation, classification, clustering, forecasting, optim, data, utils
Expand Down
65 changes: 32 additions & 33 deletions pypots/classification/brits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,37 +89,36 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
ret_b = self._reverse(self.rits_b(inputs, "backward"))

classification_pred = (ret_f["prediction"] + ret_b["prediction"]) / 2
if not training:
# if not in training mode, return the classification result only
return {"classification_pred": classification_pred}

ret_f["classification_loss"] = F.nll_loss(
torch.log(ret_f["prediction"]), inputs["label"]
)
ret_b["classification_loss"] = F.nll_loss(
torch.log(ret_b["prediction"]), inputs["label"]
)
consistency_loss = self._get_consistency_loss(
ret_f["imputed_data"], ret_b["imputed_data"]
)
classification_loss = (
ret_f["classification_loss"] + ret_b["classification_loss"]
) / 2
reconstruction_loss = (
ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"]
) / 2

loss = (
consistency_loss
+ reconstruction_loss * self.reconstruction_weight
+ classification_loss * self.classification_weight
)

results = {
"classification_pred": classification_pred,
"consistency_loss": consistency_loss,
"classification_loss": classification_loss,
"reconstruction_loss": reconstruction_loss,
"loss": loss,
}
results = {"classification_pred": classification_pred}

# if in training mode, return results with losses
if training:
ret_f["classification_loss"] = F.nll_loss(
torch.log(ret_f["prediction"]), inputs["label"]
)
ret_b["classification_loss"] = F.nll_loss(
torch.log(ret_b["prediction"]), inputs["label"]
)
consistency_loss = self._get_consistency_loss(
ret_f["imputed_data"], ret_b["imputed_data"]
)
classification_loss = (
ret_f["classification_loss"] + ret_b["classification_loss"]
) / 2
reconstruction_loss = (
ret_f["reconstruction_loss"] + ret_b["reconstruction_loss"]
) / 2

results["consistency_loss"] = consistency_loss
results["classification_loss"] = classification_loss
results["reconstruction_loss"] = reconstruction_loss

# `loss` is always the item for backward propagating to update the model
loss = (
consistency_loss
+ reconstruction_loss * self.reconstruction_weight
+ classification_loss * self.classification_weight
)
results["loss"] = loss

return results
20 changes: 8 additions & 12 deletions pypots/classification/grud/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,14 @@ def forward(self, inputs: dict, training: bool = True) -> dict:

logits = self.classifier(hidden_state)
classification_pred = torch.softmax(logits, dim=1)
results = {"classification_pred": classification_pred}

if not training:
# if not in training mode, return the classification result only
return {"classification_pred": classification_pred}
# if in training mode, return results with losses
if training:
torch.log(classification_pred)
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)
results["loss"] = classification_loss

torch.log(classification_pred)
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)

results = {
"classification_pred": classification_pred,
"loss": classification_loss,
}
return results
19 changes: 7 additions & 12 deletions pypots/classification/raindrop/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,13 @@ def classify(self, inputs: dict) -> torch.Tensor:

def forward(self, inputs, training=True):
classification_pred = self.classify(inputs)
if not training:
# if not in training mode, return the classification result only
return {"classification_pred": classification_pred}
results = {"classification_pred": classification_pred}

classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)

results = {
"prediction": classification_pred,
"loss": classification_loss
# 'distance': distance,
}
# if in training mode, return results with losses
if training:
classification_loss = F.nll_loss(
torch.log(classification_pred), inputs["label"]
)
results["loss"] = classification_loss

return results
4 changes: 2 additions & 2 deletions pypots/clustering/crli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _train_model(
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs, return_loss=True)
results = self.model.forward(inputs, training=True)
epoch_val_loss_G_collector.append(
results["generation_loss"].sum().item()
)
Expand Down Expand Up @@ -424,7 +424,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
inputs = self.model.forward(inputs, return_loss=False)
inputs = self.model.forward(inputs, training=False)
clustering_latent_collector.append(inputs["fcn_latent"])
if return_latent_vars:
imputation_collector.append(inputs["imputation_latent"])
Expand Down
5 changes: 3 additions & 2 deletions pypots/clustering/crli/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def forward(
self,
inputs: dict,
training_object: str = "generator",
return_loss: bool = True,
training: bool = True,
) -> dict:
X = inputs["X"]
missing_mask = inputs["missing_mask"]
Expand All @@ -76,7 +76,7 @@ def forward(
inputs["fcn_latent"] = fcn_latent

# return results directly, skip loss calculation to reduce inference time
if not return_loss:
if not training:
return inputs

if training_object == "discriminator":
Expand Down Expand Up @@ -106,4 +106,5 @@ def forward(
l_kmeans = torch.trace(HTH) - torch.trace(FTHTHF) # k-means loss
loss_gene = l_G + l_pre + l_rec + l_kmeans * self.lambda_kmeans
losses["generation_loss"] = loss_gene

return losses
139 changes: 69 additions & 70 deletions pypots/clustering/vader/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,15 @@ def forward(
stddev_tilde,
) = self.get_results(X, missing_mask)

if not training and not pretrain:
results = {
"mu_tilde": mu_tilde,
"stddev_tilde": stddev_tilde,
"mu": mu_c,
"var": var_c,
"phi": phi_c,
"z": z,
"imputation_latent": X_reconstructed,
}
# if only run clustering, then no need to calculate loss
return results
results = {
"mu_tilde": mu_tilde,
"stddev_tilde": stddev_tilde,
"mu": mu_c,
"var": var_c,
"phi": phi_c,
"z": z,
"imputation_latent": X_reconstructed,
}

# calculate the reconstruction loss
unscaled_reconstruction_loss = cal_mse(X_reconstructed, X, missing_mask)
Expand All @@ -194,66 +191,68 @@ def forward(
* self.d_input
/ missing_mask.sum()
)

if pretrain:
results = {"loss": reconstruction_loss, "z": z}
results["loss"] = reconstruction_loss
return results

# calculate the latent loss
var_tilde = torch.exp(stddev_tilde)
stddev_c = torch.log(var_c + self.eps)
log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device))
log_phi_c = torch.log(phi_c + self.eps)

batch_size = z.shape[0]

ii, jj = torch.meshgrid(
torch.arange(self.n_clusters, dtype=torch.int64, device=device),
torch.arange(batch_size, dtype=torch.int64, device=device),
indexing="ij",
)
ii = ii.flatten()
jj = jj.flatten()

lsc_b = stddev_c.index_select(dim=0, index=ii)
mc_b = mu_c.index_select(dim=0, index=ii)
sc_b = var_c.index_select(dim=0, index=ii)
z_b = z.index_select(dim=0, index=jj)
log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b)
log_pdf_z = log_pdf_z.reshape([batch_size, self.n_clusters, self.d_mu_stddev])

log_p = log_phi_c + log_pdf_z.sum(dim=2)
lse_p = log_p.logsumexp(dim=1, keepdim=True)
log_gamma_c = log_p - lse_p
gamma_c = torch.exp(log_gamma_c)

term1 = torch.log(var_c + self.eps)
st_b = var_tilde.index_select(dim=0, index=jj)
sc_b = var_c.index_select(dim=0, index=ii)
term2 = torch.reshape(
st_b / (sc_b + self.eps), [batch_size, self.n_clusters, self.d_mu_stddev]
)
mt_b = mu_tilde.index_select(dim=0, index=jj)
mc_b = mu_c.index_select(dim=0, index=ii)
term3 = torch.reshape(
torch.square(mt_b - mc_b) / (sc_b + self.eps),
[batch_size, self.n_clusters, self.d_mu_stddev],
)

latent_loss1 = 0.5 * torch.sum(
gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1
)
latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1)
latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1)

latent_loss1 = latent_loss1.mean()
latent_loss2 = latent_loss2.mean()
latent_loss3 = latent_loss3.mean()
latent_loss = latent_loss1 + latent_loss2 + latent_loss3

results = {
"loss": reconstruction_loss + self.alpha * latent_loss,
"z": z,
"imputation_latent": X_reconstructed,
}
# if in training mode, return results with losses
if training:
# calculate the latent loss for model training
var_tilde = torch.exp(stddev_tilde)
stddev_c = torch.log(var_c + self.eps)
log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device))
log_phi_c = torch.log(phi_c + self.eps)

batch_size = z.shape[0]

ii, jj = torch.meshgrid(
torch.arange(self.n_clusters, dtype=torch.int64, device=device),
torch.arange(batch_size, dtype=torch.int64, device=device),
indexing="ij",
)
ii = ii.flatten()
jj = jj.flatten()

lsc_b = stddev_c.index_select(dim=0, index=ii)
mc_b = mu_c.index_select(dim=0, index=ii)
sc_b = var_c.index_select(dim=0, index=ii)
z_b = z.index_select(dim=0, index=jj)
log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b)
log_pdf_z = log_pdf_z.reshape(
[batch_size, self.n_clusters, self.d_mu_stddev]
)

log_p = log_phi_c + log_pdf_z.sum(dim=2)
lse_p = log_p.logsumexp(dim=1, keepdim=True)
log_gamma_c = log_p - lse_p
gamma_c = torch.exp(log_gamma_c)

term1 = torch.log(var_c + self.eps)
st_b = var_tilde.index_select(dim=0, index=jj)
sc_b = var_c.index_select(dim=0, index=ii)
term2 = torch.reshape(
st_b / (sc_b + self.eps),
[batch_size, self.n_clusters, self.d_mu_stddev],
)
mt_b = mu_tilde.index_select(dim=0, index=jj)
mc_b = mu_c.index_select(dim=0, index=ii)
term3 = torch.reshape(
torch.square(mt_b - mc_b) / (sc_b + self.eps),
[batch_size, self.n_clusters, self.d_mu_stddev],
)

latent_loss1 = 0.5 * torch.sum(
gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1
)
latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1)
latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1)

latent_loss1 = latent_loss1.mean()
latent_loss2 = latent_loss2.mean()
latent_loss3 = latent_loss3.mean()
latent_loss = latent_loss1 + latent_loss2 + latent_loss3

results["loss"] = reconstruction_loss + self.alpha * latent_loss

return results
34 changes: 15 additions & 19 deletions pypots/imputation/brits/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,27 +316,23 @@ def forward(self, inputs: dict, training: bool = True) -> dict:

imputed_data = (ret_f["imputed_data"] + ret_b["imputed_data"]) / 2

if not training:
# if not in training mode, return the classification result only
return {
"imputed_data": imputed_data,
}

consistency_loss = self._get_consistency_loss(
ret_f["imputed_data"], ret_b["imputed_data"]
)

# `loss` is always the item for backward propagating to update the model
loss = (
consistency_loss
+ ret_f["reconstruction_loss"]
+ ret_b["reconstruction_loss"]
)

results = {
"imputed_data": imputed_data,
"consistency_loss": consistency_loss,
"loss": loss, # will be used for backward propagating to update the model
}

# if in training mode, return results with losses
if training:
consistency_loss = self._get_consistency_loss(
ret_f["imputed_data"], ret_b["imputed_data"]
)

# `loss` is always the item for backward propagating to update the model
loss = (
consistency_loss
+ ret_f["reconstruction_loss"]
+ ret_b["reconstruction_loss"]
)
results["consistency_loss"] = consistency_loss
results["loss"] = loss

return results
Loading

0 comments on commit 191e777

Please sign in to comment.