Skip to content

Commit

Permalink
fixes simclr finetuner (#165)
Browse files Browse the repository at this point in the history
* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner

* added simclr finetuner
  • Loading branch information
williamFalcon authored Aug 23, 2020
1 parent c180b70 commit e3057b1
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 53 deletions.
92 changes: 92 additions & 0 deletions docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,118 @@ AMDIM
.. autoclass:: pl_bolts.models.self_supervised.AMDIM
:noindex:

---------

BYOL
^^^^

.. autoclass:: pl_bolts.models.self_supervised.BYOL
:noindex:

---------

CPC (V2)
^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.CPCV2
:noindex:

---------

Moco (V2)
^^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.MocoV2
:noindex:

---------

SimCLR
^^^^^^

PyTorch Lightning implementation of `SIMCLR <https://arxiv.org/abs/2002.05709.>`_

Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.

Model implemented by:

- `William Falcon <https://github.com/williamFalcon>`_
- `Tullie Murrell <https://github.com/tullie>`_

Example::

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size)

# fit
trainer = pl.Trainer()
trainer.fit(model, dm)

CIFAR-10 baseline
*****************
.. list-table:: Cifar-10 test accuracy
:widths: 50 50
:header-rows: 1

* - Model
- test accuracy
* - Original repo
- `82.00 <https://github.com/google-research/simclr#finetuning-the-linear-head-linear-eval>`_
* - Our implementation
- `86.75 <https://tensorboard.dev/experiment/mh3qnIdaQcWA9d4XkErNEA>`_

.. note:: This experiment used a standard resnet50 (not extra-wide, 2x, 4x). But you can use any resnet

|
Pre-training:

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/val_loss.png
:width: 200
:alt: pretraining validation loss

|
Fine-tuning (Single layer MLP, 1024 hidden units):

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/val_acc.png
:width: 200
:alt: finetuning validation accuracy

.. figure:: https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/test_acc.png
:width: 200
:alt: finetuning test accuracy

|
To reproduce::

# pretrain
python simclr_module.py
--gpus 1
--dataset cifar10
--batch_size 512
--learning_rate 1e-06
--num_workers 8

# finetune
python simclr_finetuner.py
--ckpt_path path/to/epoch=xyz.ckpt
--gpus 1

SimCLR API
**********

.. autoclass:: pl_bolts.models.self_supervised.SimCLR
:noindex:
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/cpc/cpc_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os


def cli_main():
def cli_main(): # pragma: no-cover
pl.seed_everything(1234)

parser = ArgumentParser()
Expand Down
3 changes: 2 additions & 1 deletion pl_bolts/models/self_supervised/simclr/simclr_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pl_bolts.models.self_supervised.simclr.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform


def cli_main():
def cli_main(): # pragma: no-cover
pl.seed_everything(1234)

parser = ArgumentParser()
Expand All @@ -24,6 +24,7 @@ def cli_main():
dm = CIFAR10DataModule.from_argparse_args(args)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
dm.test_transforms = SimCLREvalDataTransform(32)
args.num_samples = dm.num_samples

elif args.dataset == 'stl10':
Expand Down
55 changes: 4 additions & 51 deletions pl_bolts/models/self_supervised/simclr/simclr_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,60 +57,13 @@ def __init__(self,
loss_temperature=0.5,
**kwargs):
"""
PyTorch Lightning implementation of `SIMCLR <https://arxiv.org/abs/2002.05709.>`_
Paper authors: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton.
Model implemented by:
- `William Falcon <https://github.com/williamFalcon>`_
- `Tullie Murrell <https://github.com/tullie>`_
Example::
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size)
# fit
trainer = pl.Trainer()
trainer.fit(model, dm))
Train::
trainer = Trainer()
trainer.fit(model)
CLI command::
# cifar10
python simclr_module.py --gpus 1
# imagenet
python simclr_module.py
--gpus 8
--dataset imagenet2012
--data_dir /path/to/imagenet/
--meta_dir /path/to/folder/with/meta.bin/
--batch_size 32
Args:
batch_size: the batch size
num_samples: num samples in the dataset
warmup_epochs: epochs to warmup the lr for
lr: the optimizer learning rate
opt_weight_decay: the optimizer weight decay
loss_temperature: the loss temperature
warmup_epochs: epochs to warmup the lr for
lr: the optimizer learning rate
opt_weight_decay: the optimizer weight decay
loss_temperature: the loss temperature
"""
super().__init__()
self.save_hyperparameters()
Expand Down

0 comments on commit e3057b1

Please sign in to comment.