-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathSpectralNormGouk.py
278 lines (233 loc) · 10.6 KB
/
SpectralNormGouk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""
Spectral Normalization from https://arxiv.org/abs/1802.05957
"""
import torch
from torch.nn.functional import normalize
from torch.nn.parameter import Parameter
import torch.nn.functional as F
class forward_function:
def __init__(self, weight):
self.weight = weight
def __call__(self, inp, weight=None):
if weight is None:
weight = self.weight
return F.linear(input, weight)
class iteration_function:
def __init__(self, weight):
self.weight = weight
def __call__(self, inp, weight=None):
if weight is None:
weight = self.weight
return F.linear(F.linear(inp, weight), weight.transpose(1,0))
class forward_function2:
def __init__(self, functions, weight, s, g, d, p):
self.weight = weight
self.s = s
self.g = g
self.d = d
self.p = p
self.functions = functions
def __call__(self, inp, weight=None, s=None, g=None, d=None, p=None):
if weight is None:
weight = self.weight
if s is None:
s = self.s
if g is None:
g = self.g
if d is None:
d = self.d
if p is None:
p = self.p
return self.functions[0](inp, weight, stride=s, padding=p, dilation=d, groups=g)
class iteration_function2:
def __init__(self, functions, weight, s, g, d, p):
self.weight = weight
self.s = s
self.g = g
self.d = d
self.p = p
self.functions = functions
def __call__(self, inp, weight=None, s=None, g=None, d=None, p=None):
if weight is None:
weight = self.weight
if s is None:
s = self.s
if g is None:
g = self.g
if d is None:
d = self.d
if p is None:
p = self.p
return self.functions[1](self.functions[0](inp, weight, stride=s, padding=p, dilation=d, groups=g),
weight, stride=s, padding=p, dilation=d, groups=g)
class SpectralNorm(object):
# Invariant before and after each forward call:
# u = normalize(W @ v)
# NB: At initialization, this invariant is not enforced
_version = 2
# At version 2:
# used Gouk 2018 method.
# will only normalize if largest singular value > magnitude
def __init__(self, name='weight', n_power_iterations=1, magnitude=1.0, eps=1e-12):
self.name = name
self.magnitude = magnitude
if n_power_iterations <= 0:
raise ValueError('Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps
def l2norm(self, t):
return torch.sqrt((t ** 2).sum())
def compute_weight(self, module, do_power_iteration, num_iter=0):
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
if do_power_iteration:
with torch.no_grad():
for _ in range(max(self.n_power_iterations, num_iter)):
u = module.iteration_function(u, weight=weight)
if self.n_power_iterations > 0:
# See above on why we need to clone
u = u.clone()
sv = self.l2norm(module.forward_function(u, weight=weight)) / self.l2norm(u)
sigma = F.relu(sv / self.magnitude - 1.0) + 1.0
module.sigma = sigma
else:
sigma = module.sigma
return weight / sigma
def remove(self, module):
with torch.no_grad():
weight = self.compute_weight(module, do_power_iteration=False)
delattr(module, self.name)
delattr(module, self.name + '_u')
delattr(module, self.name + '_sigma')
delattr(module, self.name + '_orig')
module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))
def __call__(self, module, inputs, n_power_iterations=0):
setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training, num_iter=n_power_iterations))
@staticmethod
def apply(module, name, n_power_iterations, magnitude, eps):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError("Cannot register two spectral_norm hooks on "
"the same parameter {}".format(name))
fn = SpectralNorm(name, n_power_iterations, magnitude, eps)
weight = module._parameters[name]
functions_dict = {torch.nn.Conv1d : (F.conv1d, F.conv_transpose1d),
torch.nn.Conv2d : (F.conv2d, F.conv_transpose2d),
torch.nn.Conv3d : (F.conv3d, F.conv_transpose3d),
torch.nn.ConvTranspose1d : (F.conv_transpose1d, F.conv1d),
torch.nn.ConvTranspose2d : (F.conv_transpose2d, F.conv2d),
torch.nn.ConvTranspose3d : (F.conv_transpose3d, F.conv3d),
}
if isinstance(module, torch.nn.Linear):
module.forward_function = forward_function(weight)
module.iteration_function = iteration_function(weight)
elif isinstance(module, (torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,)):
k = weight.shape[2:]
s = module.stride
g = module.groups
d = module.dilation
p = module.padding
functions = functions_dict[module.__class__ ]
module.forward_function = forward_function2(functions, weight, s, g, d, p)
module.iteration_function = iteration_function2(functions, weight, s, g, d, p)
with torch.no_grad():
shape = (1,weight.shape[1])
for i in range(0,len(weight.shape)-2):
shape += (max(k[i]*d[i],1),)
u = torch.randn(shape).to(weight.device)
delattr(module, fn.name)
module.register_parameter(fn.name + "_orig", weight)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a plain
# attribute.
setattr(module, fn.name, weight.data)
module.register_buffer(fn.name + "_u", u)
sigma = torch.tensor(1).to(weight.device)
module.register_buffer(fn.name + "_sigma", sigma)
module.register_forward_pre_hook(fn)
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
return fn
# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormLoadStateDictPreHook(object):
# See docstring of SpectralNorm._version on the changes to spectral_norm.
def __init__(self, fn):
self.fn = fn
# For state_dict with version None, (assuming that it has gone through at
# least one training forward), we have
#
# u = normalize(W_orig @ v)
# W = W_orig / sigma, where sigma = u @ W_orig @ v
#
# To compute `v`, we solve `W_orig @ x = u`, and let
# v = x / (u @ W_orig @ x) * (W / W_orig).
def __call__(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
pass
# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormStateDictHook(object):
# See docstring of SpectralNorm._version on the changes to spectral_norm.
def __init__(self, fn):
self.fn = fn
def __call__(self, module, state_dict, prefix, local_metadata):
pass
def spectral_norm(module, name='weight', n_power_iterations=1, magnitude=1.0,eps=1e-12):
r"""Applies spectral normalization to a parameter in the given module.
.. math::
\mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
with spectral norm :math:`\sigma` of the weight matrix calculated using
power iteration method. If the dimension of the weight tensor is greater
than 2, it is reshaped to 2D in power iteration method to get spectral
norm. This is implemented via a hook that calculates spectral norm and
rescales weight before every :meth:`~Module.forward` call.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
n_power_iterations (int, optional): number of power iterations to
calculate spectal norm
eps (float, optional): epsilon for numerical stability in
calculating norms
dim (int, optional): dimension corresponding to number of outputs,
the default is 0, except for modules that are instances of
ConvTranspose1/2/3d, when it is 1
Returns:
The original module with the spectal norm hook
Example::
>>> m = spectral_norm(nn.Linear(20, 40))
Linear (20 -> 40)
>>> m.weight_u.size()
torch.Size([20])
"""
SpectralNorm.apply(module, name, n_power_iterations, magnitude, eps)
return module
def remove_spectral_norm(module, name='weight'):
r"""Removes the spectral normalization reparameterization from a module.
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))