-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnntrain.m
85 lines (83 loc) · 2.49 KB
/
cnntrain.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
function [netBest,mode] = cnntrain(net, x, y,testx,testy, input,opts,fid1)
m = size(x, 3);
numbatches = m / opts.batchsize;
if rem(numbatches, 1) ~= 0
error('numbatches not integer');
end
net.rL = [];
ermin.id =1;
ermin.advance=1;
flag = 0;
ermin.value = cnntest1(net,x,y);
netBest =net;
mode = 1;
for i = 1 : opts.numepochs
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)]);
tic;
kk = randperm(m);
%kk = 1:960;
for l = 1 : numbatches
batch_x = x(:, :, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
batch_y = y(:, kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize));
net = cnnff(net, batch_x);
net = cnnbp(net, batch_y);
net = cnnapplygrads(net, opts);
if isempty(net.rL)
%net.L =gather(net.L);
net.rL(1) = net.L;
end
%net.rL = gpuArray(net.rL);
net.rL(end + 1) = 0.99 * net.rL(end) + 0.01 * net.L;
end
toc;
%[er1,er2,er3] = cnntest(net, x,y ,input,15,15,fid1);
er1 = cnntest1(net,x,y);
%if er1 < er2
%if er1 < er3
% if er1 <ermin.value
% ermin.value =er1;
% ermin.id = i;
% netBest =net;
% mode =1;
%end;
%else if er3 <ermin.value
% ermin.value =er3;
% ermin.id = i;
% netBest =net;
% mode =3;
%end;
%end;
%else if er2 < er3
% if er2 <ermin.value
% ermin.value =er2;
% ermin.id = i;
% netBest =net;
% mode =2;
% end;
%else if er3 <ermin.value
% ermin.value =er3;
% ermin.id = i;
% netBest =net;
% mode =3;
%end;
% end;
%end;
if er1 == ermin.advance
flag = flag+1;
else
ermin.advance=er1;
flag = 0;
end;
if flag >=100
break;
end;
if er1 <ermin.value
ermin.value =er1;
ermin.id = i;
netBest =net;
mode =1;
end;
%cnntest(net,testx,testy,input,15,15,fid1);
cnntest1(net,testx,testy);
end
end