-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathEnsembleDemo.java
125 lines (101 loc) · 5.08 KB
/
EnsembleDemo.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import model.*;
import processor.BertProcess;
import processor.InputFeature;
import processor.textCNNProcess;
import utils.*;
import java.util.ArrayList;
import java.util.Scanner;
public class EnsembleDemo {
protected TextCNN textcnn;
protected Bert bert;
protected float[] textcnn_prob;
protected float[] bert_prob;
public EnsembleDemo() {
this.loadModels(); // load models and keep it in the memory
}
public void loadModels() {
System.out.println("====================== Start loading models ... ====================== ");
this.textcnn = this._load_TextCNNs();
this.bert = this._load_Bert();
System.out.println("====================== Models loaded! ======================");
}
public TextCNN _load_TextCNNs() {
String textcnn_model_path = "pb_model/text_cnn.pb";
TextCNN textcnn = new TextCNN(textcnn_model_path);
return textcnn;
}
public Bert _load_Bert() {
String bert_model_path = "pb_model/bert_L4_FC3_Seq128.pb";
Bert bert = new Bert(bert_model_path);
return bert;
}
public void doPredictWithModels(String text, int cnn_max_seq, int bert_max_seq) {
this.textcnn_prob = this.__pred_with_textCNNs(this.textcnn, text, cnn_max_seq);
// output textCNNs result
System.out.println("======================textCNN predictions ======================");
PredictionUtils.printPerLabelProb(this.textcnn_prob);
String pred_label_cnn = PredictionUtils.getLabelName(this.textcnn_prob);
System.out.println(pred_label_cnn);
this.bert_prob = this.__pred_with_Bert(this.bert, text, bert_max_seq);
// output bert result
System.out.println("====================== Bert predictions ======================");
PredictionUtils.printPerLabelProb(this.bert_prob);
String pred_label_bert = PredictionUtils.getLabelName(this.bert_prob);
System.out.println(pred_label_bert);
}
public float[] __pred_with_textCNNs(TextCNN textcnn, String text, int cnn_max_seq) {
// Run TextCNN
System.out.println("====================== Start textCNN preprocessor! ======================");
long cnn_pro_time = System.currentTimeMillis();
textCNNProcess processor = new textCNNProcess(text, cnn_max_seq);
InputFeature inputfeat = processor.getInputFeatures();
System.out.println("TextCNN 预处理耗时 : " + (System.currentTimeMillis() - cnn_pro_time) / 1000f + " 秒 ");
long cnn_run_time = System.currentTimeMillis();
float[] prob = textcnn.predict_single_case(inputfeat);
System.out.println("TextCNN 预测执行耗时 : " + (System.currentTimeMillis() - cnn_run_time) / 1000f + " 秒 ");
return prob;
}
public float[] __pred_with_Bert(Bert bert, String text, int bert_max_seq) {
System.out.println("====================== Start Bert preprocessor! ======================");
long bert_pro_time = System.currentTimeMillis();
BertProcess bert_processor = new BertProcess(text, bert_max_seq);
InputFeature bert_feat = bert_processor.getInputFeatures();
System.out.println("Bert 预处理执行耗时 : " + (System.currentTimeMillis() - bert_pro_time) / 1000f + " 秒 ");
long bert_pred_time = System.currentTimeMillis();
float[] prob = bert.predict_single_case(bert_feat);
String[] bert_labels = bert.getBertLabel();
float[] std_prob_bert = PredictionUtils.getProbInLabelSequence(bert_labels, PredictionUtils.labelList, prob);
System.out.println("Bert 预测执行耗时 : " + (System.currentTimeMillis() - bert_pred_time) / 1000f + " 秒 ");
return std_prob_bert;
}
public String _doVotingWithWeight() {
// Do ensemble vote with the order of textCNN/Bert/...
float[] weightedVote = {.25f, .5f};
ArrayList<float[]> prob_list = new ArrayList<float[]>();
prob_list.add(this.textcnn_prob);
prob_list.add(this.bert_prob);
System.out.println("====================== Weighted voting results ======================");
float[] weight_prob = EnsembleUtils.doWeightedVote(weightedVote, prob_list);
PredictionUtils.printPerLabelProb(weight_prob);
String pred_label_weightedVote = PredictionUtils.getLabelName(weight_prob);
System.out.println(pred_label_weightedVote);
return pred_label_weightedVote;
}
public static void main(String[] args) {
System.out.println("Starting to predict!");
int cnn_max_seq = 30, bert_max_seq = 128; // hyper-parameter
EnsembleDemo demo = new EnsembleDemo();
Scanner scn = new Scanner(System.in);
// 我要听小猪佩奇的故事。
while (true) {
System.out.println("请输入测试话术 (输入 exit 结束):");
String X = scn.nextLine().trim();
if (X.equals("exit")) {
System.out.println("Bye!");
break;
}
demo.doPredictWithModels(X, cnn_max_seq, bert_max_seq);
demo._doVotingWithWeight();
}
}
}