diff --git a/pl_bolts/callbacks/ssl_online.py b/pl_bolts/callbacks/ssl_online.py index 788b0f252f..572bfb680a 100644 --- a/pl_bolts/callbacks/ssl_online.py +++ b/pl_bolts/callbacks/ssl_online.py @@ -102,8 +102,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data # log metrics train_acc = accuracy(mlp_preds, y) - pl_module.log('train_acc', train_acc, on_step=True, on_epoch=False) - pl_module.log('train_mlp_loss', mlp_loss, on_step=True, on_epoch=False) + pl_module.log('online_train_acc', train_acc, on_step=True, on_epoch=False) + pl_module.log('online_train_loss', mlp_loss, on_step=True, on_epoch=False) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): x, y = self.to_device(batch, pl_module.device) @@ -119,5 +119,5 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, # log metrics val_acc = accuracy(mlp_preds, y) - pl_module.log('val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True) - pl_module.log('val_mlp_loss', mlp_loss, on_step=False, on_epoch=True, sync_dist=True) + pl_module.log('online_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True) + pl_module.log('online_val_loss', mlp_loss, on_step=False, on_epoch=True, sync_dist=True) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 63ae86d1e0..3dcd43e279 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -90,6 +90,7 @@ def __init__( self.meta_dir = meta_dir self.num_imgs_per_val_class = num_imgs_per_val_class self.batch_size = batch_size + self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes @property def num_classes(self): @@ -144,6 +145,7 @@ def train_dataloader(self): dataset = UnlabeledImagenet(self.data_dir, num_imgs_per_class=-1, + num_imgs_per_class_val_split=self.num_imgs_per_val_class, meta_dir=self.meta_dir, split='train', transform=transforms) diff --git a/pl_bolts/models/self_supervised/swav/swav_finetuner.py b/pl_bolts/models/self_supervised/swav/swav_finetuner.py index cdf7367e1b..a9aaa5fa4c 100644 --- a/pl_bolts/models/self_supervised/swav/swav_finetuner.py +++ b/pl_bolts/models/self_supervised/swav/swav_finetuner.py @@ -92,12 +92,12 @@ def cli_main(): # pragma: no-cover backbone = SwAV( gpus=args.gpus, + nodes=1, num_samples=args.num_samples, batch_size=args.batch_size, - datamodule=dm, maxpool1=args.maxpool1, first_conv=args.first_conv, - dataset='imagenet', + dataset=args.dataset, ).load_from_checkpoint(args.ckpt_path, strict=False) tuner = SSLFineTuner( @@ -117,6 +117,7 @@ def cli_main(): # pragma: no-cover trainer = pl.Trainer( gpus=args.gpus, + num_nodes=1, precision=16, max_epochs=args.num_epochs, distributed_backend='ddp', diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index accaf62b1b..2f0e138513 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -16,13 +16,18 @@ from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 from pl_bolts.optimizers.lars_scheduling import LARSWrapper -from pl_bolts.transforms.dataset_normalizations import cifar10_normalization, stl10_normalization +from pl_bolts.transforms.dataset_normalizations import ( + stl10_normalization, + cifar10_normalization, + imagenet_normalization +) class SwAV(pl.LightningModule): def __init__( self, gpus: int, + nodes: int, num_samples: int, batch_size: int, dataset: str, @@ -54,8 +59,9 @@ def __init__( ): """ Args: - gpus: number of gpus used in training, passed to SwAV module + gpus: number of gpus per node used in training, passed to SwAV module to manage the queue and select distributed sinkhorn + nodes: number of nodes to train on num_samples: number of image samples used for training batch_size: batch size per GPU in ddp dataset: dataset being used for train/val @@ -94,6 +100,7 @@ def __init__( self.save_hyperparameters() self.gpus = gpus + self.nodes = nodes self.arch = arch self.dataset = dataset self.num_samples = num_samples @@ -127,7 +134,7 @@ def __init__( self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs - if self.gpus > 1: + if self.gpus * self.nodes > 1: self.get_assignments = self.distributed_sinkhorn else: self.get_assignments = self.sinkhorn @@ -135,7 +142,7 @@ def __init__( self.model = self.init_model() # compute iters per epoch - global_batch_size = self.gpus * self.batch_size if self.gpus > 0 else self.batch_size + global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size # define LR schedule @@ -435,6 +442,7 @@ def add_model_specific_args(parent_parser): # training params parser.add_argument("--fast_dev_run", action='store_true') + parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") parser.add_argument("--num_workers", default=16, type=int, help="num of workers per GPU") parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd") @@ -471,8 +479,8 @@ def add_model_specific_args(parent_parser): def cli_main(): from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule - from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform + from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform + from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule, ImagenetDataModule parser = ArgumentParser() @@ -515,11 +523,44 @@ def cli_main(): args.size_crops = [32, 16] args.nmb_crops = [2, 1] args.gaussian_blur = False + elif args.dataset == 'imagenet': + args.maxpool1 = True + args.first_conv = True + normalization = imagenet_normalization() + + args.size_crops = [224, 96] + args.min_scale_crops = [0.14, 0.05] + args.max_scale_crops = [1., 0.14] + args.gaussian_blur = True + args.jitter_strength = 1. + + args.batch_size = 64 + args.nodes = 8 + args.gpus = 8 # per-node + args.max_epochs = 800 + + args.optimizer = 'sgd' + args.lars_wrapper = True + args.learning_rate = 4.8 + args.final_lr = 0.0048 + args.start_lr = 0.3 + + args.nmb_prototypes = 3000 + args.online_ft = True + + dm = ImagenetDataModule( + data_dir=args.data_path, + batch_size=args.batch_size, + num_workers=args.num_workers + ) + + args.num_samples = dm.num_samples + args.input_height = dm.size()[-1] else: raise NotImplementedError("other datasets have not been implemented till now") dm.train_transforms = SwAVTrainDataTransform( - normalize=stl10_normalization(), + normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, @@ -529,7 +570,7 @@ def cli_main(): ) dm.val_transforms = SwAVEvalDataTransform( - normalize=stl10_normalization(), + normalize=normalization, size_crops=args.size_crops, nmb_crops=args.nmb_crops, min_scale_crops=args.min_scale_crops, @@ -556,6 +597,7 @@ def cli_main(): max_epochs=args.max_epochs, max_steps=None if args.max_steps == -1 else args.max_steps, gpus=args.gpus, + num_nodes=args.nodes, distributed_backend='ddp' if args.gpus > 1 else None, sync_batchnorm=True if args.gpus > 1 else False, precision=32 if args.fp32 else 16, diff --git a/pl_bolts/models/self_supervised/swav/transforms.py b/pl_bolts/models/self_supervised/swav/transforms.py index de5b72e620..f5e98a82c4 100644 --- a/pl_bolts/models/self_supervised/swav/transforms.py +++ b/pl_bolts/models/self_supervised/swav/transforms.py @@ -55,8 +55,12 @@ def __init__( ] if self.gaussian_blur: + kernel_size = int(0.1 * self.size_crops[0]) + if kernel_size % 2 == 0: + kernel_size += 1 + color_transform.append( - GaussianBlur(kernel_size=int(0.1 * self.size_crops[0]), p=0.5) + GaussianBlur(kernel_size=kernel_size, p=0.5) ) self.color_transform = transforms.Compose(color_transform) diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 2c37a13274..95175b7bb6 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -117,6 +117,7 @@ def test_swav(tmpdir, datadir): arch='resnet18', hidden_mlp=512, gpus=0, + nodes=1, num_samples=datamodule.num_samples, batch_size=batch_size, nmb_crops=[2, 1],