-
Notifications
You must be signed in to change notification settings - Fork 8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CV套件建设专项活动 - 文字识别返回单字识别坐标 #10515
CV套件建设专项活动 - 文字识别返回单字识别坐标 #10515
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,10 +64,55 @@ def pred_reverse(self, pred): | |
|
||
return ''.join(pred_re[::-1]) | ||
|
||
def add_special_char(self, dict_character): | ||
def add_special_char(self, text, dict_character): | ||
return dict_character | ||
|
||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | ||
def get_word_info(self, text, selection): | ||
state = None | ||
word_content = [] | ||
word_col_content = [] | ||
word_list = [] | ||
word_col_list = [] | ||
state_list = [] | ||
valid_col = np.where(selection==True)[0] | ||
|
||
for c_i, char in enumerate(text): | ||
if '\u4e00' <= char <= '\u9fff': | ||
c_state = 'cn' | ||
elif bool(re.search('[a-zA-Z0-9]', char)): | ||
c_state = 'en&num' | ||
else: | ||
c_state = 'splitter' | ||
|
||
if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping float number | ||
c_state = 'en&num' | ||
if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art' | ||
c_state = 'en&num' | ||
|
||
if state == None: | ||
state = c_state | ||
|
||
if state != c_state: | ||
if len(word_content) != 0: | ||
word_list.append(word_content) | ||
word_col_list.append(word_col_content) | ||
state_list.append(state) | ||
word_content = [] | ||
word_col_content = [] | ||
state = c_state | ||
|
||
if state != "splitter": | ||
word_content.append(char) | ||
word_col_content.append(valid_col[c_i]) | ||
|
||
if len(word_content) != 0: | ||
word_list.append(word_content) | ||
word_col_list.append(word_col_content) | ||
state_list.append(state) | ||
|
||
return word_list, word_col_list, state_list | ||
|
||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False): | ||
""" convert text-index into text-label. """ | ||
result_list = [] | ||
ignored_tokens = self.get_ignored_tokens() | ||
|
@@ -95,8 +140,12 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False): | |
|
||
if self.reverse: # for arabic rec | ||
text = self.pred_reverse(text) | ||
|
||
result_list.append((text, np.mean(conf_list).tolist())) | ||
|
||
if return_word_box: | ||
word_list, word_col_list, state_list = self.get_word_info(text, selection) | ||
result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list])) | ||
else: | ||
result_list.append((text, np.mean(conf_list).tolist())) | ||
return result_list | ||
|
||
def get_ignored_tokens(self): | ||
|
@@ -111,14 +160,14 @@ def __init__(self, character_dict_path=None, use_space_char=False, | |
super(CTCLabelDecode, self).__init__(character_dict_path, | ||
use_space_char) | ||
|
||
def __call__(self, preds, label=None, *args, **kwargs): | ||
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs): | ||
if isinstance(preds, tuple) or isinstance(preds, list): | ||
preds = preds[-1] | ||
if isinstance(preds, paddle.Tensor): | ||
preds = preds.numpy() | ||
preds_idx = preds.argmax(axis=2) | ||
preds_prob = preds.max(axis=2) | ||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) | ||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box) | ||
if label is None: | ||
return text | ||
label = self.decode(label) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉在预测里面的代码可以放在这里,因为所有的CTCLabelDecode都是在rec后处理中。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加这里的话可能要把图像的宽高比也作为参数传进来哈哈 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个能不能利用kwargs传入呢? |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,6 +79,8 @@ def __init__(self, args): | |
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor | ||
self.kie_predictor = SerRePredictor(args) | ||
|
||
self.return_word_box = args.return_word_box | ||
|
||
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): | ||
time_dict = { | ||
'image_orientation': 0, | ||
|
@@ -156,17 +158,66 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): | |
] | ||
res = [] | ||
for box, rec_res in zip(filter_boxes, filter_rec_res): | ||
rec_str, rec_conf = rec_res | ||
rec_str, rec_conf = rec_res[0], rec_res[1] | ||
for token in style_token: | ||
if token in rec_str: | ||
rec_str = rec_str.replace(token, '') | ||
if not self.recovery: | ||
box += [x1, y1] | ||
res.append({ | ||
'text': rec_str, | ||
'confidence': float(rec_conf), | ||
'text_region': box.tolist() | ||
}) | ||
if self.return_word_box: | ||
rec_word_info = rec_res[2] | ||
col_num, word_list, word_col_list, state_list = rec_word_info | ||
box = box.tolist() | ||
bbox_x_start = box[0][0] | ||
bbox_x_end = box[1][0] | ||
bbox_y_start = box[0][1] | ||
bbox_y_end = box[2][1] | ||
|
||
cell_width = (bbox_x_end - bbox_x_start)/col_num | ||
|
||
word_box_list = [] | ||
word_box_content_list = [] | ||
cn_width_list = [] | ||
cn_col_list = [] | ||
for word, word_col, state in zip(word_list, word_col_list, state_list): | ||
if state == 'cn': | ||
if len(word_col) != 1: | ||
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width | ||
char_width = char_seq_length/(len(word_col)-1) | ||
cn_width_list.append(char_width) | ||
cn_col_list += word_col | ||
word_box_content_list += word | ||
else: | ||
cell_x_start = bbox_x_start + int(word_col[0] * cell_width) | ||
cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width) | ||
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end)) | ||
word_box_list.append(cell) | ||
word_box_content_list.append("".join(word)) | ||
if len(cn_col_list) != 0: | ||
if len(cn_width_list) != 0: | ||
avg_char_width = np.mean(cn_width_list) | ||
else: | ||
avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str) | ||
for center_idx in cn_col_list: | ||
center_x = (center_idx+0.5)*cell_width | ||
cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start | ||
cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start | ||
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end)) | ||
word_box_list.append(cell) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议抽象这部分成函数到ppstructure 到utility,提升代码可读性。对了可以给代码增加一些注释,例如说明这一部分是将识别结果转化为基于字符的位置和内容信息 |
||
res.append({ | ||
'text': rec_str, | ||
'confidence': float(rec_conf), | ||
'text_region': box, | ||
'text_word': word_box_content_list, | ||
'text_word_region': word_box_list | ||
}) | ||
else: | ||
res.append({ | ||
'text': rec_str, | ||
'confidence': float(rec_conf), | ||
'text_region': box.tolist() | ||
}) | ||
res_list.append({ | ||
'type': region['label'].lower(), | ||
'bbox': [x1, y1, x2, y2], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,6 +116,7 @@ def __init__(self, args): | |
"use_space_char": args.use_space_char | ||
} | ||
self.postprocess_op = build_post_process(postprocess_params) | ||
self.postprocess_params = postprocess_params | ||
self.predictor, self.input_tensor, self.output_tensors, self.config = \ | ||
utility.create_predictor(args, 'rec', logger) | ||
self.benchmark = args.benchmark | ||
|
@@ -139,6 +140,7 @@ def __init__(self, args): | |
], | ||
warmup=0, | ||
logger=logger) | ||
self.return_word_box = args.return_word_box | ||
|
||
def resize_norm_img(self, img, max_wh_ratio): | ||
imgC, imgH, imgW = self.rec_image_shape | ||
|
@@ -616,7 +618,16 @@ def __call__(self, img_list): | |
preds = outputs | ||
else: | ||
preds = outputs[0] | ||
rec_result = self.postprocess_op(preds) | ||
if self.postprocess_params['name'] == 'CTCLabelDecode': | ||
rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box) | ||
ino_list = list(range(beg_img_no, end_img_no)) | ||
for rec_idx, rec in enumerate(rec_result): | ||
ino = ino_list[rec_idx] | ||
h, w = img_list[indices[ino]].shape[0:2] | ||
wh_ratio = w * 1.0 / h | ||
rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio) | ||
else: | ||
rec_result = self.postprocess_op(preds) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about add this in the call func of CTCLabelDecode postprocess? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块好像不太好放进去,因为计算涉及到当前的图像的宽高比,宽高比的信息只有在predict_rec.py这个层级有 |
||
for rno in range(len(rec_result)): | ||
rec_res[indices[beg_img_no + rno]] = rec_result[rno] | ||
if self.benchmark: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个多传入的参数似乎没有使用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
噢,这个应该是写错了,没有用到这个函数