Skip to content

Commit

Permalink
Final simclr (#170)
Browse files Browse the repository at this point in the history
* cleared docs

* cleared docs

* cleared docs
  • Loading branch information
williamFalcon authored Aug 23, 2020
1 parent fd1c350 commit cf38307
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,19 @@ CIFAR-10 baseline
- Hardware
- LR
* - `Original <https://github.com/google-research/simclr#finetuning-the-linear-head-linear-eval>`_
- `82.00? <https://github.com/google-research/simclr#finetuning-the-linear-head-linear-eval>`_
- resnet (depth 18)
- `92.00? <https://github.com/google-research/simclr#finetuning-the-linear-head-linear-eval>`_
- resnet50
- LARS
- 512
- 1000
- 1 V100 (32GB)
- 1.0
* - Ours
- `86.75 <https://tensorboard.dev/experiment/mh3qnIdaQcWA9d4XkErNEA>`_
- `85.68 <https://tensorboard.dev/experiment/GlS1eLXMQsqh3T5DAec6UQ/#scalars>`_
- `resnet50 <https://github.com/PyTorchLightning/PyTorch-Lightning-Bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS <https://pytorch-lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- 512
- 698 (10 hr)
- 960 (12 hr)
- 1 V100 (32GB)
- 1e-6

Expand All @@ -187,7 +187,7 @@ CIFAR-10 pretrained model::

from pl_bolts.models.self_supervised import SimCLR

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp2_acc_867/epoch%3D698.ckpt'
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

simclr.freeze()
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/simclr/simclr_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def cli_main(): # pragma: no-cover
dm.val_transforms = SimCLREvalDataTransform(h)

# finetune
tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes)
tuner = SSLFineTuner(backbone, in_features=2048 * 2 * 2, num_classes=dm.num_classes, hidden_dim=None)
trainer = pl.Trainer.from_argparse_args(args, early_stop_callback=True)
trainer.fit(tuner, dm)

Expand Down

0 comments on commit cf38307

Please sign in to comment.