diff --git a/pl_bolts/models/gans/dcgan/dcgan_module.py b/pl_bolts/models/gans/dcgan/dcgan_module.py index 9ac298a5e6..5cb7b26145 100644 --- a/pl_bolts/models/gans/dcgan/dcgan_module.py +++ b/pl_bolts/models/gans/dcgan/dcgan_module.py @@ -4,6 +4,7 @@ import torch from torch import nn +from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler from pl_bolts.utils.warnings import warn_missing_pkg try: @@ -62,7 +63,8 @@ def configure_optimizers(self): opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas) return [opt_disc, opt_gen], [] - def forward(self, noise): + def forward(self, noise: torch.Tensor) -> torch.Tensor: + noise = noise.view(*noise.shape, 1, 1) return self.generator(noise) def training_step(self, batch, batch_idx, optimizer_idx): @@ -125,8 +127,8 @@ def _get_batch_size(real: torch.Tensor) -> int: batch_size = len(real) return batch_size - def _get_noise(self, n_samples, latent_dim): - noise = torch.randn(n_samples, latent_dim, 1, 1, device=self.device) + def _get_noise(self, n_samples: int, latent_dim: int) -> torch.Tensor: + noise = torch.randn(n_samples, latent_dim, device=self.device) return noise @staticmethod @@ -169,7 +171,8 @@ def cli_main(args=None): dm.test_transforms = transforms model = DCGAN(**vars(args)) - trainer = pl.Trainer.from_argparse_args(args) + callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)] + trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks) trainer.fit(model, dm) return dm, model, trainer