-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_d0.py
47 lines (39 loc) · 2.05 KB
/
main_d0.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from bib import *
train_QP_22 = ['ChristmasTree_QP_22_depth0.csv','CrowdRun_QP_22_depth0.csv','DucksTakeOff_QP_22_depth0.csv',
'PedestrianArea_QP_22_depth0.csv','RushHour_QP_22_depth0.csv','Sunflower_QP_22_depth0.csv']
valid_QP_22 = ['Tractor_QP_22_depth0.csv','Wisley_QP_22_depth0.csv']
train_QP_27 = ['ChristmasTree_QP_27_depth0.csv','CrowdRun_QP_27_depth0.csv','DucksTakeOff_QP_27_depth0.csv',
'PedestrianArea_QP_27_depth0.csv','RushHour_QP_27_depth0.csv','Sunflower_QP_27_depth0.csv']
valid_QP_27 = ['Tractor_QP_27_depth0.csv','Wisley_QP_27_depth0.csv']
train_QP_32 = ['ChristmasTree_QP_32_depth0.csv','CrowdRun_QP_32_depth0.csv','DucksTakeOff_QP_32_depth0.csv',
'PedestrianArea_QP_32_depth0.csv','RushHour_QP_32_depth0.csv','Sunflower_QP_32_depth0.csv']
valid_QP_32 = ['Tractor_QP_32_depth0.csv','Wisley_QP_32_depth0.csv']
train_QP_37 = ['ChristmasTree_QP_37_depth0.csv','CrowdRun_QP_37_depth0.csv','DucksTakeOff_QP_37_depth0.csv',
'PedestrianArea_QP_37_depth0.csv','RushHour_QP_37_depth0.csv','Sunflower_QP_37_depth0.csv']
valid_QP_37 = ['Tractor_QP_37_depth0.csv','Wisley_QP_37_depth0.csv']
#checks if there is a 'tree.cpp' file. If so, deletes it
if os.path.isfile('tree.cpp'): os.remove('tree.cpp')
max_depths = [5,3,5,7]
trains = [train_QP_22, train_QP_27, train_QP_32, train_QP_37]
valids = [valid_QP_22, valid_QP_27, valid_QP_32, valid_QP_37]
qps = [22,27,32,37]
for train,valid,max_depth,qp in zip(trains,valids,max_depths,qps):
data = Data()
#if qp == 27 or qp==37:
# data.load_data(train,valid,ftk=[1,3,5,6,7,8])
#else:
data.load_data(train,valid)
#if qp == 22:
clf = Classifier(data,max_depth=max_depth, hack=True)
#else:
# clf = Classifier(data,max_depth=max_depth,hack=True)
clf.fit_tree()
clf.prune_duplicate_leaves(clf.clf)
clf.get_stats()
print('Acc: '+ str(clf.acc))
print('Cost:' + str(clf.total_cost))
print('Min Cost: ' + str(clf.calculate_minimal_cost()))
print('Ratio: ' + str(clf.total_cost/clf.calculate_minimal_cost()))
et = ExportTree(clf)
et.write_tree_cpp(0,qp,3)
print('\n')