From dedd10132c8865ada9a5a1ffd0410b478e9e997b Mon Sep 17 00:00:00 2001 From: Annika Brundyn Date: Wed, 5 Aug 2020 19:07:37 -0400 Subject: [PATCH] add scheduler --- pl_bolts/models/self_supervised/byol/byol_module.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index a918780e61..b54496ac63 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -6,6 +6,7 @@ from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule from pl_bolts.models.self_supervised.simclr.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform from pl_bolts.optimizers.layer_adaptive_scaling import LARS +from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR from pl_bolts.models.self_supervised.byol.models import SiameseArm from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate @@ -14,7 +15,7 @@ class BYOL(pl.LightningModule): def __init__(self, datamodule: pl.LightningDataModule = None, data_dir: str = './', - learning_rate: float = 0.00006, + learning_rate: float = 0.2, weight_decay: float = 0.0005, input_height: int = 32, batch_size: int = 32, @@ -153,9 +154,9 @@ def validation_step(self, batch, batch_idx): return result def configure_optimizers(self): - optimizer = LARS(self.parameters(), lr=self.hparams.learning_rate) - # TODO: add scheduler - cosine decay - return optimizer + optimizer = LARS(self.parameters(), lr=self.hparams.learning_rate, weight_decay=0.000015) + scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=1000) + return [optimizer], [scheduler] @staticmethod def add_model_specific_args(parent_parser):