-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcyclic_lr_scheduler.py
61 lines (53 loc) · 2.5 KB
/
cyclic_lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import math
import numpy as np
from torch.optim.optimizer import Optimizer
class _LRScheduler(object):
def __init__(self, optimizer, last_epoch=-1):
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_lr' not in group:
raise KeyError("param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
class CyclicLR(_LRScheduler):
def __init__(self, optimizer, base_lr, max_lr, step_size, gamma=0.99, mode='triangular', last_epoch=-1):
self.optimizer = optimizer
self.base_lr = base_lr
self.max_lr = max_lr
self.step_size = step_size
self.gamma = gamma
self.mode = mode
assert mode in ['triangular', 'triangular2', 'exp_range']
super(CyclicLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
new_lr = []
# make sure that the length of base_lrs doesn't change. Dont care about the actual value
for base_lr in self.base_lrs:
cycle = np.floor(1 + self.last_epoch / (2 * self.step_size))
x = np.abs(float(self.last_epoch) / self.step_size - 2 * cycle + 1)
if self.mode == 'triangular':
lr = self.base_lr + (self.max_lr - self.base_lr) * np.maximum(0, (1 - x))
elif self.mode == 'triangular2':
lr = self.base_lr + (self.max_lr - self.base_lr) * np.maximum(0, (1 - x)) / float(2 ** (cycle - 1))
elif self.mode == 'exp_range':
lr = self.base_lr + (self.max_lr - self.base_lr) * np.maximum(0, (1 - x)) * (self.gamma ** (
self.last_epoch))
new_lr.append(lr)
return new_lr