From 9a050c9e87a227652e1d597219b58f0c6dceb745 Mon Sep 17 00:00:00 2001 From: Marius Giger Date: Tue, 19 Apr 2022 10:47:48 +0200 Subject: [PATCH] fixes an issue with wandb image logging, adds configuration for data loader workers --- src/sdo/cmd/sood/ce_vae/cmd_train.py | 7 +++++-- src/sood/algorithms/ce_vae.py | 30 ++++++++++++++-------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/sdo/cmd/sood/ce_vae/cmd_train.py b/src/sdo/cmd/sood/ce_vae/cmd_train.py index a36b82c..fa598e4 100644 --- a/src/sdo/cmd/sood/ce_vae/cmd_train.py +++ b/src/sdo/cmd/sood/ce_vae/cmd_train.py @@ -24,6 +24,7 @@ @click.option( "--dataset", type=click.Choice(["CuratedImageParameterDataset", "SDOMLDatasetV1"], case_sensitive=False), required=False, default="CuratedImageParameterDataset" ) +@click.option("--num-data-loader-workers", type=int, default=0) @pass_environment def train(ctx, target_size, @@ -40,7 +41,8 @@ def train(ctx, load_path, log_dir, data_dir, - dataset): + dataset, + num_data_loader_workers): main(run="train", target_size=target_size, @@ -56,4 +58,5 @@ def train(ctx, load_path=load_path, log_dir=log_dir, data_dir=data_dir, - dataset=dataset) + dataset=dataset, + num_data_loader_workers=num_data_loader_workers) diff --git a/src/sood/algorithms/ce_vae.py b/src/sood/algorithms/ce_vae.py index 0be2c66..3f13c85 100644 --- a/src/sood/algorithms/ce_vae.py +++ b/src/sood/algorithms/ce_vae.py @@ -176,16 +176,15 @@ def train(self): log_dict = {} # VAE - image_table = wandb.Table() + + image_data = [cnt, None, None, None, None] if self.ce_factor < 1: input_vae_path = save_image_grid(inpt, name="Input-VAE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": False}, n_iter=cnt) - image_table.add_column( - "Input-VAE", wandb.Image(input_vae_path)) + image_data[1] = wandb.Image(input_vae_path) output_vae_path = save_image_grid( x_rec_vae, name="Output-VAE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": True}, n_iter=cnt) - image_table.add_column( - "Output-VAE", wandb.Image(output_vae_path)) + image_data[2] = wandb.Image(output_vae_path) if self.beta > 0: log_dict["Kl-loss"] = torch.mean(kl_loss).item() @@ -197,16 +196,16 @@ def train(self): if self.ce_factor > 0: input_ce_path = save_image_grid( inpt_noisy, name="Input-CE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": False}) - image_table.add_column( - "Input-CE", wandb.Image(input_ce_path)) + image_data[3] = wandb.Image(input_ce_path) output_ce_path = save_image_grid( x_rec_ce, name="Output-CE", save_dir=self.work_dir / Path("save/imgs"), image_args={"normalize": True}) - image_table.add_column( - "Output-CE", wandb.Image(output_ce_path)) + image_data[4] = wandb.Image(output_ce_path) log_dict["CE-train-loss"] = loss_ce.item() - # TODO why normalize by the length of the input (batch length)? + image_table = wandb.Table( + columns=["Step", "Input-VAE", "Output-VAE", "Input-CE", "Output-CE"], data=[image_data]) + # TODO why normalize by the length of the input (batch length)? log_dict["CEVAE-train-loss"] = loss.item() / len(inpt) log_dict["epoch"] = epoch log_dict["counter"] = cnt @@ -448,7 +447,8 @@ def main( test_dir=None, pred_dir=None, data_dir=None, - dataset="CuratedImageParameterDataset" + dataset="CuratedImageParameterDataset", + num_data_loader_workers=0 ): input_shape = (batch_size, 1, target_size, target_size) @@ -458,7 +458,7 @@ def main( if dataset == "CuratedImageParameterDataset": train_loader = get_dataset( base_dir=data_dir, - num_processes=16, + num_processes=num_data_loader_workers, pin_memory=False, batch_size=batch_size, mode="train", @@ -466,7 +466,7 @@ def main( ) val_loader = get_dataset( base_dir=data_dir, - num_processes=8, + num_processes=num_data_loader_workers, pin_memory=False, batch_size=batch_size, mode="val", @@ -476,7 +476,7 @@ def main( # due to a bug on Mac, num processes needs to be 0: https://github.com/pyg-team/pytorch_geometric/issues/366 train_loader = get_sdo_ml_v1_dataset( base_dir=data_dir, - num_processes=0, + num_processes=num_data_loader_workers, pin_memory=False, batch_size=batch_size, mode="train", @@ -484,7 +484,7 @@ def main( ) val_loader = get_sdo_ml_v1_dataset( base_dir=data_dir, - num_processes=0, + num_processes=num_data_loader_workers, pin_memory=False, batch_size=batch_size, mode="val",