-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.cpp
executable file
·98 lines (78 loc) · 2.2 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include "classifier.h"
#include <iostream>
#include <fstream>
#include <math.h>
#include <vector>
using namespace std;
vector<vector<double>> Load_State(string file_name)
{
ifstream in_state_(file_name.c_str(), ifstream::in);
vector< vector<double> > state_out;
string line;
while (getline(in_state_, line))
{
istringstream iss(line);
vector<double> x_coord;
double state1;
double state2;
double state3;
double state4;
iss >> state1;
iss.ignore();
x_coord.push_back(state1);
iss >> state2;
iss.ignore();
x_coord.push_back(state2);
iss >> state3;
iss.ignore();
x_coord.push_back(state3);
iss >> state4;
iss.ignore();
x_coord.push_back(state4);
//std::cout << "s1:" << state1 << " s2:" << state2
// << " s3:" << state3 << " s4:" << state4 << std::endl;
state_out.push_back(x_coord);
}
return state_out;
}
vector<string> Load_Label(string file_name)
{
ifstream in_label_(file_name.c_str(), ifstream::in);
vector< string > label_out;
string line;
while (getline(in_label_, line))
{
istringstream iss(line);
string label;
iss >> label;
label_out.push_back(label);
}
return label_out;
}
int main() {
vector< vector<double> > X_train = Load_State("../data/train_states.txt");
vector< vector<double> > X_test = Load_State("../data/test_states.txt");
vector< string > Y_train = Load_Label("../data/train_labels.txt");
vector< string > Y_test = Load_Label("../data/test_labels.txt");
cout << "X_train number of elements " << X_train.size() << endl;
cout << "X_train element size " << X_train[0].size() << endl;
cout << "Y_train number of elements " << Y_train.size() << endl;
GNB gnb = GNB();
gnb.train(X_train, Y_train);
cout << "X_test number of elements " << X_test.size() << endl;
cout << "X_test element size " << X_test[0].size() << endl;
cout << "Y_test number of elements " << Y_test.size() << endl;
int score = 0;
for(int i = 0; i < X_test.size(); i++)
{
vector<double> coords = X_test[i];
string predicted = gnb.predict(coords);
if(predicted.compare(Y_test[i]) == 0)
{
score += 1;
}
}
float fraction_correct = float(score) / Y_test.size();
cout << "You got " << (100*fraction_correct) << " correct" << endl;
return 0;
}