-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathagc.py
24 lines (21 loc) · 1023 Bytes
/
agc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
def unitwise_norm(x, norm_type=2.0):
if x.ndim <= 1:
return x.norm(norm_type)
else:
# works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
# might need special cases for other weights (possibly MHA) where this may not be true
return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
for p in parameters:
if p.grad is None:
continue
p_data = p.detach()
g_data = p.grad.detach()
max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
grad_norm = unitwise_norm(g_data, norm_type=norm_type)
clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
p.grad.detach().copy_(new_grads)