diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py index b056e06a01..8f89352158 100644 --- a/monai/optimizers/lr_scheduler.py +++ b/monai/optimizers/lr_scheduler.py @@ -68,6 +68,7 @@ def __init__( optimizer: Optimizer, warmup_steps: int, t_total: int, + end_lr: float = 0.0, cycles: float = 0.5, last_epoch: int = -1, warmup_multiplier: float = 0, @@ -77,6 +78,7 @@ def __init__( optimizer: wrapped optimizer. warmup_steps: number of warmup iterations. t_total: total number of training iterations. + end_lr: the final learning rate. Defaults to 0.0. cycles: cosine cycles parameter. last_epoch: the index of last epoch. warmup_multiplier: if provided, starts the linear warmup from this fraction of the initial lr. @@ -88,6 +90,7 @@ def __init__( self.warmup_multiplier = warmup_multiplier self.t_total = t_total self.cycles = cycles + self.end_lr = end_lr if warmup_multiplier < 0 or warmup_multiplier > 1: raise ValueError("warmup_multiplier must be in 0..1 range") super().__init__(optimizer, self.lr_lambda, last_epoch) @@ -98,3 +101,10 @@ def lr_lambda(self, step): return self.warmup_multiplier + (1 - self.warmup_multiplier) * f progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) + + def get_lr(self): + current_lr = [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] + if self.last_epoch < self.warmup_steps: + return current_lr + else: + return [max(self.end_lr, _current_lr) for _current_lr in current_lr] diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index bcddf7627e..54092ba931 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -35,6 +35,10 @@ def forward(self, x): {"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1}, [0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038], ], + [ + {"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1, "end_lr": 0.309}, + [0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.309, 0.309], + ], ]