From 76cedfedf248b72cd2809c14f9857b5b47f40334 Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Fri, 2 Oct 2020 16:18:09 +0000 Subject: [PATCH] current --- .../self_supervised/swav/swav_finetuner.py | 27 ++++++++++++++++--- .../self_supervised/swav/weights_convert.py | 16 +++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) create mode 100644 pl_bolts/models/self_supervised/swav/weights_convert.py diff --git a/pl_bolts/models/self_supervised/swav/swav_finetuner.py b/pl_bolts/models/self_supervised/swav/swav_finetuner.py index 84b708dfb2..9bdb4a710d 100644 --- a/pl_bolts/models/self_supervised/swav/swav_finetuner.py +++ b/pl_bolts/models/self_supervised/swav/swav_finetuner.py @@ -6,7 +6,7 @@ from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner from pl_bolts.models.self_supervised.swav.swav_module import SwAV -from pl_bolts.transforms.dataset_normalizations import stl10_normalization +from pl_bolts.transforms.dataset_normalizations import stl10_normalization, imagenet_normalization from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform @@ -48,6 +48,26 @@ def cli_main(): # pragma: no-cover ) args.maxpool1 = False + elif args.dataset == 'imagenet': + dm = ImagenetDataModule( + data_dir=args.data_path, + batch_size=args.batch_size, + num_workers=args.num_workers + ) + + dm.train_transforms = SwAVFinetuneTransform( + normalize=imagenet_normalization(), + input_height=dm.size()[-1], + eval_transform=False + ) + dm.val_transforms = SwAVFinetuneTransform( + normalize=imagenet_normalization(), + input_height=dm.size()[-1], + eval_transform=True + ) + + args.num_samples = 0 + args.maxpool1 = True else: raise NotImplementedError("other datasets have not been implemented till now") @@ -55,12 +75,13 @@ def cli_main(): # pragma: no-cover gpus=1, num_samples=args.num_samples, batch_size=args.batch_size, - datamodule=dm + datamodule=dm, + maxpool1=args.maxpool1 ).load_from_checkpoint(args.ckpt_path, strict=False) tuner = SSLFineTuner(backbone, in_features=2048, num_classes=dm.num_classes, hidden_dim=None) trainer = pl.Trainer.from_argparse_args( - args, gpus=1, precision=16, early_stop_callback=True + args, gpus=4, precision=16, early_stop_callback=True ) trainer.fit(tuner, dm) diff --git a/pl_bolts/models/self_supervised/swav/weights_convert.py b/pl_bolts/models/self_supervised/swav/weights_convert.py new file mode 100644 index 0000000000..9a54fb4df6 --- /dev/null +++ b/pl_bolts/models/self_supervised/swav/weights_convert.py @@ -0,0 +1,16 @@ +import torch +from collections import OrderedDict + +swav_imagenet = torch.load('swav_imagenet.pth.tar') + +new_state_dict = OrderedDict() + +for key in swav_imagenet.keys(): + if 'prototype' in key: + continue + new_state_dict[key.replace('module.', 'model.')] = swav_imagenet[key] + +stl10_save = torch.load("epoch=96.ckpt") +stl10_save['state_dict'] = new_state_dict + +torch.save(stl10_save, 'swav_imagenet.ckpt') \ No newline at end of file