Skip to content

Commit

Permalink
Updates all scripts to LARS (#613)
Browse files Browse the repository at this point in the history
* updates optimizer

* updates simclr loss

* fix imports

* fix imports

* update docs ex

* update docs ex

* update docs ex

* version based skipif for igpt

* version based skipif for igpt

* fixed val loss in simsiam

* update

* update

* formatting

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
ananyahjha93 and Borda authored Apr 12, 2021
1 parent 2ba3f62 commit 51fe7d9
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 295 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci_test-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Setup macOS
if: runner.os == 'macOS'
run: |
brew update # Todo: find a better way...
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
# Note: This uses an internal pip API and may not always work
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci_test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- name: Setup macOS
if: runner.os == 'macOS'
run: |
brew update # Todo: find a better way...
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
- name: Set min. dependencies
Expand Down
10 changes: 4 additions & 6 deletions docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ CIFAR-10 baseline
* - Ours
- 88.50
- `resnet50 <https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS-SGD <https://lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- LARS
- 2048
- 800 (4 hours)
- 8 V100 (16GB)
Expand Down Expand Up @@ -361,7 +361,6 @@ To reproduce::
-- num_workers 16
--optimizer sgd
--learning_rate 1.5
--lars_wrapper
--exclude_bn_bias
--max_epochs 800
--online_ft
Expand Down Expand Up @@ -401,7 +400,7 @@ Imagenet baseline for SimCLR
* - Ours
- 68.4
- `resnet50 <https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS-SGD <https://lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- LARS
- 4096
- 800
- 64 V100 (16GB)
Expand Down Expand Up @@ -533,7 +532,7 @@ The original paper does not provide baselines on STL10.
* - Ours
- `86.72 <https://tensorboard.dev/experiment/w2pq3bPPSxC4VIm5udhA2g/>`_
- SwAV resnet50
- `LARS <https://lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- LARS
- 128
- No
- 100 (~9 hr)
Expand Down Expand Up @@ -585,7 +584,6 @@ To reproduce::
python swav_module.py
--online_ft
--gpus 1
--lars_wrapper
--batch_size 128
--learning_rate 1e-3
--gaussian_blur
Expand Down Expand Up @@ -630,7 +628,7 @@ Imagenet baseline for SwAV
* - Ours
- 74
- `resnet50 <https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS-SGD <https://lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- LARS
- 4096
- 800
- 64 V100 (16GB)
Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR


Expand Down Expand Up @@ -141,7 +140,6 @@ def validation_step(self, batch, batch_idx):

def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
optimizer = LARSWrapper(optimizer)
scheduler = LinearWarmupCosineAnnealingLR(
optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs
)
Expand Down
88 changes: 28 additions & 60 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import math
from argparse import ArgumentParser
from typing import Callable, Optional

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch import nn
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer

from pl_bolts.models.self_supervised.resnets import resnet18, resnet50
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lars import LARS
from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
from pl_bolts.transforms.dataset_normalizations import (
cifar10_normalization,
imagenet_normalization,
Expand Down Expand Up @@ -79,7 +76,6 @@ def __init__(
first_conv: bool = True,
maxpool1: bool = True,
optimizer: str = 'adam',
lars_wrapper: bool = True,
exclude_bn_bias: bool = False,
start_lr: float = 0.,
learning_rate: float = 1e-3,
Expand Down Expand Up @@ -112,7 +108,6 @@ def __init__(
self.maxpool1 = maxpool1

self.optim = optimizer
self.lars_wrapper = lars_wrapper
self.exclude_bn_bias = exclude_bn_bias
self.weight_decay = weight_decay
self.temperature = temperature
Expand All @@ -131,19 +126,6 @@ def __init__(
global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

# define LR schedule
warmup_lr_schedule = np.linspace(
self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs
)
iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs))
cosine_lr_schedule = np.array([
self.final_lr + 0.5 * (self.learning_rate - self.final_lr) *
(1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs))))
for t in iters
])

self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))

def init_model(self):
if self.arch == 'resnet18':
backbone = resnet18
Expand Down Expand Up @@ -179,9 +161,6 @@ def shared_step(self, batch):
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)

# log LR (LearningRateLogger callback doesn't work with LARSWrapper)
self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False)

self.log('train_loss', loss, on_step=True, on_epoch=False)
return loss

Expand Down Expand Up @@ -217,41 +196,30 @@ def configure_optimizers(self):
else:
params = self.parameters()

if self.optim == 'sgd':
optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay)
if self.optim == 'lars':
optimizer = LARS(
params,
lr=self.learning_rate,
momentum=0.9,
weight_decay=self.weight_decay,
trust_coefficient=0.001,
)
elif self.optim == 'adam':
optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)

if self.lars_wrapper:
optimizer = LARSWrapper(
optimizer,
eta=0.001, # trust coefficient
clip=False
)
warmup_steps = self.train_iters_per_epoch * self.warmup_epochs
total_steps = self.train_iters_per_epoch * self.max_epochs

return optimizer
scheduler = {
"scheduler": torch.optim.lr_scheduler.LambdaLR(
optimizer,
linear_warmup_decay(warmup_steps, total_steps, cosine=True),
),
"interval": "step",
"frequency": 1,
}

def optimizer_step(
self,
epoch: int = None,
batch_idx: int = None,
optimizer: Optimizer = None,
optimizer_idx: int = None,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = None,
using_native_amp: bool = None,
using_lbfgs: bool = None,
) -> None:
# warm-up + decay schedule placed here since LARSWrapper is not optimizer class
# adjust LR of optim contained within LARSWrapper
for param_group in optimizer.param_groups:
param_group["lr"] = self.lr_schedule[self.trainer.global_step]

# from lightning
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=optimizer_closure)
return [optimizer], [scheduler]

def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6):
"""
Expand Down Expand Up @@ -280,8 +248,8 @@ def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6):
sim = torch.exp(cov / temperature)
neg = sim.sum(dim=-1)

# from each row, subtract e^1 to remove similarity measure for x1.x1
row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
# from each row, subtract e^(1/temp) to remove similarity measure for x1.x1
row_sub = torch.Tensor(neg.shape).fill_(math.e**(1 / temperature)).to(neg.device)
neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability

# Positive similarity, pos becomes [2 * batch_size]
Expand Down Expand Up @@ -317,8 +285,7 @@ def add_model_specific_args(parent_parser):
parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training")
parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd")
parser.add_argument("--lars_wrapper", action='store_true', help="apple lars wrapper over optimizer used")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars")
parser.add_argument('--exclude_bn_bias', action='store_true', help="exclude bn/bias from weight decay")
parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run")
parser.add_argument("--max_steps", default=-1, type=int, help="max steps")
Expand Down Expand Up @@ -393,8 +360,7 @@ def cli_main():
args.gpus = 8 # per-node
args.max_epochs = 800

args.optimizer = 'sgd'
args.lars_wrapper = True
args.optimizer = 'lars'
args.learning_rate = 4.8
args.final_lr = 0.0048
args.start_lr = 0.3
Expand Down Expand Up @@ -434,8 +400,10 @@ def cli_main():
dataset=args.dataset,
)

lr_monitor = LearningRateMonitor(logging_interval="step")
model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss')
callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint]
callbacks.append(lr_monitor)

trainer = pl.Trainer(
max_epochs=args.max_epochs,
Expand Down
Loading

0 comments on commit 51fe7d9

Please sign in to comment.