Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates all scripts to LARS #613

Merged
merged 13 commits into from
Apr 12, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci_test-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Setup macOS
if: runner.os == 'macOS'
run: |
brew update # Todo: find a better way...
brew install libomp # https://github.com/pytorch/pytorch/issues/20030

# Note: This uses an internal pip API and may not always work
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci_test-full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ jobs:
- name: Setup macOS
if: runner.os == 'macOS'
run: |
brew update # Todo: find a better way...
brew install libomp # https://github.com/pytorch/pytorch/issues/20030

- name: Set min. dependencies
Expand Down
10 changes: 4 additions & 6 deletions docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ CIFAR-10 baseline
* - Ours
- 88.50
- `resnet50 <https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS-SGD <https://lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- LARS
- 2048
- 800 (4 hours)
- 8 V100 (16GB)
Expand Down Expand Up @@ -361,7 +361,6 @@ To reproduce::
-- num_workers 16
--optimizer sgd
--learning_rate 1.5
--lars_wrapper
--exclude_bn_bias
--max_epochs 800
--online_ft
Expand Down Expand Up @@ -401,7 +400,7 @@ Imagenet baseline for SimCLR
* - Ours
- 68.4
- `resnet50 <https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS-SGD <https://lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- LARS
- 4096
- 800
- 64 V100 (16GB)
Expand Down Expand Up @@ -533,7 +532,7 @@ The original paper does not provide baselines on STL10.
* - Ours
- `86.72 <https://tensorboard.dev/experiment/w2pq3bPPSxC4VIm5udhA2g/>`_
- SwAV resnet50
- `LARS <https://lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- LARS
- 128
- No
- 100 (~9 hr)
Expand Down Expand Up @@ -585,7 +584,6 @@ To reproduce::
python swav_module.py
--online_ft
--gpus 1
--lars_wrapper
--batch_size 128
--learning_rate 1e-3
--gaussian_blur
Expand Down Expand Up @@ -630,7 +628,7 @@ Imagenet baseline for SwAV
* - Ours
- 74
- `resnet50 <https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS-SGD <https://lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- LARS
- 4096
- 800
- 64 V100 (16GB)
Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate
from pl_bolts.models.self_supervised.byol.models import SiameseArm
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR


Expand Down Expand Up @@ -141,7 +140,6 @@ def validation_step(self, batch, batch_idx):

def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
optimizer = LARSWrapper(optimizer)
scheduler = LinearWarmupCosineAnnealingLR(
optimizer, warmup_epochs=self.hparams.warmup_epochs, max_epochs=self.hparams.max_epochs
)
Expand Down
87 changes: 28 additions & 59 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import math
from argparse import ArgumentParser
from typing import Callable, Optional

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.optimizer import LightningOptimizer
from torch import nn
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer

from pl_bolts.models.self_supervised.resnets import resnet18, resnet50
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.optimizers.lars import LARS
from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
from pl_bolts.transforms.dataset_normalizations import (
cifar10_normalization,
imagenet_normalization,
Expand Down Expand Up @@ -79,7 +77,6 @@ def __init__(
first_conv: bool = True,
maxpool1: bool = True,
optimizer: str = 'adam',
lars_wrapper: bool = True,
exclude_bn_bias: bool = False,
start_lr: float = 0.,
learning_rate: float = 1e-3,
Expand Down Expand Up @@ -112,7 +109,6 @@ def __init__(
self.maxpool1 = maxpool1

self.optim = optimizer
self.lars_wrapper = lars_wrapper
self.exclude_bn_bias = exclude_bn_bias
self.weight_decay = weight_decay
self.temperature = temperature
Expand All @@ -131,19 +127,6 @@ def __init__(
global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

# define LR schedule
warmup_lr_schedule = np.linspace(
self.start_lr, self.learning_rate, self.train_iters_per_epoch * self.warmup_epochs
)
iters = np.arange(self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs))
cosine_lr_schedule = np.array([
self.final_lr + 0.5 * (self.learning_rate - self.final_lr) *
(1 + math.cos(math.pi * t / (self.train_iters_per_epoch * (self.max_epochs - self.warmup_epochs))))
for t in iters
])

self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))

def init_model(self):
if self.arch == 'resnet18':
backbone = resnet18
Expand Down Expand Up @@ -179,9 +162,6 @@ def shared_step(self, batch):
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)

