@@ -21,6 +21,30 @@ TEST_F(FannTestTrain, TrainOnDateSimpleXor) {
21
21
EXPECT_LT (net.test_data (data), 0.001 );
22
22
}
23
23
24
+ TEST_F (FannTestTrain, TrainOnReLUSimpleXor) {
25
+ neural_net net (LAYER, 3 , 2 , 3 , 1 );
26
+
27
+ data.set_train_data (4 , 2 , xorInput, 1 , xorOutput);
28
+ net.set_activation_function_hidden (FANN::LINEAR_PIECE_RECT);
29
+ net.set_activation_steepness_hidden (1.0 );
30
+ net.train_on_data (data, 100 , 100 , 0.001 );
31
+
32
+ EXPECT_LT (net.get_MSE (), 0.001 );
33
+ EXPECT_LT (net.test_data (data), 0.001 );
34
+ }
35
+
36
+ TEST_F (FannTestTrain, TrainOnReLULeakySimpleXor) {
37
+ neural_net net (LAYER, 3 , 2 , 3 , 1 );
38
+
39
+ data.set_train_data (4 , 2 , xorInput, 1 , xorOutput);
40
+ net.set_activation_function_hidden (FANN::LINEAR_PIECE_RECT_LEAKY);
41
+ net.set_activation_steepness_hidden (1.0 );
42
+ net.train_on_data (data, 100 , 100 , 0.001 );
43
+
44
+ EXPECT_LT (net.get_MSE (), 0.001 );
45
+ EXPECT_LT (net.test_data (data), 0.001 );
46
+ }
47
+
24
48
TEST_F (FannTestTrain, TrainSimpleIncrementalXor) {
25
49
neural_net net (LAYER, 3 , 2 , 3 , 1 );
26
50
@@ -41,3 +65,4 @@ TEST_F(FannTestTrain, TrainSimpleIncrementalXor) {
41
65
42
66
EXPECT_LT (net.get_MSE (), 0.01 );
43
67
}
68
+
0 commit comments