Skip to content

Commit 518eb6b

Browse files
authored
Merge pull request #147 from DrDub/master
Leaky ReLUs with test case
2 parents b591110 + 23987ef commit 518eb6b

File tree

7 files changed

+80
-8
lines changed

7 files changed

+80
-8
lines changed

src/fann.c

+12-5
Original file line numberDiff line numberDiff line change
@@ -683,13 +683,20 @@ FANN_EXTERNAL fann_type *FANN_API fann_run(struct fann *ann, fann_type *input) {
683683
neuron_it->value = neuron_sum;
684684
break;
685685
case FANN_LINEAR_PIECE:
686-
neuron_it->value = (fann_type)(
687-
(neuron_sum < 0) ? 0 : (neuron_sum > multiplier) ? multiplier : neuron_sum);
686+
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0
687+
: (neuron_sum > multiplier) ? multiplier
688+
: neuron_sum);
688689
break;
689690
case FANN_LINEAR_PIECE_SYMMETRIC:
690-
neuron_it->value = (fann_type)((neuron_sum < -multiplier)
691-
? -multiplier
692-
: (neuron_sum > multiplier) ? multiplier : neuron_sum);
691+
neuron_it->value = (fann_type)((neuron_sum < -multiplier) ? -multiplier
692+
: (neuron_sum > multiplier) ? multiplier
693+
: neuron_sum);
694+
break;
695+
case FANN_LINEAR_PIECE_RECT:
696+
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0 : neuron_sum);
697+
break;
698+
case FANN_LINEAR_PIECE_RECT_LEAKY:
699+
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0.01 * neuron_sum : neuron_sum);
693700
break;
694701
case FANN_ELLIOT:
695702
case FANN_ELLIOT_SYMMETRIC:

src/fann_cascade.c