# log LR (LearningRateLogger callback doesn't work with LARSWrapper)
self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False)

self.log('train_loss', loss, on_step=True, on_epoch=False)
return loss

Expand Down Expand Up @@ -217,41 +197,30 @@ def configure_optimizers(self):
else:
params = self.parameters()

if self.optim == 'sgd':
optimizer = torch.optim.SGD(params, lr=self.learning_rate, momentum=0.9, weight_decay=self.weight_decay)
if self.optim == 'lars':
optimizer = LARS(
params,
lr=self.learning_rate,
momentum=0.9,
weight_decay=self.weight_decay,
trust_coefficient=0.001,
)
elif self.optim == 'adam':
optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)

if self.lars_wrapper:
optimizer = LARSWrapper(
optimizer,
eta=0.001, # trust coefficient
clip=False
)
warmup_steps = self.train_iters_per_epoch * self.warmup_epochs
total_steps = self.train_iters_per_epoch * self.max_epochs

return optimizer
scheduler = {
"scheduler": torch.optim.lr_scheduler.LambdaLR(
optimizer,
linear_warmup_decay(warmup_steps, total_steps, cosine=True),
),
"interval": "step",
"frequency": 1,
}

def optimizer_step(
self,
epoch: int = None,
batch_idx: int = None,
optimizer: Optimizer = None,
optimizer_idx: int = None,
optimizer_closure: Optional[Callable] = None,
on_tpu: bool = None,
using_native_amp: bool = None,
using_lbfgs: bool = None,
) -> None:
# warm-up + decay schedule placed here since LARSWrapper is not optimizer class
# adjust LR of optim contained within LARSWrapper
for param_group in optimizer.param_groups:
param_group["lr"] = self.lr_schedule[self.trainer.global_step]

# from lightning
if not isinstance(optimizer, LightningOptimizer):
# wraps into LightingOptimizer only for running step
optimizer = LightningOptimizer.to_lightning_optimizer(optimizer, self.trainer)
optimizer.step(closure=optimizer_closure)
return [optimizer], [scheduler]

def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6):
"""
Expand Down Expand Up @@ -280,8 +249,8 @@ def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6):
sim = torch.exp(cov / temperature)
neg = sim.sum(dim=-1)

# from each row, subtract e^1 to remove similarity measure for x1.x1
row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
# from each row, subtract e^(1/temp) to remove similarity measure for x1.x1
row_sub = torch.Tensor(neg.shape).fill_(math.e ** (1 / temperature)).to(neg.device)
neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability

# Positive similarity, pos becomes [2 * batch_size]
Expand Down Expand Up @@ -317,8 +286,7 @@ def add_model_specific_args(parent_parser):
parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training")
parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd")
parser.add_argument("--lars_wrapper", action='store_true', help="apple lars wrapper over optimizer used")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars")
parser.add_argument('--exclude_bn_bias', action='store_true', help="exclude bn/bias from weight decay")
parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run")
parser.add_argument("--max_steps", default=-1, type=int, help="max steps")
Expand Down Expand Up @@ -393,8 +361,7 @@ def cli_main():
args.gpus = 8 # per-node
args.max_epochs = 800

args.optimizer = 'sgd'
args.lars_wrapper = True
args.optimizer = 'lars'
args.learning_rate = 4.8
args.final_lr = 0.0048
args.start_lr = 0.3
Expand Down Expand Up @@ -434,8 +401,10 @@ def cli_main():
dataset=args.dataset,
)

lr_monitor = LearningRateMonitor(logging_interval="step")
model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor='val_loss')
callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint]
callbacks.append(lr_monitor)

trainer = pl.Trainer(
max_epochs=args.max_epochs,
Expand Down
Loading