-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
101 lines (86 loc) · 3.34 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import cv2
import numpy as np
import itertools, os, time
from Model import get_Model
from parameter import *
import json
from keras import backend as K
K.set_learning_phase(0)
##load json file labels
def load_json(file_path):
with open(file_path) as f:
data = json.load(f)
return data
class predict():
def __init__(self, json_path, model_path):
self.epoch = 0
self.output_path = 0
self.model_name = 0
self.json_path = json_path
self.model_path = model_path
self.model = get_Model(training=False)
def load_model(self):
print('load')
self.model.load_weights(self.model_path)
print("...Previous weight data...")
def decode_label(self,out):
out_best = list(np.argmax(out[0, 2:], axis=1)) # get max index -> len = 32
out_best = [k for k, g in itertools.groupby(out_best)] # remove overlap value
outstr = ''
for i in out_best:
if i < len(letters):
outstr += letters[i]
return outstr
def load_data(self):
print("Json Loading start.....")
data = load_json(self.json_path)
img_path = []
labels = []
print("Data Test Found " + str(len(data.keys())))
for key in data:
img_path.append(key)
labels.append(data[key])
return img_path, labels
def TrainCheck(self):
self.load_model()
img_path, labels = self.load_data()
total = 0
acc = 0
letter_total = 0
letter_acc = 0
start = time.time()
for index in range(len(img_path)):
img = cv2.imread('./KANA_data100-1/'+img_path[index], cv2.IMREAD_GRAYSCALE)
img_pred = img.astype(np.float32)
img_pred = cv2.resize(img_pred, (img_w, img_h))
img_pred = (img_pred / 255.0) * 2.0 - 1.0
img_pred = img_pred.T
img_pred = np.expand_dims(img_pred, axis=-1)
img_pred = np.expand_dims(img_pred, axis=0)
net_out_value = self.model.predict(img_pred)
pred_texts = self.decode_label(net_out_value)
true_texts = labels[index]
match_char = 0
for i in range(min(len(pred_texts), len(true_texts))):
if pred_texts[i] == true_texts[i]:
letter_acc += 1
match_char += 1
letter_total += max(len(pred_texts), len(true_texts))
letter_length = max(len(pred_texts), len(true_texts))
if pred_texts == true_texts:
acc += 1
total += 1
print('Predicted: %s / True: %s /Acc: %s' % (pred_texts, true_texts, (match_char / letter_length)*100 ))
# cv2.rectangle(img, (0,0), (150, 30), (0,0,0), -1)
# cv2.putText(img, pred_texts, (5, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255,255,255),2)
#cv2.imshow("q", img)
#if cv2.waitKey(0) == 27:
# break
#cv2.destroyAllWindows()
end = time.time()
total_time = (end - start)
print("Time : ",total_time / total)
print("ACC : ", (acc / total)*100)
print("letter ACC : ", (letter_acc / letter_total)*100)
a = predict('/home/hogwarts/Documents/OCR_Japan/KANA_data100-1/val.json','/home/hogwarts/Downloads/LSTM+BN5--01--24.476.hdf5')
a.TrainCheck()