Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CV套件建设专项活动 - 文字识别返回单字识别坐标 #10515

Merged
merged 4 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,55 @@ def pred_reverse(self, pred):

return ''.join(pred_re[::-1])

def add_special_char(self, dict_character):
def add_special_char(self, text, dict_character):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个多传入的参数似乎没有使用?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

噢,这个应该是写错了,没有用到这个函数

return dict_character

def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
def get_word_info(self, text, selection):
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection==True)[0]

for c_i, char in enumerate(text):
if '\u4e00' <= char <= '\u9fff':
c_state = 'cn'
elif bool(re.search('[a-zA-Z0-9]', char)):
c_state = 'en&num'
else:
c_state = 'splitter'

if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping float number
c_state = 'en&num'
if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art'
c_state = 'en&num'

if state == None:
state = c_state

if state != c_state:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state

if state != "splitter":
word_content.append(char)
word_col_content.append(valid_col[c_i])

if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)

return word_list, word_col_list, state_list

def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
Expand Down Expand Up @@ -95,8 +140,12 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):

if self.reverse: # for arabic rec
text = self.pred_reverse(text)

result_list.append((text, np.mean(conf_list).tolist()))

if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(text, selection)
result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list]))
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list

def get_ignored_tokens(self):
Expand All @@ -111,14 +160,14 @@ def __init__(self, character_dict_path=None, use_space_char=False,
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)

def __call__(self, preds, label=None, *args, **kwargs):
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box)
if label is None:
return text
label = self.decode(label)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉在预测里面的代码可以放在这里,因为所有的CTCLabelDecode都是在rec后处理中。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加这里的话可能要把图像的宽高比也作为参数传进来哈哈

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能不能利用kwargs传入呢?

Expand Down
63 changes: 57 additions & 6 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self, args):
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
self.kie_predictor = SerRePredictor(args)

self.return_word_box = args.return_word_box

def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict = {
'image_orientation': 0,
Expand Down Expand Up @@ -156,17 +158,66 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res
rec_str, rec_conf = rec_res[0], rec_res[1]
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
if not self.recovery:
box += [x1, y1]
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
if self.return_word_box:
rec_word_info = rec_res[2]
col_num, word_list, word_col_list, state_list = rec_word_info
box = box.tolist()
bbox_x_start = box[0][0]
bbox_x_end = box[1][0]
bbox_y_start = box[0][1]
bbox_y_end = box[2][1]

cell_width = (bbox_x_end - bbox_x_start)/col_num

word_box_list = []
word_box_content_list = []
cn_width_list = []
cn_col_list = []
for word, word_col, state in zip(word_list, word_col_list, state_list):
if state == 'cn':
if len(word_col) != 1:
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
char_width = char_seq_length/(len(word_col)-1)
cn_width_list.append(char_width)
cn_col_list += word_col
word_box_content_list += word
else:
cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width)
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)
word_box_content_list.append("".join(word))
if len(cn_col_list) != 0:
if len(cn_width_list) != 0:
avg_char_width = np.mean(cn_width_list)
else:
avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str)
for center_idx in cn_col_list:
center_x = (center_idx+0.5)*cell_width
cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start
cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议抽象这部分成函数到ppstructure 到utility,提升代码可读性。对了可以给代码增加一些注释,例如说明这一部分是将识别结果转化为基于字符的位置和内容信息

res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box,
'text_word': word_box_content_list,
'text_word_region': word_box_list
})
else:
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
res_list.append({
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
Expand Down
15 changes: 14 additions & 1 deletion ppstructure/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args

import math

def init_args():
parser = infer_args()
Expand Down Expand Up @@ -166,6 +166,19 @@ def draw_structure_result(image, result, font_path):
txts.append(text_result['text'])
scores.append(text_result['confidence'])

if 'text_word_region' in text_result:
for word_region in text_result['text_word_region']:
char_box = word_region
box_height = int(
math.sqrt((char_box[0][0] - char_box[3][0])**2 + (char_box[0][1] - char_box[3][1])**2))
box_width = int(
math.sqrt((char_box[0][0] - char_box[1][0])**2 + (char_box[0][1] - char_box[1][1])**2))
if box_height == 0 or box_width == 0:
continue
boxes.append(word_region)
txts.append("")
scores.append(1.0)

im_show = draw_ocr_box_txt(
img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
return im_show
13 changes: 12 additions & 1 deletion tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(self, args):
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.postprocess_params = postprocess_params
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
self.benchmark = args.benchmark
Expand All @@ -139,6 +140,7 @@ def __init__(self, args):
],
warmup=0,
logger=logger)
self.return_word_box = args.return_word_box

def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
Expand Down Expand Up @@ -616,7 +618,16 @@ def __call__(self, img_list):
preds = outputs
else:
preds = outputs[0]
rec_result = self.postprocess_op(preds)
if self.postprocess_params['name'] == 'CTCLabelDecode':
rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box)
ino_list = list(range(beg_img_no, end_img_no))
for rec_idx, rec in enumerate(rec_result):
ino = ino_list[rec_idx]
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio)
else:
rec_result = self.postprocess_op(preds)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about add this in the call func of CTCLabelDecode postprocess?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块好像不太好放进去,因为计算涉及到当前的图像的宽高比,宽高比的信息只有在predict_rec.py这个层级有

for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
if self.benchmark:
Expand Down
2 changes: 1 addition & 1 deletion tools/infer/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __call__(self, img, cls=True):
rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
text, score = rec_result[0], rec_result[1]
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
Expand Down
4 changes: 4 additions & 0 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def init_args():

parser.add_argument("--show_log", type=str2bool, default=True)
parser.add_argument("--use_onnx", type=str2bool, default=False)

# extended function
parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character')

return parser


Expand Down