+2
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,8 @@ fann_type fann_train_candidates_epoch(struct fann *ann, struct fann_train_data *
680680
case FANN_GAUSSIAN_STEPWISE:
681681
case FANN_ELLIOT:
682682
case FANN_LINEAR_PIECE:
683+
case FANN_LINEAR_PIECE_RECT:
684+
case FANN_LINEAR_PIECE_RECT_LEAKY:
683685
case FANN_SIN:
684686
case FANN_COS:
685687
break;

src/fann_train.c

+6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ fann_type fann_activation_derived(unsigned int activation_function, fann_type st
4040
case FANN_LINEAR_PIECE:
4141
case FANN_LINEAR_PIECE_SYMMETRIC:
4242
return (fann_type)fann_linear_derive(steepness, value);
43+
case FANN_LINEAR_PIECE_RECT:
44+
return (fann_type)((value < 0) ? 0 : steepness);
45+
case FANN_LINEAR_PIECE_RECT_LEAKY:
46+
return (fann_type)((value < 0) ? steepness * 0.01 : steepness);
4347
case FANN_SIGMOID:
4448
case FANN_SIGMOID_STEPWISE:
4549
value = fann_clip(value, 0.01f, 0.99f);
@@ -125,6 +129,8 @@ fann_type fann_update_MSE(struct fann *ann, struct fann_neuron *neuron, fann_typ
125129
case FANN_GAUSSIAN_STEPWISE:
126130
case FANN_ELLIOT:
127131
case FANN_LINEAR_PIECE:
132+
case FANN_LINEAR_PIECE_RECT:
133+
case FANN_LINEAR_PIECE_RECT_LEAKY:
128134
case FANN_SIN:
129135
case FANN_COS:
130136
break;

src/include/fann_activation.h

+6
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ __doublefann_h__ is not defined
178178
case FANN_GAUSSIAN_STEPWISE: \
179179
result = 0; \
180180
break; \
181+
case FANN_LINEAR_PIECE_RECT: \
182+
result = (fann_type)((value < 0) ? 0 : value); \
183+
break; \
184+
case FANN_LINEAR_PIECE_RECT_LEAKY: \
185+
result = (fann_type)((value < 0) ? value * 0.01 : value); \
186+
break; \
181187
}
182188

183189
#endif

src/include/fann_data.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,16 @@ static char const *const FANN_TRAIN_NAMES[] = {"FANN_TRAIN_INCREMENTAL", "FANN_T
196196
* y = cos(x*s)/2+0.5
197197
* d = s*-sin(x*s)/2
198198
199+
FANN_LINEAR_PIECE_RECT - ReLU
200+
* span: -inf < y < inf
201+
* y = x<0? 0: x
202+
* d = x<0? 0: 1
203+
204+
FANN_LINEAR_PIECE_RECT_LEAKY - leaky ReLU
205+
* span: -inf < y < inf
206+
* y = x<0? 0.01*x: x
207+
* d = x<0? 0.01: 1
208+
199209
See also:
200210
<fann_set_activation_function_layer>, <fann_set_activation_function_hidden>,
201211
<fann_set_activation_function_output>, <fann_set_activation_steepness>,
@@ -223,7 +233,9 @@ enum fann_activationfunc_enum {
223233
FANN_SIN_SYMMETRIC,
224234
FANN_COS_SYMMETRIC,
225235
FANN_SIN,
226-
FANN_COS
236+
FANN_COS,
237+
FANN_LINEAR_PIECE_RECT,
238+
FANN_LINEAR_PIECE_RECT_LEAKY
227239
};
228240

229241
/* Constant: FANN_ACTIVATIONFUNC_NAMES
@@ -254,7 +266,9 @@ static char const *const FANN_ACTIVATIONFUNC_NAMES[] = {"FANN_LINEAR",
254266
"FANN_SIN_SYMMETRIC",
255267
"FANN_COS_SYMMETRIC",
256268
"FANN_SIN",
257-
"FANN_COS"};
269+
"FANN_COS",
270+
"FANN_LINEAR_PIECE_RECT",
271+
"FANN_LINEAR_PIECE_RECT_LEAKY"};
258272

259273
/* Enum: fann_errorfunc_enum
260274
Error function used during training.

src/include/fann_data_cpp.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,16 @@ enum training_algorithm_enum {
200200
* y = cos(x*s)
201201
* d = s*-sin(x*s)
202202
203+
FANN_LINEAR_PIECE_RECT - ReLU
204+
* span: -inf < y < inf
205+
* y = x<0? 0: x
206+
* d = x<0? 0: 1
207+
208+
FANN_LINEAR_PIECE_RECT_LEAKY - leaky ReLU
209+
* span: -inf < y < inf
210+
* y = x<0? 0.01*x: x
211+
* d = x<0? 0.01: 1
212+
203213
See also:
204214
<neural_net::set_activation_function_hidden>,
205215
<neural_net::set_activation_function_output>
@@ -220,7 +230,9 @@ enum activation_function_enum {
220230
LINEAR_PIECE,
221231
LINEAR_PIECE_SYMMETRIC,
222232
SIN_SYMMETRIC,
223-
COS_SYMMETRIC
233+
COS_SYMMETRIC,
234+
LINEAR_PIECE_RECT,
235+
LINEAR_PIECE_RECT_LEAKY
224236
};
225237

226238
/* Enum: network_type_enum

tests/fann_test_train.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,30 @@ TEST_F(FannTestTrain, TrainOnDateSimpleXor) {
2121
EXPECT_LT(net.test_data(data), 0.001);
2222
}
2323

24+
TEST_F(FannTestTrain, TrainOnReLUSimpleXor) {
25+
neural_net net(LAYER, 3, 2, 3, 1);
26+
27+
data.set_train_data(4, 2, xorInput, 1, xorOutput);
28+
net.set_activation_function_hidden(FANN::LINEAR_PIECE_RECT);
29+
net.set_activation_steepness_hidden(1.0);
30+
net.train_on_data(data, 100, 100, 0.001);
31+
32+
EXPECT_LT(net.get_MSE(), 0.001);
33+
EXPECT_LT(net.test_data(data), 0.001);
34+
}
35+
36+
TEST_F(FannTestTrain, TrainOnReLULeakySimpleXor) {
37+
neural_net net(LAYER, 3, 2, 3, 1);
38+
39+
data.set_train_data(4, 2, xorInput, 1, xorOutput);
40+
net.set_activation_function_hidden(FANN::LINEAR_PIECE_RECT_LEAKY);
41+
net.set_activation_steepness_hidden(1.0);
42+
net.train_on_data(data, 100, 100, 0.001);
43+
44+
EXPECT_LT(net.get_MSE(), 0.001);
45+
EXPECT_LT(net.test_data(data), 0.001);
46+
}
47+
2448
TEST_F(FannTestTrain, TrainSimpleIncrementalXor) {
2549
neural_net net(LAYER, 3, 2, 3, 1);
2650

@@ -41,3 +65,4 @@ TEST_F(FannTestTrain, TrainSimpleIncrementalXor) {
4165

4266
EXPECT_LT(net.get_MSE(), 0.01);
4367
}
68+

0 commit comments

Comments
 (0)