Skip to content

Commit

Permalink
Make image sampler callback work
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Clement committed Nov 27, 2020
1 parent 0df4de5 commit 6b1e11a
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch import nn

from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.utils.warnings import warn_missing_pkg

try:
Expand Down Expand Up @@ -62,7 +63,8 @@ def configure_optimizers(self):
opt_gen = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=betas)
return [opt_disc, opt_gen], []

def forward(self, noise):
def forward(self, noise: torch.Tensor) -> torch.Tensor:
noise = noise.view(*noise.shape, 1, 1)
return self.generator(noise)

def training_step(self, batch, batch_idx, optimizer_idx):
Expand Down Expand Up @@ -125,8 +127,8 @@ def _get_batch_size(real: torch.Tensor) -> int:
batch_size = len(real)
return batch_size

def _get_noise(self, n_samples, latent_dim):
noise = torch.randn(n_samples, latent_dim, 1, 1, device=self.device)
def _get_noise(self, n_samples: int, latent_dim: int) -> torch.Tensor:
noise = torch.randn(n_samples, latent_dim, device=self.device)
return noise

@staticmethod
Expand Down Expand Up @@ -169,7 +171,8 @@ def cli_main(args=None):
dm.test_transforms = transforms

model = DCGAN(**vars(args))
trainer = pl.Trainer.from_argparse_args(args)
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)]
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks)
trainer.fit(model, dm)
return dm, model, trainer

Expand Down

0 comments on commit 6b1e11a

Please sign in to comment.