-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_d1.py
44 lines (37 loc) · 1.94 KB
/
main_d1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from bib import *
train_QP_22 = ['ChristmasTree_QP_22_depth1.csv','CrowdRun_QP_22_depth1.csv','DucksTakeOff_QP_22_depth1.csv',
'PedestrianArea_QP_22_depth1.csv','RushHour_QP_22_depth1.csv','Sunflower_QP_22_depth1.csv']
valid_QP_22 = ['Tractor_QP_22_depth1.csv','Wisley_QP_22_depth1.csv']
train_QP_27 = ['ChristmasTree_QP_27_depth1.csv','CrowdRun_QP_27_depth1.csv','DucksTakeOff_QP_27_depth1.csv',
'PedestrianArea_QP_27_depth1.csv','RushHour_QP_27_depth1.csv','Sunflower_QP_27_depth1.csv']
valid_QP_27 = ['Tractor_QP_27_depth1.csv','Wisley_QP_27_depth1.csv']
train_QP_32 = ['ChristmasTree_QP_32_depth1.csv','CrowdRun_QP_32_depth1.csv','DucksTakeOff_QP_32_depth1.csv',
'PedestrianArea_QP_32_depth1.csv','RushHour_QP_32_depth1.csv','Sunflower_QP_32_depth1.csv']
valid_QP_32 = ['Tractor_QP_32_depth1.csv','Wisley_QP_32_depth1.csv']
train_QP_37 = ['ChristmasTree_QP_37_depth1.csv','CrowdRun_QP_37_depth1.csv','DucksTakeOff_QP_37_depth1.csv',
'PedestrianArea_QP_37_depth1.csv','RushHour_QP_37_depth1.csv','Sunflower_QP_37_depth1.csv']
valid_QP_37 = ['Tractor_QP_37_depth1.csv','Wisley_QP_37_depth1.csv']
max_depths = [10,7,1,3]
trains = [train_QP_22, train_QP_27, train_QP_32, train_QP_37]
valids = [valid_QP_22, valid_QP_27, valid_QP_32, valid_QP_37]
qps = [22,27,32,37]
for train,valid,max_depth,qp in zip(trains,valids,max_depths,qps):
data = Data()
#if qp == 27 or qp==37:
# data.load_data(train,valid,ftk=[5,6,8])
#else:
data.load_data(train,valid)
#if qp == 22:
clf = Classifier(data,max_depth=max_depth,hack=True)
#else:
# clf = Classifier(data,max_depth=max_depth,hack=True)
clf.fit_tree()
clf.prune_duplicate_leaves(clf.clf)
clf.get_stats()
print('Acc: '+ str(clf.acc))
print('Cost:' + str(clf.total_cost))
print('Min Cost: ' + str(clf.calculate_minimal_cost()))
print('Ratio: ' + str(clf.total_cost/clf.calculate_minimal_cost()))
et = ExportTree(clf)
et.write_tree_cpp(1,qp,3)
print('\n')