-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathlr.py
31 lines (25 loc) · 925 Bytes
/
lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import math
class LRScheduler():
def __init__(self, optimizer, total_steps, warmup_steps, max_lr=1e-3):
pass
def step(self):
self.steps += 1
def get_lr(self):
pass
class CosineDecayWithWarmup(LRScheduler):
def __init__(self, optimizer, total_steps, warmup_steps, max_lr=1e-3):
self.total_steps = total_steps
self.warmup_steps = warmup_steps
self.max_lr = max_lr
self.steps = 1
def step(self):
self.steps += 1
def get_lr(self):
if self.steps < self.warmup_steps:
lr = self.max_lr * (self.steps / self.warmup_steps)
else:
remaining_steps = self.steps - self.warmup_steps
total_cosine_steps = self.total_steps - self.warmup_steps
cos_inner = (math.pi * remaining_steps) / total_cosine_steps
lr = self.max_lr * (1 + math.cos(cos_inner)) / 2
return lr