-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathSBD.m
140 lines (119 loc) · 4.06 KB
/
SBD.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
function [ Aout, Xout, bout, extras ] = SBD( Y, k, params, dispfun )
%SBD Summary of this function goes here
%
% PARAMS STRUCT:
% ===============
% The options struct should include the fields:
% lambda1, float > 0 : regularization parameter for Phase I
% phase2, bool : whether to do Phase II (refinement) or not
%
% IF phase2 == true, then the following fields should also be included:
% kplus, int > 0 : border padding (pixels) for sphere lifting
% lambda2, float > 0 : FINAL reg. param. value for Phase II
%
% nrefine, int >= 1 : number of refinements for Phase II.
% Refinement 1 lifts the sphere and uses lambda1, successive
% refinements decrease lambda down to lambda2;
% i.e. if nrefine == 1, then no decrease in lambda is made.
%
%
% Finally, two optional fields for the struct. These features are
% automatically disabled if the fields are not included or are empty:
%
% Xsolve, string : Pick which Xsolve to use--'FISTA' or
% 'pdNCG.'
%
% xpos, bool : Constrain X to have nonnegative entries
% when running XSolve.
%
% getbias, bool : Extract constant bias from observation.
%
%% Process input arguments
starttime = tic;
n = size(Y,3);
if nargin < 4 || isempty(dispfun)
dispfun = @(Y,A,X,k,kplus,idx) 0;
end
lambda1 = params.lambda1;
if params.phase2
kplus = params.kplus;
lambda2 = params.lambda2;
nrefine = params.nrefine;
end
if ~isfield(params, 'xpos') || isempty(params.xpos)
xpos = false;
else
xpos = params.xpos;
end
if ~isfield(params, 'getbias') || isempty(params.getbias)
getbias = false;
else
getbias = params.getbias;
end
if ~isfield(params, 'Xsolve') || isempty(params.Xsolve)
Xsolve = 'FISTA';
else
Xsolve = params.Xsolve;
end
%% PHASE I: First pass at BD
dispfun1 = @(A, X) dispfun(Y, A, X, k, [], 1);
fprintf('PHASE I: \n=========\n');
A = randn([k n]); A = A/norm(A(:));
[A, Xsol, info] = Asolve_Manopt( Y, A, lambda1, Xsolve, [], xpos, getbias, dispfun1);
extras.phase1.A = A;
extras.phase1.X = Xsol.X;
extras.phase1.b = Xsol.b;
extras.phase1.info = info;
%% PHASE II: Lift the sphere and do lambda continuation
if params.phase2
k2 = k + 2*kplus;
dispfun2 = @(A, X) dispfun(Y, A, X, k2, 0, 1);
A2 = zeros([k2 n]);
A2(kplus(1)+(1:k(1)), kplus(2)+(1:k(2)), :) = A;
X2sol = Xsol;
%X2sol.X = circshift(Xsol.X,-kplus);
%X2sol.W = circshift(Xsol.W,-kplus);
% clear A Xsol;
lambda = lambda1;
score = zeros(2*kplus+1);
fprintf('\n\nPHASE II: \n=========\n');
lam2fac = (lambda2/lambda1)^(1/nrefine);
i = 1;
while i <= nrefine + 1
fprintf('lambda = %.1e: \n', lambda);
[A2, X2sol, info] = Asolve_Manopt( Y, A2, lambda, Xsolve, X2sol, xpos, getbias, dispfun2 );
fprintf('\n');
%Attempt to 'unshift" the a and x by taking the l1-norm over all k-contiguous elements:
for tau1 = -kplus(1):kplus(1)
ind1 = tau1+kplus(1)+1;
for tau2 = -kplus(2):kplus(2)
ind2 = tau2+kplus(2)+1;
temp = A2(ind1:(ind1+k(1)-1), ind2:(ind2+k(2))-1,:);
score(ind1,ind2) = norm(temp(:), 1);
end
end
[temp,ind1] = max(score); [~,ind2] = max(temp);
tau = [ind1(ind2) ind2]-kplus-1;
A2 = circshift(A2,-tau);
X2sol.X = circshift(X2sol.X,tau);
X2sol.W = circshift(X2sol.W,tau);
% Save phase 2 extras:
if i == 1; idx = 1; else; idx = i; end
extras.phase2(idx).A = A2;
extras.phase2(idx).X = X2sol.X;
extras.phase2(idx).b = X2sol.b;
extras.phase2(idx).info = info;
if i == 1; extras.phase2 = fliplr(extras.phase2); end
dispfun2(A2,X2sol.X);
lambda = lambda*lam2fac;
i = i+1;
end
end
%% Finished: get the final A, X
Aout = A2(kplus(1)+(1:k(1)), kplus(2)+(1:k(2)), :);
Xout = circshift(X2sol.X,kplus) * norm(Aout(:));
Aout = Aout/norm(Aout(:));
bout = X2sol.b;
runtime = toc(starttime);
fprintf('\nDone! Runtime = %.2fs. \n\n', runtime);
end