Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple DataModules from Models - GAN #206

Merged
merged 6 commits into from
Sep 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 33 additions & 49 deletions pl_bolts/models/gans/basic/basic_gan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
import pytorch_lightning as pl
from torch.nn import functional as F

from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.models.gans.basic.components import Generator, Discriminator


class GAN(pl.LightningModule):

def __init__(self,
datamodule: pl.LightningDataModule = None,
latent_dim: int = 32,
batch_size: int = 100,
learning_rate: float = 0.0002,
data_dir: str = '',
num_workers: int = 8,
**kwargs):
def __init__(
self,
input_channels: int,
input_height: int,
input_width: int,
latent_dim: int = 32,
learning_rate: float = 0.0002,
**kwargs
):
"""
Vanilla GAN implementation.

Expand Down Expand Up @@ -53,24 +53,12 @@ def __init__(self,

# makes self.hparams under the hood and saves to ckpt
self.save_hyperparameters()

self._set_default_datamodule(datamodule)
self.img_dim = (input_channels, input_height, input_width)

# networks
self.generator = self.init_generator(self.img_dim)
self.discriminator = self.init_discriminator(self.img_dim)

def _set_default_datamodule(self, datamodule):
# link default data
if datamodule is None:
datamodule = MNISTDataModule(
data_dir=self.hparams.data_dir,
num_workers=self.hparams.num_workers,
normalize=True
)
self.datamodule = datamodule
self.img_dim = self.datamodule.size()

def init_generator(self, img_dim):
generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_dim)
return generator
Expand Down Expand Up @@ -179,44 +167,40 @@ def add_model_specific_args(parent_parser):
help="adam: decay of first order momentum of gradient")
parser.add_argument('--latent_dim', type=int, default=100,
help="generator embedding dim")
parser.add_argument('--batch_size', type=int, default=64, help="size of the batches")
parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers")
parser.add_argument('--data_dir', type=str, default=os.getcwd())

return parser


def cli_main():
def cli_main(args=None):
from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.datamodules import STL10DataModule, ImagenetDataModule
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule

pl.seed_everything(1234)

parser = ArgumentParser()
parser.add_argument('--dataset', type=str, default='mnist', help='mnist, stl10, imagenet2012')

parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet")
script_args, _ = parser.parse_known_args(args)

if script_args.dataset == "mnist":
dm_cls = MNISTDataModule
elif script_args.dataset == "cifar10":
dm_cls = CIFAR10DataModule
elif script_args.dataset == "stl10":
dm_cls = STL10DataModule
elif script_args.dataset == "imagenet":
dm_cls = ImagenetDataModule

parser = dm_cls.add_argparse_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
parser = GAN.add_model_specific_args(parser)
parser = ImagenetDataModule.add_argparse_args(parser)
args = parser.parse_args()

# default is mnist
datamodule = None
if args.dataset == 'imagenet2012':
datamodule = ImagenetDataModule.from_argparse_args(args)
elif args.dataset == 'stl10':
datamodule = STL10DataModule.from_argparse_args(args)

gan = GAN(**vars(args), datamodule=datamodule)
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator()]
args = parser.parse_args(args)

trainer = pl.Trainer.from_argparse_args(
args,
callbacks=callbacks,
progress_bar_refresh_rate=10
)
trainer.fit(gan)
dm = dm_cls.from_argparse_args(args)
model = GAN(*dm.size(), **vars(args))
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)]
trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20)
trainer.fit(model, dm)
return dm, model, trainer


if __name__ == '__main__':
cli_main()
dm, model, trainer = cli_main()
5 changes: 2 additions & 3 deletions pl_bolts/models/self_supervised/moco/moco2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
This implementation is: Copyright (c) PyTorch Lightning, Inc. and its affiliates. All Rights Reserved
"""

from argparse import ArgumentParser
from typing import Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -318,8 +319,7 @@ def configure_optimizers(self):

@staticmethod
def add_model_specific_args(parent_parser):
from test_tube import HyperOptArgumentParser
parser = HyperOptArgumentParser(parents=[parent_parser], add_help=False)
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--base_encoder', type=str, default='resnet18')
parser.add_argument('--emb_dim', type=int, default=128)
parser.add_argument('--num_workers', type=int, default=8)
Expand Down Expand Up @@ -354,7 +354,6 @@ def concat_all_gather(tensor):


def cli_main():
from argparse import ArgumentParser

parser = ArgumentParser()

Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_variational_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self):
self.global_step = 1
self.logger = DummyLogger()

model = GAN()
model = GAN(3, 28, 28)
cb = LatentDimInterpolator(interpolate_epoch_interval=2)

cb.on_epoch_end(FakeTrainer(), model)
21 changes: 15 additions & 6 deletions tests/models/test_executable_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
import pytest


@pytest.mark.parametrize('cli_args', ['--max_epochs 1'
' --limit_train_batches 3'
' --limit_val_batches 3'
' --batch_size 3'])
def test_cli_basic_gan(cli_args):
@pytest.mark.parametrize(
"dataset_name", [
pytest.param('mnist', id="mnist"),
pytest.param('cifar10', id="cifar10")
]
)
def test_cli_basic_gan(dataset_name):
from pl_bolts.models.gans.basic.basic_gan_module import cli_main

cli_args = cli_args.split(' ') if cli_args else []
cli_args = f"""
--dataset {dataset_name}
--max_epochs 1
--limit_train_batches 3
--limit_val_batches 3
--batch_size 3
""".strip().split()

with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
cli_main()

Expand Down
14 changes: 10 additions & 4 deletions tests/models/test_gans.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import pytest
import pytorch_lightning as pl
from pytorch_lightning import seed_everything

from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule
from pl_bolts.models.gans import GAN


def test_gan(tmpdir):
@pytest.mark.parametrize(
"dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")]
)
def test_gan(tmpdir, dm_cls):
seed_everything()

model = GAN(data_dir=tmpdir)
dm = dm_cls()
model = GAN(*dm.size())
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model)
trainer.test(model)
trainer.fit(model, dm)
trainer.test(datamodule=dm)