-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathrutil.lua
170 lines (150 loc) · 5.13 KB
/
rutil.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
require 'rnn';
require 'nn';
require 'cunn';
require 'MultiCrossEntropyCriterion'
local rutil = {}
function rutil.getNetCt(net,bs,rho,hs,cls)
local batchSize = bs or 16
local rho = rho or 1
local hiddenSize = hs or 36
local classes = cls or 36
local mlp = nn.Sequential()
:add(nn.Recurrent(
hiddenSize, nn.Identity(),
nn:Sequential():add(nn.Linear(hiddenSize, hiddenSize)):add(nn.ReLU()), nn.ReLU(),
rho
))
:add(nn.Linear(hiddenSize, classes))
:add(nn.LogSoftMax())
local rnn = nn.Sequential():add(net):add(nn.Sequencer(mlp))
local criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
return rnn,criterion
end
function rutil.kfacc(outputs,targets)
local Y,y = nil,nil;
local N = outputs[1]:size(1)
local C = outputs[1]:size(2)
for k=1,#outputs do
Y = Y and torch.cat(Y,outputs[k]:reshape(N,1,C),2) or outputs[k]:reshape(N,1,C)
y = y and torch.cat(y,targets[k]:reshape(N,1),2) or targets[k]:reshape(N,1)
end
local t,idx = Y:max(3)
return idx:squeeze():eq(y):sum(2):eq(#outputs):sum()
end
function rutil.facc(outputs,targets)
local acc = 0
local acci = {}
for k=1,#outputs do
local t,idx = outputs[k]:max(2)
local ai = targets[k]:eq(idx:squeeze()):sum()
acc = acc + ai
table.insert(acci,ai)
end
return acc,acci
end
function rutil.valid(rnn,Xv,Yv,batchSize,tnet,f)
local batchSize = bs or 16
local acc = 0
local acci = {}
local Nv = Xv:size(1)
rnn:evaluate()
for i=1,Nv,batchSize do
xlua.progress(i/batchSize, Nv/batchSize)
local j = math.min(Nv,i+batchSize-1)
local Xb = Xv[{{i,j}}]:cuda()
local Yb = Yv[{{i,j}}]:cuda()
local inputs = Xb
local targets = tnet:forward(Yb)
local outputs = rnn:forward(inputs)
local aa,ai = f(outputs,targets) or rutil.facc(outputs,targets)
for k=1,#ai do
acci[k] = (acci[k] or 0) + ai[k]
end
acc = acc + aa/#ai
rnn:forget()
end
for k=1,#acci do
acci[k] = (acci[k] or 0) * 100/(Nv)
end
return (acc*100)/Nv,acci
end
function rutil.train(rnn,criterion,Xt,Yt,Xv,Yv,T,batchSize,tnet,lr)
local batchSize = batchSize or 16
local maxv = 0 or maxv
local T = T or 2
for t = 1,T do
print(os.date("%X", os.time()))
print(t,T)
rnn:forget()
local loss = 0
local acc = 0
local Nt = Xt:size(1)
rnn:training()
for i=1,Nt,batchSize do
xlua.progress(i/batchSize, Nt/batchSize)
local j = math.min(Nt,i+batchSize-1)
local Xb = Xt[{{i,j}}]:cuda()
local Yb = Yt[{{i,j}}]:cuda()
local inputs = Xb
local targets = tnet:forward(Yb)
local outputs = rnn:forward(inputs)
loss = loss + criterion:forward(outputs, targets)/#targets
local gradOutputs = criterion:backward(outputs,targets)
acc = acc + rutil.facc(outputs,targets)/#targets
rnn:backward(inputs,gradOutputs)
rnn:backwardThroughTime()
rnn:updateParameters(lr or 0.001)
rnn:zeroGradParameters()
rnn:forget()
end
print('loss',loss)
print('train',100*acc/Nt)
local v,acc = rutil.valid(rnn,Xv,Yv,batchSize,tnet)
if(v>maxv) then maxv = v end
print('v',v,'maxv',maxv)
print(acc)
print(os.date("%X", os.time()))
end
print(maxv)
end
function rutil.model()
local k = k or 5
local c = c or 36
vgg = nn.Sequential()
vgg:add(nn.Reshape(1,50,200))
local function ConvBNReLU(nInputPlane, nOutputPlane)
vgg:add(nn.SpatialConvolution(nInputPlane, nOutputPlane, 3,3, 1,1, 1,1))
vgg:add(nn.SpatialBatchNormalization(nOutputPlane,1e-3))
vgg:add(nn.ReLU(true))
return vgg
end
ConvBNReLU(1,64)--:add(nn.Dropout(0.3,nil,true))
ConvBNReLU(64,64)
ConvBNReLU(64,64)
vgg:add(nn.SpatialMaxPooling(2,2,2,2):ceil())
ConvBNReLU(64,128)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(128,128)--:add(nn.Dropout(0.4,nil,true))
vgg:add(nn.SpatialMaxPooling(2,2,2,2):ceil())
ConvBNReLU(128,256)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(256,256)--:add(nn.Dropout(0.4,nil,true))
vgg:add(nn.SpatialMaxPooling(2,2,2,2):ceil())
ConvBNReLU(256,256)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(256,256)--:add(nn.Dropout(0.4,nil,true))
vgg:add(nn.SpatialMaxPooling(2,2,2,2):ceil())
ConvBNReLU(256,256)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(256,256)--:add(nn.Dropout(0.4,nil,true))
ConvBNReLU(256,256)--:add(nn.Dropout(0.4,nil,true))
vgg:add(nn.SpatialMaxPooling(2,2,2,2):ceil())
vgg:add(nn.View(256*2*7))
local classifier = nn.Sequential()
--classifier:add(nn.Dropout(0.5,nil,true))
classifier:add(nn.Linear(256*2*7,256))
classifier:add(nn.BatchNormalization(256))
classifier:add(nn.ReLU(true))
classifier:add(nn.Linear(256,k*c))
vgg:add(classifier)
vgg:add(nn.Reshape(5,36))
vgg:add(nn.SplitTable(2,3))
return vgg
end
return rutil