From 9dee7e040faf82ec6360b48d93897497c6f96cd8 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Apr 2021 20:23:25 -0400 Subject: [PATCH 01/13] updates optimizer --- docs/source/self_supervised_models.rst | 10 +- .../self_supervised/byol/byol_module.py | 2 - .../self_supervised/simclr/simclr_module.py | 78 ++++------- .../self_supervised/simsiam/simsiam_module.py | 90 +++++-------- .../self_supervised/swav/swav_module.py | 79 ++++------- pl_bolts/optimizers/__init__.py | 10 +- pl_bolts/optimizers/lars.py | 123 ++++++++++++++++++ pl_bolts/optimizers/lars_scheduling.py | 98 -------------- pl_bolts/optimizers/lr_scheduler.py | 29 +++++ 9 files changed, 242 insertions(+), 277 deletions(-) create mode 100644 pl_bolts/optimizers/lars.py delete mode 100644 pl_bolts/optimizers/lars_scheduling.py diff --git a/docs/source/self_supervised_models.rst b/docs/source/self_supervised_models.rst index 5e89cde822..715380ef8c 100644 --- a/docs/source/self_supervised_models.rst +++ b/docs/source/self_supervised_models.rst @@ -312,7 +312,7 @@ CIFAR-10 baseline * - Ours - 88.50 - `resnet50 `_ - - `LARS-SGD `_ + - LARS - 2048 - 800 (4 hours) - 8 V100 (16GB) @@ -361,7 +361,6 @@ To reproduce:: -- num_workers 16 --optimizer sgd --learning_rate 1.5 - --lars_wrapper --exclude_bn_bias --max_epochs 800 --online_ft @@ -401,7 +400,7 @@ Imagenet baseline for SimCLR * - Ours - 68.4 - `resnet50 `_ - - `LARS-SGD `_ + - LARS - 4096 - 800 - 64 V100 (16GB) @@ -533,7 +532,7 @@ The original paper does not provide baselines on STL10. * - Ours - `86.72 `_ - SwAV resnet50 - - `LARS `_ + - LARS - 128 - No - 100 (~9 hr) @@ -585,7 +584,6 @@ To reproduce:: python swav_module.py --online_ft --gpus 1 - --lars_wrapper --batch_size 128 --learning_rate 1e-3 --gaussian_blur @@ -630,7 +628,7 @@ Imagenet baseline for SwAV * - Ours - 74 - `resnet50 `_ - - `LARS-SGD `_ + - LARS - 4096 - 800 - 64 V100 (16GB) diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 0e883711fd..cc4a749a16 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -10,7 +10,6 @@ from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate from pl_bolts.models.self_supervised.byol.models import SiameseArm -from pl_bolts.optimizers.lars_scheduling import LARSWrapper from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR @@ -141,7 +140,6 @@ def validation_step(self, batch, batch_idx): def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) - optimizer = LARSWrapper(optimizer) scheduler = LinearWarmupCosineAnnealingLR( optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs ) diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index b8441b6011..0feb3842ed 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -12,7 +12,8 @@ from torch.optim.optimizer import Optimizer from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 -from pl_bolts.optimizers.lars_scheduling import LARSWrapper +from pl_bolts.optimizers.lars import LARS +from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay from pl_bolts.transforms.dataset_normalizations import ( cifar10_normalization, imagenet_normalization, @@ -79,7 +80,6 @@ def __init__( first_conv: bool = True, maxpool1: bool = True, optimizer: str = 'adam', - lars_wrapper: bool = True, exclude_bn_bias: bool = False, start_lr: float = 0., learning_rate: float = 1e-3, @@ -112,7 +112,6 @@ def __init__( self.maxpool1 = maxpool1 self.optim = optimizer - self.lars_wrapper = lars_wrapper self.exclude_bn_bias = exclude_bn_bias self.weight_decay = weight_decay self.temperature = temperature @@ -131,19 +130,6 @@ def __init__( global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size - # define LR schedule - warmup_lr_schedule = np.linspace( - self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs - ) - iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)) - cosine_lr_schedule = np.array([ - self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * - (1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)))) - for t in iters - ]) - - self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) - def init_model(self): if self.arch == 'resnet18': backbone = resnet18 @@ -179,9 +165,6 @@ def shared_step(self, batch): def training_step(self, batch, batch_idx): loss = self.shared_step(batch) - # log LR (LearningRateLogger callback doesn't work with LARSWrapper) - self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False) - self.log('train_loss', loss, on_step=True, on_epoch=False) return loss @@ -217,41 +200,30 @@ def configure_optimizers(self): else: params = self.parameters() - if self.optim == 'sgd': - optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay) + if self.optim == 'lars': + optimizer = LARS( + params, + lr=self.learning_rate, + momentum=0.9, + weight_decay=self.weight_decay, + trust_coefficient=0.001, + ) elif self.optim == 'adam': optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) - if self.lars_wrapper: - optimizer = LARSWrapper( - optimizer, - eta=0.001, # trust coefficient - clip=False - ) + warmup_steps = self.train_iters_per_epoch * self.warmup_epochs + total_steps = self.train_iters_per_epoch * self.max_epochs - return optimizer + scheduler = { + "scheduler": torch.optim.lr_scheduler.LambdaLR( + optimizer, + linear_warmup_decay(warmup_steps, total_steps, cosine=True), + ), + "interval": "step", + "frequency": 1, + } - def optimizer_step( - self, - epoch: int = None, - batch_idx: int = None, - optimizer: Optimizer = None, - optimizer_idx: int = None, - optimizer_closure: Optional[Callable] = None, - on_tpu: bool = None, - using_native_amp: bool = None, - using_lbfgs: bool = None, - ) -> None: - # warm-up + decay schedule placed here since LARSWrapper is not optimizer class - # adjust LR of optim contained within LARSWrapper - for param_group in optimizer.param_groups: - param_group["lr"] = self.lr_schedule[self.trainer.global_step] - - # from lightning - if not isinstance(optimizer, LightningOptimizer): - # wraps into LightingOptimizer only for running step - optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) - optimizer.step(closure=optimizer_closure) + return [optimizer], [scheduler] def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6): """ @@ -317,8 +289,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") - parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd") - parser.add_argument("--lars_wrapper", action='store_true', help="apple lars wrapper over optimizer used") + parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") parser.add_argument('--exclude_bn_bias', action='store_true', help="exclude bn/bias from weight decay") parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") parser.add_argument("--max_steps", default=-1, type=int, help="max steps") @@ -393,8 +364,7 @@ def cli_main(): args.gpus = 8 # per-node args.max_epochs = 800 - args.optimizer = 'sgd' - args.lars_wrapper = True + args.optimizer = 'lars' args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 @@ -434,8 +404,10 @@ def cli_main(): dataset=args.dataset, ) + lr_monitor = LearningRateMonitor(logging_interval="step") model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint] + callbacks.append(lr_monitor) trainer = pl.Trainer( max_epochs=args.max_epochs, diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 31b1c0fd60..f1fd4d701a 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -12,7 +12,8 @@ from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 from pl_bolts.models.self_supervised.simsiam.models import SiameseArm -from pl_bolts.optimizers.lars_scheduling import LARSWrapper +from pl_bolts.optimizers.lars import LARS +from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay from pl_bolts.transforms.dataset_normalizations import ( cifar10_normalization, imagenet_normalization, @@ -83,7 +84,6 @@ def __init__( first_conv: bool = True, maxpool1: bool = True, optimizer: str = 'adam', - lars_wrapper: bool = True, exclude_bn_bias: bool = False, start_lr: float = 0., learning_rate: float = 1e-3, @@ -118,7 +118,6 @@ def __init__( self.maxpool1 = maxpool1 self.optim = optimizer - self.lars_wrapper = lars_wrapper self.exclude_bn_bias = exclude_bn_bias self.weight_decay = weight_decay self.temperature = temperature @@ -137,19 +136,6 @@ def __init__( global_batch_size = self.num_nodes * nb_gpus * self.batch_size if nb_gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size - # define LR schedule - warmup_lr_schedule = np.linspace( - self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs - ) - iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)) - cosine_lr_schedule = np.array([ - self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * - (1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)))) - for t in iters - ]) - - self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) - def init_model(self): if self.arch == 'resnet18': backbone = resnet18 @@ -227,52 +213,30 @@ def configure_optimizers(self): else: params = self.parameters() - if self.optim == 'sgd': - optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay) + if self.optim == 'lars': + optimizer = LARS( + params, + lr=self.learning_rate, + momentum=0.9, + weight_decay=self.weight_decay, + trust_coefficient=0.001, + ) elif self.optim == 'adam': optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) - if self.lars_wrapper: - optimizer = LARSWrapper( - optimizer, - eta=0.001, # trust coefficient - clip=False - ) + warmup_steps = self.train_iters_per_epoch * self.warmup_epochs + total_steps = self.train_iters_per_epoch * self.max_epochs - return optimizer + scheduler = { + "scheduler": torch.optim.lr_scheduler.LambdaLR( + optimizer, + linear_warmup_decay(warmup_steps, total_steps, cosine=True), + ), + "interval": "step", + "frequency": 1, + } - def optimizer_step( - self, - epoch: int, - batch_idx: int, - optimizer: Optimizer, - optimizer_idx: int, - optimizer_closure: Optional[Callable] = None, - on_tpu: bool = False, - using_native_amp: bool = False, - using_lbfgs: bool = False, - ) -> None: - # warm-up + decay schedule placed here since LARSWrapper is not optimizer class - # adjust LR of optim contained within LARSWrapper - if self.lars_wrapper: - for param_group in optimizer.optim.param_groups: - param_group["lr"] = self.lr_schedule[self.trainer.global_step] - else: - for param_group in optimizer.param_groups: - param_group["lr"] = self.lr_schedule[self.trainer.global_step] - - # log LR (LearningRateLogger callback doesn't work with LARSWrapper) - self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False) - - # from lightning - if self.trainer.amp_backend == AMPType.NATIVE: - optimizer_closure() - self.trainer.scaler.step(optimizer) - elif self.trainer.amp_backend == AMPType.APEX: - optimizer_closure() - optimizer.step() - else: - optimizer.step(closure=optimizer_closure) + return [optimizer], [scheduler] @staticmethod def add_model_specific_args(parent_parser): @@ -295,8 +259,7 @@ def add_model_specific_args(parent_parser): # training params parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") - parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd") - parser.add_argument("--lars_wrapper", action="store_true", help="apple lars wrapper over optimizer used") + parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") @@ -381,7 +344,7 @@ def cli_main(): args.gpus = 8 # per-node args.max_epochs = 800 - args.optimizer = "sgd" + args.optimizer = "lars" args.lars_wrapper = True args.learning_rate = 4.8 args.final_lr = 0.0048 @@ -423,10 +386,15 @@ def cli_main(): dataset=args.dataset, ) + lr_monitor = LearningRateMonitor(logging_interval="step") + model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') + callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint] + callbacks.append(lr_monitor) + trainer = pl.Trainer.from_argparse_args( args, sync_batchnorm=True if args.gpus > 1 else False, - callbacks=[online_evaluator] if args.online_ft else None, + callbacks=callbacks, ) trainer.fit(model, datamodule=dm) diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index 38a56bbaca..3f50a1ab3a 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -16,7 +16,8 @@ from torch.optim.optimizer import Optimizer from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 -from pl_bolts.optimizers.lars_scheduling import LARSWrapper +from pl_bolts.optimizers.lars import LARS +from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay from pl_bolts.transforms.dataset_normalizations import ( cifar10_normalization, imagenet_normalization, @@ -50,7 +51,6 @@ def __init__( first_conv: bool = True, maxpool1: bool = True, optimizer: str = 'adam', - lars_wrapper: bool = True, exclude_bn_bias: bool = False, start_lr: float = 0., learning_rate: float = 1e-3, @@ -90,7 +90,6 @@ def __init__( maxpool1: keep first maxpool layer same as the original resnet architecture, if set to false, first maxpool is turned off (cifar10, maybe stl10) optimizer: optimizer to use - lars_wrapper: use LARS wrapper over the optimizer exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers start_lr: starting lr for linear warmup learning_rate: learning rate @@ -124,7 +123,6 @@ def __init__( self.maxpool1 = maxpool1 self.optim = optimizer - self.lars_wrapper = lars_wrapper self.exclude_bn_bias = exclude_bn_bias self.weight_decay = weight_decay self.epsilon = epsilon @@ -147,19 +145,6 @@ def __init__( global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size self.train_iters_per_epoch = self.num_samples // global_batch_size - # define LR schedule - warmup_lr_schedule = np.linspace( - self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs - ) - iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)) - cosine_lr_schedule = np.array([ - self.final_lr + 0.5 * (self.learning_rate - self.final_lr) * - (1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs)))) - for t in iters - ]) - - self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) - self.queue = None self.softmax = nn.Softmax(dim=1) @@ -268,9 +253,6 @@ def shared_step(self, batch): def training_step(self, batch, batch_idx): loss = self.shared_step(batch) - # log LR (LearningRateLogger callback doesn't work with LARSWrapper) - self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False) - self.log('train_loss', loss, on_step=True, on_epoch=False) return loss @@ -300,41 +282,30 @@ def configure_optimizers(self): else: params = self.parameters() - if self.optim == 'sgd': - optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay) + if self.optim == 'lars': + optimizer = LARS( + params, + lr=self.learning_rate, + momentum=0.9, + weight_decay=self.weight_decay, + trust_coefficient=0.001, + ) elif self.optim == 'adam': optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) - if self.lars_wrapper: - optimizer = LARSWrapper( - optimizer, - eta=0.001, # trust coefficient - clip=False - ) + warmup_steps = self.train_iters_per_epoch * self.warmup_epochs + total_steps = self.train_iters_per_epoch * self.max_epochs - return optimizer + scheduler = { + "scheduler": torch.optim.lr_scheduler.LambdaLR( + optimizer, + linear_warmup_decay(warmup_steps, total_steps, cosine=True), + ), + "interval": "step", + "frequency": 1, + } - def optimizer_step( - self, - epoch: int = None, - batch_idx: int = None, - optimizer: Optimizer = None, - optimizer_idx: int = None, - optimizer_closure: Optional[Callable] = None, - on_tpu: bool = None, - using_native_amp: bool = None, - using_lbfgs: bool = None, - ) -> None: - # warm-up + decay schedule placed here since LARSWrapper is not optimizer class - # adjust LR of optim contained within LARSWrapper - for param_group in optimizer.param_groups: - param_group["lr"] = self.lr_schedule[self.trainer.global_step] - - # from lightning - if not isinstance(optimizer, LightningOptimizer): - # wraps into LightingOptimizer only for running step - optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer) - optimizer.step(closure=optimizer_closure) + return [optimizer], [scheduler] def sinkhorn(self, Q, nmb_iters): with torch.no_grad(): @@ -433,8 +404,7 @@ def add_model_specific_args(parent_parser): parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training") parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") - parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd") - parser.add_argument("--lars_wrapper", action='store_true', help="apple lars wrapper over optimizer used") + parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") parser.add_argument('--exclude_bn_bias', action='store_true', help="exclude bn/bias from weight decay") parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") parser.add_argument("--max_steps", default=-1, type=int, help="max steps") @@ -536,8 +506,7 @@ def cli_main(): args.gpus = 8 # per-node args.max_epochs = 800 - args.optimizer = 'sgd' - args.lars_wrapper = True + args.optimizer = 'lars' args.learning_rate = 4.8 args.final_lr = 0.0048 args.start_lr = 0.3 @@ -586,8 +555,10 @@ def cli_main(): dataset=args.dataset, ) + lr_monitor = LearningRateMonitor(logging_interval="step") model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss') callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint] + callbacks.append(lr_monitor) trainer = pl.Trainer( max_epochs=args.max_epochs, diff --git a/pl_bolts/optimizers/__init__.py b/pl_bolts/optimizers/__init__.py index 1d5ee21175..c6b2878ce2 100644 --- a/pl_bolts/optimizers/__init__.py +++ b/pl_bolts/optimizers/__init__.py @@ -1,7 +1,11 @@ -from pl_bolts.optimizers.lars_scheduling import LARSWrapper # noqa: F401 -from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR # noqa: F401 +from pl_bolts.optimizers.lars import LARS # noqa: F401 +from pl_bolts.optimizers.lr_scheduler import ( + LinearWarmupCosineAnnealingLR, # noqa: F401 + linear_warmup_decay, # noqa: F401 +) __all__ = [ - "LARSWrapper", + "LARS", "LinearWarmupCosineAnnealingLR", + "linear_warmup_decay", ] diff --git a/pl_bolts/optimizers/lars.py b/pl_bolts/optimizers/lars.py new file mode 100644 index 0000000000..d4a9729b3d --- /dev/null +++ b/pl_bolts/optimizers/lars.py @@ -0,0 +1,123 @@ +""" +References: + - https://arxiv.org/pdf/1708.03888.pdf + - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py +""" +import torch +from torch.optim.optimizer import Optimizer, required + + +class LARS(Optimizer): + r"""Extends SGD in PyTorch with LARS scaling from the paper + `Large batch training of Convolutional Networks `_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001) + eps (float, optional): eps for division denominator (default: 1e-8) + Example: + >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + .. note:: + The application of momentum in the SGD part is modified according to + the PyTorch standards. LARS scaling fits into the equation in the + following fashion. + .. math:: + \begin{aligned} + g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\ + v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ + p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, + \end{aligned} + where :math:`p`, :math:`g`, :math:`v`, :math:`\mu` and :math:`\beta` denote the + parameters, gradient, velocity, momentum, and weight decay respectively. + The :math:`lars_lr` is defined by Eq. 6 in the paper. + The Nesterov version is analogously modified. + + .. warning:: + Parameters with weight decay set to 0 will automatically be excluded from + layer-wise LR scaling. This is to ensure consistency with papers like SimCLR + and BYOL. + """ + + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-8): + if lr is not required and lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, + trust_coefficient=trust_coefficient, eps=eps) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + super(LARS, self).__init__(params, defaults) + + def __setstate__(self, state): + super(LARS, self).__setstate__(state) + + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # exclude scaling for params with 0 weight decay + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + + d_p = p.grad + p_norm = torch.norm(p.data) + g_norm = torch.norm(p.grad.data) + + # lars scaling + weight decay part + if weight_decay != 0: + if p_norm != 0 and g_norm != 0: + lars_lr = p_norm / (g_norm + p_norm * weight_decay + group['eps']) + lars_lr *= group['trust_coefficient'] + + d_p = d_p.add(p, alpha=weight_decay) + d_p *= lars_lr + + # sgd part + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + + p.add_(d_p, alpha=-group['lr']) + + return loss diff --git a/pl_bolts/optimizers/lars_scheduling.py b/pl_bolts/optimizers/lars_scheduling.py deleted file mode 100644 index a54cc67ec5..0000000000 --- a/pl_bolts/optimizers/lars_scheduling.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -References: - - https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py - - https://arxiv.org/pdf/1708.03888.pdf - - https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py -""" -import torch -from torch.optim import Optimizer - - -class LARSWrapper(object): - """ - Wrapper that adds LARS scheduling to any optimizer. This helps stability with huge batch sizes. - """ - - def __init__(self, optimizer, eta=0.02, clip=True, eps=1e-8): - """ - Args: - optimizer: torch optimizer - eta: LARS coefficient (trust) - clip: True to clip LR - eps: adaptive_lr stability coefficient - """ - self.optim = optimizer - self.eta = eta - self.eps = eps - self.clip = clip - - # transfer optim methods - self.state_dict = self.optim.state_dict - self.load_state_dict = self.optim.load_state_dict - self.zero_grad = self.optim.zero_grad - self.add_param_group = self.optim.add_param_group - self.__setstate__ = self.optim.__setstate__ - self.__getstate__ = self.optim.__getstate__ - self.__repr__ = self.optim.__repr__ - - @property - def defaults(self): - return self.optim.defaults - - @defaults.setter - def defaults(self, defaults): - self.optim.defaults = defaults - - @property - def __class__(self): - return Optimizer - - @property - def state(self): - return self.optim.state - - @property - def param_groups(self): - return self.optim.param_groups - - @param_groups.setter - def param_groups(self, value): - self.optim.param_groups = value - - @torch.no_grad() - def step(self, closure=None): - weight_decays = [] - - for group in self.optim.param_groups: - weight_decay = group.get('weight_decay', 0) - weight_decays.append(weight_decay) - - # reset weight decay - group['weight_decay'] = 0 - - # update the parameters - [self.update_p(p, group, weight_decay) for p in group['params'] if p.grad is not None] - - # update the optimizer - self.optim.step(closure=closure) - - # return weight decay control to optimizer - for group_idx, group in enumerate(self.optim.param_groups): - group['weight_decay'] = weight_decays[group_idx] - - def update_p(self, p, group, weight_decay): - # calculate new norms - p_norm = torch.norm(p.data) - g_norm = torch.norm(p.grad.data) - - if p_norm != 0 and g_norm != 0: - # calculate new lr - new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps) - - # clip lr - if self.clip: - new_lr = min(new_lr / group['lr'], 1) - - # update params with clipped lr - p.grad.data += weight_decay * p.data - p.grad.data *= new_lr diff --git a/pl_bolts/optimizers/lr_scheduler.py b/pl_bolts/optimizers/lr_scheduler.py index d12e15a0ee..4221e1b050 100644 --- a/pl_bolts/optimizers/lr_scheduler.py +++ b/pl_bolts/optimizers/lr_scheduler.py @@ -118,3 +118,32 @@ def _get_closed_form_lr(self) -> List[float]: (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) for base_lr in self.base_lrs ] + + +# warmup + decay as a function +def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False): + """ + Linear warmup for warmup_steps, optionally with cosine annealing or + linear decay to 0 at total_steps + """ + assert not (linear and cosine) + + def fn(step): + if step < warmup_steps: + return float(step) / float(max(1, warmup_steps)) + + if not (cosine or linear): + # no decay + return 1.0 + + progress = float(step - warmup_steps) / float( + max(1, total_steps - warmup_steps) + ) + if cosine: + # cosine decay + return 0.5 * (1.0 + math.cos(math.pi * progress)) + + # linear decay + return 1.0 - progress + + return fn From 294c1a7dd7175d067431154d9abac74196ef90aa Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Apr 2021 21:27:33 -0400 Subject: [PATCH 02/13] updates simclr loss --- pl_bolts/models/self_supervised/simclr/simclr_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 0feb3842ed..8f5a35cc71 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -252,8 +252,8 @@ def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6): sim = torch.exp(cov / temperature) neg = sim.sum(dim=-1) - # from each row, subtract e^1 to remove similarity measure for x1.x1 - row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device) + # from each row, subtract e^(1/temp) to remove similarity measure for x1.x1 + row_sub = torch.Tensor(neg.shape).fill_(math.e ** (1 / temperature)).to(neg.device) neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability # Positive similarity, pos becomes [2 * batch_size] From 03c4f04d7af5e1ee2c52df93fd3ad28129291849 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Apr 2021 21:32:18 -0400 Subject: [PATCH 03/13] fix imports --- pl_bolts/models/self_supervised/simclr/simclr_module.py | 2 +- pl_bolts/models/self_supervised/simsiam/simsiam_module.py | 2 ++ pl_bolts/models/self_supervised/swav/swav_module.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 8f5a35cc71..63fd66a0aa 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -5,8 +5,8 @@ import numpy as np import pytorch_lightning as pl import torch +from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.core.optimizer import LightningOptimizer from torch import nn from torch.nn import functional as F from torch.optim.optimizer import Optimizer diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index f1fd4d701a..15471e4795 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -9,6 +9,8 @@ from pytorch_lightning.utilities import AMPType from torch.nn import functional as F from torch.optim.optimizer import Optimizer +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks import ModelCheckpoint from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 from pl_bolts.models.self_supervised.simsiam.models import SiameseArm diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index 3f50a1ab3a..cb64aadb49 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -9,8 +9,8 @@ import numpy as np import pytorch_lightning as pl import torch +from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.core.optimizer import LightningOptimizer from torch import distributed as dist from torch import nn from torch.optim.optimizer import Optimizer From 5d8bd9084262ca69ac5c1e41c8d0582de932da93 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Apr 2021 21:35:01 -0400 Subject: [PATCH 04/13] fix imports --- pl_bolts/models/self_supervised/simclr/simclr_module.py | 3 --- pl_bolts/models/self_supervised/simsiam/simsiam_module.py | 5 ----- pl_bolts/models/self_supervised/swav/swav_module.py | 3 --- 3 files changed, 11 deletions(-) diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 63fd66a0aa..d31cbcf5d6 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -1,15 +1,12 @@ import math from argparse import ArgumentParser -from typing import Callable, Optional -import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint from torch import nn from torch.nn import functional as F -from torch.optim.optimizer import Optimizer from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 from pl_bolts.optimizers.lars import LARS diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 15471e4795..2ea49b377a 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -1,14 +1,9 @@ -import math from argparse import ArgumentParser -from typing import Callable, Optional -import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning import seed_everything -from pytorch_lightning.utilities import AMPType from torch.nn import functional as F -from torch.optim.optimizer import Optimizer from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index cb64aadb49..b4fa005503 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -1,10 +1,8 @@ """ Adapted from official swav implementation: https://github.com/facebookresearch/swav """ -import math import os from argparse import ArgumentParser -from typing import Callable, Optional import numpy as np import pytorch_lightning as pl @@ -13,7 +11,6 @@ from pytorch_lightning.callbacks import ModelCheckpoint from torch import distributed as dist from torch import nn -from torch.optim.optimizer import Optimizer from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 from pl_bolts.optimizers.lars import LARS From ec9b94a7c3b2e5357325be37f7a2d1568bbc2bba Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Apr 2021 21:48:25 -0400 Subject: [PATCH 05/13] update docs ex --- pl_bolts/optimizers/lars.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pl_bolts/optimizers/lars.py b/pl_bolts/optimizers/lars.py index d4a9729b3d..45921eb934 100644 --- a/pl_bolts/optimizers/lars.py +++ b/pl_bolts/optimizers/lars.py @@ -8,7 +8,7 @@ class LARS(Optimizer): - r"""Extends SGD in PyTorch with LARS scaling from the paper + """Extends SGD in PyTorch with LARS scaling from the paper `Large batch training of Convolutional Networks `_. Args: params (iterable): iterable of parameters to optimize or dicts defining @@ -21,6 +21,11 @@ class LARS(Optimizer): trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001) eps (float, optional): eps for division denominator (default: 1e-8) Example: + >>> model = torch.nn.Linear(10, 1) + >>> input = torch.Tensor(10) + >>> target = torch.Tensor([1.]) + >>> loss_fn = lambda input, target: (input - target) ** 2 + >>> # >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() From cd50eebe251d9483f04faa440a550c46f8e0a63a Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 6 Apr 2021 21:57:22 -0400 Subject: [PATCH 06/13] update docs ex --- pl_bolts/optimizers/lars.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pl_bolts/optimizers/lars.py b/pl_bolts/optimizers/lars.py index 45921eb934..9433fd952d 100644 --- a/pl_bolts/optimizers/lars.py +++ b/pl_bolts/optimizers/lars.py @@ -20,6 +20,7 @@ class LARS(Optimizer): nesterov (bool, optional): enables Nesterov momentum (default: False) trust_coefficient (float, optional): trust coefficient for computing LR (default: 0.001) eps (float, optional): eps for division denominator (default: 1e-8) + Example: >>> model = torch.nn.Linear(10, 1) >>> input = torch.Tensor(10) @@ -30,6 +31,7 @@ class LARS(Optimizer): >>> optimizer.zero_grad() >>> loss_fn(model(input), target).backward() >>> optimizer.step() + .. note:: The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the From 72d92f4e0507bac9e2b460ab34c6ea2ff2a10f7f Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 7 Apr 2021 04:39:59 -0400 Subject: [PATCH 07/13] update docs ex --- pl_bolts/optimizers/lars.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pl_bolts/optimizers/lars.py b/pl_bolts/optimizers/lars.py index 9433fd952d..3b6c61f048 100644 --- a/pl_bolts/optimizers/lars.py +++ b/pl_bolts/optimizers/lars.py @@ -36,12 +36,14 @@ class LARS(Optimizer): The application of momentum in the SGD part is modified according to the PyTorch standards. LARS scaling fits into the equation in the following fashion. + .. math:: \begin{aligned} g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\ v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, \end{aligned} + where :math:`p`, :math:`g`, :math:`v`, :math:`\mu` and :math:`\beta` denote the parameters, gradient, velocity, momentum, and weight decay respectively. The :math:`lars_lr` is defined by Eq. 6 in the paper. From 3b5a47ba5b552090b122d54b7a403ed797ac2421 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Wed, 7 Apr 2021 18:45:50 -0400 Subject: [PATCH 08/13] version based skipif for igpt --- tests/models/test_vision.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index e0c8995349..543f878693 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -1,6 +1,7 @@ import pytorch_lightning as pl import torch from torch.utils.data import DataLoader +from packaging import version from pl_bolts.datamodules import FashionMNISTDataModule, MNISTDataModule from pl_bolts.datasets import DummyDataset @@ -14,6 +15,9 @@ def train_dataloader(self): return DataLoader(train_ds, batch_size=1) +@pytest.mark.skipif( + version.parse(pl.__version__) > version.parse("1.1.0"), + reason="igpt code not updated for latest lightning") def test_igpt(tmpdir, datadir): pl.seed_everything(0) dm = MNISTDataModule(data_dir=datadir, normalize=False) From 4344aa9747197fcfacdc9170f06688770698865b Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 8 Apr 2021 02:17:06 -0400 Subject: [PATCH 09/13] version based skipif for igpt --- tests/models/test_vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 543f878693..25ac43a70c 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -1,3 +1,4 @@ +import pytest import pytorch_lightning as pl import torch from torch.utils.data import DataLoader From 68d4b5af6084d7a9530a7263d6f2c05e0561386f Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Mon, 12 Apr 2021 14:09:28 -0400 Subject: [PATCH 10/13] fixed val loss in simsiam --- pl_bolts/models/self_supervised/simsiam/simsiam_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 2ea49b377a..61bcc3fa13 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -164,7 +164,7 @@ def training_step(self, batch, batch_idx): loss = self.cosine_similarity(h1, z2) / 2 + self.cosine_similarity(h2, z1) / 2 # log results - self.log_dict({"loss": loss}) + self.log_dict({"train_loss": loss}) return loss @@ -177,7 +177,7 @@ def validation_step(self, batch, batch_idx): loss = self.cosine_similarity(h1, z2) / 2 + self.cosine_similarity(h2, z1) / 2 # log results - self.log_dict({"loss": loss}) + self.log_dict({"val_loss": loss}) return loss From 82070740281f8eb95e3b0ff91eb5bfa169c9ab37 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 12 Apr 2021 20:16:54 +0200 Subject: [PATCH 11/13] update --- .github/workflows/ci_test-full.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index e411e4e3bc..a250f24caa 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -35,6 +35,7 @@ jobs: - name: Setup macOS if: runner.os == 'macOS' run: | + brew update # Todo: find a better way... brew install libomp # https://github.com/pytorch/pytorch/issues/20030 - name: Set min. dependencies From a9f54db28d42ed0de61cfa91bda23aebe064aeb5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 12 Apr 2021 20:17:39 +0200 Subject: [PATCH 12/13] update --- .github/workflows/ci_test-base.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index 53e8cec1e7..5749fd7f54 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -30,6 +30,7 @@ jobs: - name: Setup macOS if: runner.os == 'macOS' run: | + brew update # Todo: find a better way... brew install libomp # https://github.com/pytorch/pytorch/issues/20030 # Note: This uses an internal pip API and may not always work From 06373d8c88e18dbe5a2f0701d6ce4b04053dcd98 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 12 Apr 2021 20:21:54 +0200 Subject: [PATCH 13/13] formatting --- .../self_supervised/simclr/simclr_module.py | 5 ++-- .../self_supervised/simsiam/simsiam_module.py | 3 +-- .../self_supervised/swav/swav_module.py | 3 +-- pl_bolts/optimizers/__init__.py | 6 ++--- pl_bolts/optimizers/lars.py | 25 +++++++++++++++---- pl_bolts/optimizers/lr_scheduler.py | 4 +-- tests/models/test_vision.py | 6 ++--- 7 files changed, 30 insertions(+), 22 deletions(-) diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index d31cbcf5d6..92027618a1 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -3,8 +3,7 @@ import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks import LearningRateMonitor -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from torch import nn from torch.nn import functional as F @@ -250,7 +249,7 @@ def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6): neg = sim.sum(dim=-1) # from each row, subtract e^(1/temp) to remove similarity measure for x1.x1 - row_sub = torch.Tensor(neg.shape).fill_(math.e ** (1 / temperature)).to(neg.device) + row_sub = torch.Tensor(neg.shape).fill_(math.e**(1 / temperature)).to(neg.device) neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability # Positive similarity, pos becomes [2 * batch_size] diff --git a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py index 61bcc3fa13..8167b0f1d3 100644 --- a/pl_bolts/models/self_supervised/simsiam/simsiam_module.py +++ b/pl_bolts/models/self_supervised/simsiam/simsiam_module.py @@ -3,9 +3,8 @@ import pytorch_lightning as pl import torch from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from torch.nn import functional as F -from pytorch_lightning.callbacks import LearningRateMonitor -from pytorch_lightning.callbacks import ModelCheckpoint from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 from pl_bolts.models.self_supervised.simsiam.models import SiameseArm diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index b4fa005503..475f984a43 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -7,8 +7,7 @@ import numpy as np import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks import LearningRateMonitor -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from torch import distributed as dist from torch import nn diff --git a/pl_bolts/optimizers/__init__.py b/pl_bolts/optimizers/__init__.py index c6b2878ce2..d5bcaad353 100644 --- a/pl_bolts/optimizers/__init__.py +++ b/pl_bolts/optimizers/__init__.py @@ -1,8 +1,6 @@ from pl_bolts.optimizers.lars import LARS # noqa: F401 -from pl_bolts.optimizers.lr_scheduler import ( - LinearWarmupCosineAnnealingLR, # noqa: F401 - linear_warmup_decay, # noqa: F401 -) +from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay # noqa: F401 +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR # noqa: F401 __all__ = [ "LARS", diff --git a/pl_bolts/optimizers/lars.py b/pl_bolts/optimizers/lars.py index 3b6c61f048..6de26f733b 100644 --- a/pl_bolts/optimizers/lars.py +++ b/pl_bolts/optimizers/lars.py @@ -55,8 +55,17 @@ class LARS(Optimizer): and BYOL. """ - def __init__(self, params, lr=required, momentum=0, dampening=0, - weight_decay=0, nesterov=False, trust_coefficient=0.001, eps=1e-8): + def __init__( + self, + params, + lr=required, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + trust_coefficient=0.001, + eps=1e-8 + ): if lr is not required and lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -64,9 +73,15 @@ def __init__(self, params, lr=required, momentum=0, dampening=0, if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, - trust_coefficient=trust_coefficient, eps=eps) + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + trust_coefficient=trust_coefficient, + eps=eps + ) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") diff --git a/pl_bolts/optimizers/lr_scheduler.py b/pl_bolts/optimizers/lr_scheduler.py index 4221e1b050..120e4f1ccb 100644 --- a/pl_bolts/optimizers/lr_scheduler.py +++ b/pl_bolts/optimizers/lr_scheduler.py @@ -136,9 +136,7 @@ def fn(step): # no decay return 1.0 - progress = float(step - warmup_steps) / float( - max(1, total_steps - warmup_steps) - ) + progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps)) if cosine: # cosine decay return 0.5 * (1.0 + math.cos(math.pi * progress)) diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 25ac43a70c..4b41e8620c 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -1,8 +1,8 @@ import pytest import pytorch_lightning as pl import torch -from torch.utils.data import DataLoader from packaging import version +from torch.utils.data import DataLoader from pl_bolts.datamodules import FashionMNISTDataModule, MNISTDataModule from pl_bolts.datasets import DummyDataset @@ -17,8 +17,8 @@ def train_dataloader(self): @pytest.mark.skipif( - version.parse(pl.__version__) > version.parse("1.1.0"), - reason="igpt code not updated for latest lightning") + version.parse(pl.__version__) > version.parse("1.1.0"), reason="igpt code not updated for latest lightning" +) def test_igpt(tmpdir, datadir): pl.seed_everything(0) dm = MNISTDataModule(data_dir=datadir, normalize=False)