-
Notifications
You must be signed in to change notification settings - Fork 122
/
Copy pathlayerWholeL2Normalize.m
107 lines (88 loc) · 3.07 KB
/
layerWholeL2Normalize.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
classdef layerWholeL2Normalize
properties
type= 'custom'
name= 'wholeL2'
precious= false
end
methods
function l= layerWholeL2Normalize(name)
if nargin>0, l.name= name; end
end
function y= forward_(p, x)
batchSize= size(x, 4);
y= relja_l2normalize_col( reshape(x, [], batchSize) );
y= reshape(y, [1,1,size(y,1), batchSize]);
end
function dzdx= backward_(p, x, dzdy)
batchSize= size(x, 4);
xr= reshape(x, [], batchSize);
dzdy= reshape(dzdy, [], batchSize);
xNorm= sqrt(sum(xr.^2,1)) + 1e-12;
% dim= size(xr, 1);
% D: d(yi)/d(xj)= (i==j)/xnorm - xi*xj / xnorm^3
% where xnorm= sqrt(sum(x.^2))
% dzdx= D' * dzdy
% batchSize==1: D= -xr/(xNorm^3) * xr' + 1/xNorm*eye(dim);
% for speed better:
% -xr*(xr'*dzdx)/xNorm^3 + dzdx/xNorm
dzdx= xr; % just to create the matrix, will be overwritten
for iB= 1:batchSize
% Slow:
% dzdx(:,iB)= ...
% ( -xr(:,iB)* xr(:,iB)'/(xNorm(iB)^3) + eye(dim)/xNorm(iB) ) ...
% * dzdy(:,iB);
% Fast:
dzdx(:,iB)= ...
-xr(:,iB)* (xr(:,iB)'*dzdy(:,iB))/(xNorm(iB)^3) + dzdy(:,iB)/xNorm(iB);
end
dzdx= reshape(dzdx, size(x));
end
end
methods (Static)
function res1= forward(p, res0, res1)
res1.x= p.forward_(res0.x);
end
function res0= backward(p, res0, res1)
res0.dzdx= p.backward_(res0.x, res1.dzdx);
end
end
end
% Can implement with pure matconvnet + reshapes, but my implementation seems to be much faster (at least for in my setup / usecase)
% classdef layerWholeL2Normalize < handle
%
% properties
% type= 'custom'
% name= 'wholeL2'
% end
%
% methods
%
% function l= layerWholeL2Normalize(name)
% if nargin>0, l.name= name; end
% end
%
% end
%
% methods (Static)
%
% function res1= forward(l, res0, res1)
% D= relja_numel(res0.x, [1,2,3]);
% batchSize= size(res0.x, 4);
% res1.x= vl_nnnormalize( ...
% reshape(res0.x, [1, 1, D, batchSize]), ...
% [2*D, 1e-12, 1, 0.5]);
% end
%
% function res0= backward(l, res0, res1)
% D= relja_numel(res0.x, [1,2,3]);
% batchSize= size(res0.x, 4);
% res0.dzdx= reshape( vl_nnnormalize( ...
% reshape(res0.x, [1, 1, D, batchSize]), ...
% [2*D, 1e-12, 1, 0.5], ...
% reshape(res1.dzdx, [1, 1, D, batchSize])), ...
% size(res0.x) );
% end
%
% end
%
% end