-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain_predictor.py
44 lines (35 loc) · 1.92 KB
/
main_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from keras.utils import to_categorical
import pickle
# requests.post('http://0:4004/elmo_embed_words', json={"text":'is it?'}).json()
# requests.post('https://55898a32.eu.ngrok.io/elmo_embed_words', json={"text":'is it?\r\nokay got it.'}).json()
from models import model_attention_applied_after_bilstm, context_model_att
from src.utils import *
def main():
max_seq_len = 20
trainFile = 'data/swda-actags_train_speaker.csv'
testFile = 'data/swda-actags_test_speaker.csv'
SidTr, Xtrain, Ytrain, Ztrain = read_files(trainFile)
SidTest, Xtest, Ytest, Ztest = read_files(testFile)
print(len(Xtest), len(Xtrain))
x_test = pickle.load(open("features/x_test_tokens.p", "rb"))
x_train = pickle.load(open("features/x_train_tokens.p", "rb"))
toPadding = np.load('features/pad_a_token.npy')
X_Test = np.load('features/X_test_elmo_features.npy')
X_Test = padSequencesKeras(X_Test, max_seq_len, toPadding)
tags, num, Y_train, Y_test = categorize_raw_data(Ztrain, Ztest)
target_category_test = to_categorical(Y_test, len(tags))
# NON-CONTEXT MODEL
model = model_attention_applied_after_bilstm(max_seq_len, X_Test.shape[2], len(tags))
model.load_weights('params/weight_parameters')
evaluation = model.evaluate(X_Test, target_category_test, verbose=2)
print("Test results for non-context model - accuracy: {}".format(evaluation[1]))
seq_length = 3 # Preparing data for contextual training with Seq_length
X_test_con, Y_test_con = prepare_data(X_Test, target_category_test, seq_length)
# CONTEXT MODEL
context_model = context_model_att(seq_length, max_seq_len, X_test_con.shape[3], len(tags))
con_model_name = 'params/context_model_att_{}'.format(seq_length)
context_model.load_weights(con_model_name)
loss, old_acc = context_model.evaluate(X_test_con, Y_test_con, verbose=2, batch_size=32)
print('Context Score results:', old_acc)
if __name__ == "__main__":
main()