-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbroyden.py
233 lines (197 loc) · 8.23 KB
/
broyden.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Modified based on the DEQ repo.
import torch
from torch import nn
import torch.nn.functional as functional
from torch.autograd import Function
import numpy as np
import pickle
import sys
import os
from scipy.optimize import root
import time
from termcolor import colored
def _safe_norm(v):
if not torch.isfinite(v).all():
return np.inf
return torch.norm(v)
def scalar_search_armijo(phi, phi0, derphi0, c1=1e-4, alpha0=1, amin=0):
ite = 0
phi_a0 = phi(alpha0) # First do an update with step size 1
if phi_a0 <= phi0 + c1*alpha0*derphi0:
return alpha0, phi_a0, ite
# Otherwise, compute the minimizer of a quadratic interpolant
alpha1 = -(derphi0) * alpha0**2 / 2.0 / (phi_a0 - phi0 - derphi0 * alpha0)
phi_a1 = phi(alpha1)
# Otherwise loop with cubic interpolation until we find an alpha which
# satisfies the first Wolfe condition (since we are backtracking, we will
# assume that the value of alpha is not too small and satisfies the second
# condition.
while alpha1 > amin: # we are assuming alpha>0 is a descent direction
factor = alpha0**2 * alpha1**2 * (alpha1-alpha0)
a = alpha0**2 * (phi_a1 - phi0 - derphi0*alpha1) - \
alpha1**2 * (phi_a0 - phi0 - derphi0*alpha0)
a = a / factor
b = -alpha0**3 * (phi_a1 - phi0 - derphi0*alpha1) + \
alpha1**3 * (phi_a0 - phi0 - derphi0*alpha0)
b = b / factor
alpha2 = (-b + torch.sqrt(torch.abs(b**2 - 3 * a * derphi0))) / (3.0*a)
phi_a2 = phi(alpha2)
ite += 1
if (phi_a2 <= phi0 + c1*alpha2*derphi0):
return alpha2, phi_a2, ite
if (alpha1 - alpha2) > alpha1 / 2.0 or (1 - alpha2/alpha1) < 0.96:
alpha2 = alpha1 / 2.0
alpha0 = alpha1
alpha1 = alpha2
phi_a0 = phi_a1
phi_a1 = phi_a2
# Failed to find a suitable step length
return None, phi_a1, ite
def line_search(update, x0, g0, g, nstep=0, on=True):
"""
`update` is the propsoed direction of update.
Code adapted from scipy.
"""
tmp_s = [0]
tmp_g0 = [g0]
tmp_phi = [torch.norm(g0)**2]
s_norm = torch.norm(x0) / torch.norm(update)
def phi(s, store=True):
if s == tmp_s[0]:
return tmp_phi[0] # If the step size is so small... just return something
x_est = x0 + s * update
g0_new = g(x_est)
phi_new = _safe_norm(g0_new)**2
if store:
tmp_s[0] = s
tmp_g0[0] = g0_new
tmp_phi[0] = phi_new
return phi_new
if on:
s, phi1, ite = scalar_search_armijo(phi, tmp_phi[0], -tmp_phi[0], amin=1e-2)
if (not on) or s is None:
s = 1.0
ite = 0
x_est = x0 + s * update
if s == tmp_s[0]:
g0_new = tmp_g0[0]
else:
g0_new = g(x_est)
return x_est, g0_new, x_est - x0, g0_new - g0, ite
def rmatvec(part_Us, part_VTs, x):
# Compute x^T(-I + UV^T)
# x: (N, 2d, L')
# part_Us: (N, 2d, L', threshold)
# part_VTs: (N, threshold, 2d, L')
if part_Us.nelement() == 0:
return -x
xTU = torch.einsum('bij, bijd -> bd', x, part_Us) # (N, threshold)
return -x + torch.einsum('bd, bdij -> bij', xTU, part_VTs) # (N, 2d, L'), but should really be (N, 1, (2d*L'))
def matvec(part_Us, part_VTs, x):
# Compute (-I + UV^T)x
# x: (N, 2d, L')
# part_Us: (N, 2d, L', threshold)
# part_VTs: (N, threshold, 2d, L')
if part_Us.nelement() == 0:
return -x
VTx = torch.einsum('bdij, bij -> bd', part_VTs, x) # (N, threshold)
return -x + torch.einsum('bijd, bd -> bij', part_Us, VTx) # (N, 2d, L'), but should really be (N, (2d*L'), 1)
def broyden(g, x0, threshold, eps, ls=False, name="unknown"):
bsz, total_hsize, n_elem = x0.size()
dev = x0.device
x_est = x0 # (bsz, 2d, L')
gx = g(x_est) # (bsz, 2d, L')
nstep = 0
tnstep = 0
LBFGS_thres = min(threshold, 27)
# For fast calculation of inv_jacobian (approximately)
Us = torch.zeros(bsz, total_hsize, n_elem, LBFGS_thres).to(dev)
VTs = torch.zeros(bsz, LBFGS_thres, total_hsize, n_elem).to(dev)
update = gx
new_objective = init_objective = torch.norm(gx).item()
prot_break = False
trace = [init_objective]
new_trace = [-1]
# To be used in protective breaks
protect_thres = 1e6 * n_elem
lowest = new_objective
lowest_xest, lowest_gx, lowest_step = x_est, gx, nstep
while new_objective >= eps and nstep < threshold:
x_est, gx, delta_x, delta_gx, ite = line_search(update, x_est, gx, g, nstep=nstep, on=ls)
nstep += 1
tnstep += (ite+1)
new_objective = torch.norm(gx).item()
trace.append(new_objective)
try:
new2_objective = torch.norm(delta_x).item() / (torch.norm(x_est - delta_x).item()) # Relative residual
except:
new2_objective = torch.norm(delta_x).item() / (torch.norm(x_est - delta_x).item() + 1e-9)
new_trace.append(new2_objective)
if new_objective < lowest:
lowest_xest, lowest_gx = x_est.clone().detach(), gx.clone().detach()
lowest = new_objective
lowest_step = nstep
if new_objective < eps:
break
if new_objective < 3*eps and nstep > 30 and np.max(trace[-30:]) / np.min(trace[-30:]) < 1.3:
# if there's hardly been any progress in the last 30 steps
break
if new_objective > init_objective * protect_thres:
prot_break = True
break
part_Us, part_VTs = Us[:,:,:,:(nstep-1)], VTs[:,:(nstep-1)]
vT = rmatvec(part_Us, part_VTs, delta_x)
u = (delta_x - matvec(part_Us, part_VTs, delta_gx)) / torch.einsum('bij, bij -> b', vT, delta_gx)[:,None,None]
vT[vT != vT] = 0
u[u != u] = 0
VTs[:,(nstep-1) % LBFGS_thres] = vT
Us[:,:,:,(nstep-1) % LBFGS_thres] = u
update = -matvec(Us[:,:,:,:nstep], VTs[:,:nstep], gx)
Us, VTs = None, None
return {"result": lowest_xest,
"nstep": nstep,
"tnstep": tnstep,
"lowest_step": lowest_step,
"diff": torch.norm(lowest_gx).item(),
"diff_detail": torch.norm(lowest_gx, dim=1),
"prot_break": prot_break,
"trace": trace,
"new_trace": new_trace,
"eps": eps,
"threshold": threshold}
def analyze_broyden(res_info, err=None, judge=True, name='forward', training=True, save_err=True):
"""
For debugging use only :-)
"""
res_est = res_info['result']
nstep = res_info['nstep']
diff = res_info['diff']
diff_detail = res_info['diff_detail']
prot_break = res_info['prot_break']
trace = res_info['trace']
eps = res_info['eps']
threshold = res_info['threshold']
if judge:
return nstep >= threshold or (nstep == 0 and (diff != diff or diff > eps)) or prot_break or torch.isnan(res_est).any()
assert (err is not None), "Must provide err information when not in judgment mode"
prefix, color = ('', 'red') if name == 'forward' else ('back_', 'blue')
eval_prefix = '' if training else 'eval_'
# Case 1: A nan entry is produced in Broyden
if torch.isnan(res_est).any():
msg = colored(f"WARNING: nan found in Broyden's {name} result. Diff: {diff}", color)
print(msg)
if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}nan.pkl', 'wb'))
return (1, msg, res_info)
# Case 2: Unknown problem with Broyden's method (probably due to nan update(s) to the weights)
if nstep == 0 and (diff != diff or diff > eps):
msg = colored(f"WARNING: Bad Broyden's method {name}. Why?? Diff: {diff}. STOP.", color)
print(msg)
if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}badbroyden.pkl', 'wb'))
return (2, msg, res_info)
# Case 3: Protective break during Broyden (so that it does not diverge to infinity)
if prot_break and np.random.uniform(0,1) < 0.05:
msg = colored(f"WARNING: Hit Protective Break in {name}. Diff: {diff}. Total Iter: {len(trace)}", color)
print(msg)
if save_err: pickle.dump(err, open(f'{prefix}{eval_prefix}prot_break.pkl', 'wb'))
return (3, msg, res_info)
return (-1, '', res_info)