-
Notifications
You must be signed in to change notification settings - Fork 176
/
Copy pathoptim.py
63 lines (44 loc) · 1.95 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
class OptimizerWrapper(torch.optim.Optimizer):
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer"""
def __init__(self, optim: torch.optim.Optimizer):
object.__init__(self)
self.optim = optim
@property
def defaults(self):
return self.optim.defaults
@property
def state(self):
return self.optim.state
def __getstate__(self):
return self.optim.__getstate__()
def __setstate__(self, state):
self.optim.__setstate__(state)
def __repr__(self):
return f"{self.__class__.__name__}({repr(self.optim)})"
def state_dict(self):
return self.optim.state_dict()
def load_state_dict(self, state_dict: dict) -> None:
return self.optim.load_state_dict(state_dict)
def step(self, *args, **kwargs):
return self.optim.step(*args, **kwargs)
def zero_grad(self, *args, **kwargs):
return self.optim.zero_grad(*args, **kwargs)
@property
def param_groups(self):
return self.optim.param_groups
def add_param_group(self, param_group: dict) -> None:
return self.optim.add_param_group(param_group)
class ClippingWrapper(OptimizerWrapper):
"""A wrapper to pytorch.optimizer that clips gradients by global norm before each step"""
def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float):
super().__init__(optim)
self.clip_grad_norm = clip_grad_norm
def step(self, *args, **kwargs):
parameters = tuple(param for group in self.param_groups for param in group["params"])
torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm)
return super().step(*args, **kwargs)
@classmethod
def create(cls, optim_cls: type, *args, clip_grad_norm: float, **kwargs):
"""Create a wrapped optimizer and wrap it with clipping"""
return cls(optim=optim_cls(*args, **kwargs), clip_grad_norm=clip_grad_norm)