diff --git a/src/sood/algorithms/ce_vae.py b/src/sood/algorithms/ce_vae.py index d5ee1f3..0be2c66 100644 --- a/src/sood/algorithms/ce_vae.py +++ b/src/sood/algorithms/ce_vae.py @@ -3,6 +3,7 @@ import os from math import ceil from pathlib import Path +import json import numpy as np import torch @@ -23,25 +24,25 @@ class ceVAE: - # TODO store config as file/wandb def __init__( self, input_shape, - lr=1e-4, - n_epochs=20, - z_dim=512, - model_feature_map_sizes=(16, 64, 256, 1024), - use_geco=False, - beta=0.01, - ce_factor=0.5, - score_mode="combi", - load_path=None, - log_dir=None, - print_every_iter=100, - data_dir=None, - dataset=None + lr, + n_epochs, + z_dim, + model_feature_map_sizes, + use_geco, + beta, + ce_factor, + score_mode, + load_path, + log_dir, + print_every_iter, + data_dir, + dataset, + train_loader, + val_loader ): - self.score_mode = score_mode self.ce_factor = ce_factor self.beta = beta @@ -53,16 +54,20 @@ def __init__( self.input_shape = input_shape self.data_dir = data_dir self.dataset = dataset + self.train_loader = train_loader + self.val_loader = val_loader self.log_dir = log_dir - folder_format = "%Y%m%d-%H%M%S" + folder_time_format = "%Y%m%d-%H%M%S" self.work_dir = Path( - log_dir) / Path(f"{datetime.datetime.now().strftime(folder_format)}_ce_vae") + log_dir) / Path(f"{datetime.datetime.now().strftime(folder_time_format)}_cevae") if not os.path.exists(self.work_dir): os.makedirs(self.work_dir) + os.makedirs(self.work_dir / Path("config")) cuda_available = torch.cuda.is_available() - self.device = torch.device("cuda" if cuda_available else "cpu") + device = torch.device("cuda" if cuda_available else "cpu") + self.device = device self.model = VAE( input_size=input_shape[1:], z_dim=z_dim, fmap_sizes=model_feature_map_sizes).to(self.device) @@ -72,56 +77,41 @@ def __init__( self.vae_loss_ema = 1 self.theta = 1 + self.config = { + "score_mode": score_mode, + "ce_factor": ce_factor, + "beta": beta, + "print_every_iter": print_every_iter, + "n_epochs": n_epochs, + "batch_size": input_shape[0], + "z_dim": z_dim, + "use_geco": use_geco, + "input_shape": input_shape, + "data_dir": data_dir, + "dataset": dataset, + "log_dir": log_dir, + "load_path": load_path, + "cuda_available": cuda_available, + "lr": lr + } + with open(self.work_dir / Path("config") / Path("config.json"), "w") as cf: + json.dump(self.config, cf, skipkeys=True, indent=2) + if load_path is not None: load_model(self.model, os.path.join(load_path, "vae_final.pth")) def train(self): - if self.dataset == "CuratedImageParameterDataset": - train_loader = get_dataset( - base_dir=self.data_dir, - num_processes=16, - pin_memory=False, - batch_size=self.batch_size, - mode="train", - target_size=self.input_shape[2], - ) - val_loader = get_dataset( - base_dir=self.data_dir, - num_processes=8, - pin_memory=False, - batch_size=self.batch_size, - mode="val", - target_size=self.input_shape[2], - ) - elif self.dataset == "SDOMLDatasetV1": - # due to a bug on Mac, num processes needs to be 0: https://github.com/pyg-team/pytorch_geometric/issues/366 - train_loader = get_sdo_ml_v1_dataset( - base_dir=self.data_dir, - num_processes=0, - pin_memory=False, - batch_size=self.batch_size, - mode="train", - target_size=self.input_shape[2], - ) - val_loader = get_sdo_ml_v1_dataset( - base_dir=self.data_dir, - num_processes=0, - pin_memory=False, - batch_size=self.batch_size, - mode="val", - target_size=self.input_shape[2], - ) - - wandb.init(project='sdo-sood', entity='mariusgiger') + wandb.init(project='sdo-sood', + entity='mariusgiger', + config=self.config) wandb.watch(self.model, log_freq=100) for epoch in range(self.n_epochs): - self.model.train() train_loss = 0 print("Start epoch") - data_loader_ = tqdm(enumerate(train_loader)) + data_loader_ = tqdm(enumerate(self.train_loader)) for batch_idx, data in data_loader_: data = data[0] # only inputs no labels self.optimizer.zero_grad() @@ -132,13 +122,11 @@ def train(self): loss_vae = 0 if self.ce_factor < 1: x_rec_vae, z_dist, = self.model(inpt) - kl_loss = 0 if self.beta > 0: kl_loss = self.kl_loss_fn(z_dist) * self.beta rec_loss_vae = self.rec_loss_fn(x_rec_vae, inpt) loss_vae = kl_loss + rec_loss_vae * self.theta - # print(loss_vae) # CE Part loss_ce = 0 if self.ce_factor > 0: @@ -156,7 +144,6 @@ def train(self): x_rec_ce, _ = self.model(inpt_noisy) rec_loss_ce = self.rec_loss_fn(x_rec_ce, inpt) loss_ce = rec_loss_ce - # print(loss_ce) loss = (1.0 - self.ce_factor) * \ loss_vae + self.ce_factor * loss_ce @@ -171,65 +158,70 @@ def train(self): self.theta, self.vae_loss_ema, g_goal, g_lr, speedup=2) if torch.isnan(loss): - print("A wild NaN occurred") + print("Loss is NaN") continue loss.backward() self.optimizer.step() train_loss += loss.item() - status_str = ( - f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} " - f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: " - f"{loss.item() / len(inpt):.6f}" - ) - data_loader_.set_description_str(status_str) - if batch_idx % self.print_every_iter == 0: - cnt = epoch * len(train_loader) + batch_idx - losses = {} + status_str = ( + f"Train Epoch: {epoch} [{batch_idx}/{len(self.train_loader)} " + f" ({100.0 * batch_idx / len(self.train_loader):.0f}%)] Loss: " + f"{loss.item() / len(inpt):.6f}" + ) + data_loader_.set_description_str(status_str) + cnt = epoch * len(self.train_loader) + batch_idx + log_dict = {} # VAE + image_table = wandb.Table() if self.ce_factor < 1: - # image_table = wandb.Table() - # image_table.add_column("Input-VAE", images_t) - # my_table.add_column("class_prediction", predictions_t) - # wandb.log({"mnist_predictions": my_table}) - - save_image_grid(inpt, name="Input-VAE", save_dir=self.work_dir, - image_args={"normalize": False}) - save_image_grid( - x_rec_vae, name="Output-VAE", save_dir=self.work_dir, image_args={"normalize": True}) + input_vae_path = save_image_grid(inpt, name="Input-VAE", save_dir=self.work_dir / Path("save/imgs"), + image_args={"normalize": False}, n_iter=cnt) + image_table.add_column( + "Input-VAE", wandb.Image(input_vae_path)) + output_vae_path = save_image_grid( + x_rec_vae, name="Output-VAE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": True}, n_iter=cnt) + image_table.add_column( + "Output-VAE", wandb.Image(output_vae_path)) if self.beta > 0: - losses["Kl-loss"] = torch.mean(kl_loss).item() + log_dict["Kl-loss"] = torch.mean(kl_loss).item() - losses["Rec-loss"] = torch.mean(rec_loss_vae).item() - losses["VAE-Train-loss"] = loss_vae.item() + log_dict["Rec-loss"] = torch.mean(rec_loss_vae).item() + log_dict["VAE-train-loss"] = loss_vae.item() # CE if self.ce_factor > 0: - save_image_grid( - inpt_noisy, name="Input-CE", save_dir=self.work_dir, image_args={"normalize": False}) - save_image_grid( - x_rec_ce, name="Output-CE", save_dir=self.work_dir, image_args={"normalize": True}) + input_ce_path = save_image_grid( + inpt_noisy, name="Input-CE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": False}) + image_table.add_column( + "Input-CE", wandb.Image(input_ce_path)) + output_ce_path = save_image_grid( + x_rec_ce, name="Output-CE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": True}) + image_table.add_column( + "Output-CE", wandb.Image(output_ce_path)) - losses["CE-Train-loss"] = loss_ce.item() + log_dict["CE-train-loss"] = loss_ce.item() # TODO why normalize by the length of the input (batch length)? - losses["loss"] = loss.item() / len(inpt) - losses["epoch"] = epoch - losses["counter"] = cnt - wandb.log(losses) + log_dict["CEVAE-train-loss"] = loss.item() / len(inpt) + log_dict["epoch"] = epoch + log_dict["counter"] = cnt + log_dict["images"] = image_table + + wandb.log(log_dict) print( - f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}") + f"====> Epoch: {epoch} Average loss: {train_loss / len(self.train_loader):.4f}") self.model.eval() val_loss = 0 with torch.no_grad(): - data_loader_ = tqdm(enumerate(val_loader)) + data_loader_ = tqdm(enumerate(self.val_loader)) for i, data in data_loader_: data = data[0] inpt = data.to(self.device) @@ -244,13 +236,14 @@ def train(self): val_loss += loss.item() - wandb.log({"Val-Loss": val_loss / len(val_loader), - "counter": (epoch + 1) * len(train_loader)}) + wandb.log({"Val-Loss": val_loss / len(self.val_loader), + "counter": (epoch + 1) * len(self.train_loader)}) print( - f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}") + f"====> Epoch: {epoch} Validation loss: {val_loss / len(self.val_loader):.4f}") - save_model(self.model, "vae_final", model_dir=self.work_dir) + save_model(self.model, "vae_final", + model_dir=self.work_dir / Path("checkpoint")) # https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed def generate(self, n_samples=16, mu=None, std=None): @@ -265,12 +258,11 @@ def generate(self, n_samples=16, mu=None, std=None): with torch.no_grad(): pred = self.model.decode(z.to(self.device)).cpu() - file_name = Path(self.log_dir) / \ + file_name = Path(self.work_dir) / Path("generated") / \ (datetime.datetime.now().isoformat() + "_generated.jpeg") save_image(pred, file_name, normalize=True) def score_sample(self, data): - orig_shape = data.shape to_transforms = torch.nn.Upsample( (self.input_shape[2], self.input_shape[3]), mode="bilinear") @@ -296,7 +288,6 @@ def score_sample(self, data): return np.max(slice_scores) def score_pixels(self, data, index, file_name): - orig_shape = data.shape to_transforms = torch.nn.Upsample( (self.input_shape[2], self.input_shape[3]), mode="bilinear") @@ -336,13 +327,11 @@ def __err_fn(x): normalize(loss_grad_kl), kernel_size=8) * rec elif self.score_mode == "rec": - rec = torch.pow((x_rec - inpt), 2).detach().cpu() rec = torch.mean(rec, dim=1, keepdim=True) pixel_scores = rec elif self.score_mode == "grad": - def __err_fn(x): x_r, z_d = self.model(x, sample=False) kl_loss_ = self.kl_loss_fn(z_d) @@ -362,11 +351,11 @@ def __err_fn(x): pixel_scores = smooth_tensor( normalize(loss_grad_kl), kernel_size=8) - save_image_grid(inpt, name="Input", image_args={ + save_image_grid(inpt, name="Input", save_dir=self.work_dir, image_args={ "normalize": True}, n_iter=index) - save_image_grid(x_rec, name="Output", image_args={ + save_image_grid(x_rec, name="Output", save_dir=self.work_dir, image_args={ "normalize": True}, n_iter=index) - save_image_grid(pixel_scores, name="Scores", image_args={ + save_image_grid(pixel_scores, name="Scores", save_dir=self.work_dir, image_args={ "normalize": True}, n_iter=index) target_tensor[i * self.batch_size: ( @@ -378,11 +367,7 @@ def __err_fn(x): return target_tensor.detach().numpy() - @ staticmethod - def load_trained_model(model, tx, path): - tx.elog.load_model_static(model=model, model_file=path) - - @ staticmethod + @staticmethod def kl_loss_fn(z_post, sum_samples=True, correct=False): z_prior = dist.Normal(0, 1.0) kl_div = dist.kl_divergence(z_post, z_prior) @@ -395,7 +380,7 @@ def kl_loss_fn(z_post, sum_samples=True, correct=False): else: return kl_div - @ staticmethod + @staticmethod def rec_loss_fn(recon_x, x, sum_samples=True, correct=False): if correct: x_dist = dist.Laplace(recon_x, 1.0) @@ -409,7 +394,7 @@ def rec_loss_fn(recon_x, x, sum_samples=True, correct=False): else: return -log_p_x_z - @ staticmethod + @staticmethod def get_inpt_grad(model, inpt, err_fn): model.zero_grad() inpt = inpt.detach() @@ -424,7 +409,7 @@ def get_inpt_grad(model, inpt, err_fn): return torch.abs(grad.detach()) - @ staticmethod + @staticmethod def geco_beta_update(beta, error_ema, goal, step_size, min_clamp=1e-10, max_clamp=1e4, speedup=None): constraint = (error_ema - goal).detach() if speedup is not None and constraint > 0.0: @@ -465,9 +450,47 @@ def main( data_dir=None, dataset="CuratedImageParameterDataset" ): - input_shape = (batch_size, 1, target_size, target_size) + train_loader = None + val_loader = None + if run == "train": + if dataset == "CuratedImageParameterDataset": + train_loader = get_dataset( + base_dir=data_dir, + num_processes=16, + pin_memory=False, + batch_size=batch_size, + mode="train", + target_size=input_shape[2], + ) + val_loader = get_dataset( + base_dir=data_dir, + num_processes=8, + pin_memory=False, + batch_size=batch_size, + mode="val", + target_size=input_shape[2], + ) + elif dataset == "SDOMLDatasetV1": + # due to a bug on Mac, num processes needs to be 0: https://github.com/pyg-team/pytorch_geometric/issues/366 + train_loader = get_sdo_ml_v1_dataset( + base_dir=data_dir, + num_processes=0, + pin_memory=False, + batch_size=batch_size, + mode="train", + target_size=input_shape[2], + ) + val_loader = get_sdo_ml_v1_dataset( + base_dir=data_dir, + num_processes=0, + pin_memory=False, + batch_size=batch_size, + mode="val", + target_size=input_shape[2], + ) + cevae_algo = ceVAE( input_shape, log_dir=log_dir, @@ -482,7 +505,9 @@ def main( print_every_iter=print_every_iter, load_path=load_path, data_dir=data_dir, - dataset=dataset + dataset=dataset, + train_loader=train_loader, + val_loader=val_loader ) if run == "train": @@ -496,9 +521,10 @@ def main( pred_dir = os.path.join(cevae_algo.work_dir, "predictions") os.makedirs(pred_dir, exist_ok=True) elif pred_dir is None and log_dir is None: - print("Please either give a log/ output dir or a prediction dir") + print("Please either provide a log/output dir or a prediction dir") exit(0) + # TODO use same transforms as during training transforms = Compose([Resize((target_size, target_size)), Grayscale(num_output_channels=1), ToTensor()]) data_set = ImageFolderWithPaths(test_dir, transforms)