Skip to content

Commit

Permalink
Add end_lr in WarmupCosineSchedule (#6662)
Browse files Browse the repository at this point in the history
Fixes #6527 .

### Description

Add `end_lr` in `WarmupCosineSchedule`

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu authored Jun 27, 2023
1 parent 8bc25b9 commit 60620e4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
10 changes: 10 additions & 0 deletions monai/optimizers/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
optimizer: Optimizer,
warmup_steps: int,
t_total: int,
end_lr: float = 0.0,
cycles: float = 0.5,
last_epoch: int = -1,
warmup_multiplier: float = 0,
Expand All @@ -77,6 +78,7 @@ def __init__(
optimizer: wrapped optimizer.
warmup_steps: number of warmup iterations.
t_total: total number of training iterations.
end_lr: the final learning rate. Defaults to 0.0.
cycles: cosine cycles parameter.
last_epoch: the index of last epoch.
warmup_multiplier: if provided, starts the linear warmup from this fraction of the initial lr.
Expand All @@ -88,6 +90,7 @@ def __init__(
self.warmup_multiplier = warmup_multiplier
self.t_total = t_total
self.cycles = cycles
self.end_lr = end_lr
if warmup_multiplier < 0 or warmup_multiplier > 1:
raise ValueError("warmup_multiplier must be in 0..1 range")
super().__init__(optimizer, self.lr_lambda, last_epoch)
Expand All @@ -98,3 +101,10 @@ def lr_lambda(self, step):
return self.warmup_multiplier + (1 - self.warmup_multiplier) * f
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))

def get_lr(self):
current_lr = [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
if self.last_epoch < self.warmup_steps:
return current_lr
else:
return [max(self.end_lr, _current_lr) for _current_lr in current_lr]
4 changes: 4 additions & 0 deletions tests/test_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def forward(self, x):
{"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1},
[0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038],
],
[
{"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1, "end_lr": 0.309},
[0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.309, 0.309],
],
]


Expand Down

0 comments on commit 60620e4

Please sign in to comment.