Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CV套件建设专项活动 - 文字识别返回单字识别坐标 #10515

Merged
merged 4 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 73 additions & 5 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,66 @@ def pred_reverse(self, pred):
def add_special_char(self, dict_character):
return dict_character

def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
def get_word_info(self, text, selection):
"""
Group the decoded characters and record the corresponding decoded positions.

Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
The remaining characters in text are treated as separators between groups (e.g., space, '(', ')', etc.).
"""
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection==True)[0]

for c_i, char in enumerate(text):
if '\u4e00' <= char <= '\u9fff':
c_state = 'cn'
elif bool(re.search('[a-zA-Z0-9]', char)):
c_state = 'en&num'
else:
c_state = 'splitter'

if char == '.' and state == 'en&num' and c_i + 1 < len(text) and bool(re.search('[0-9]', text[c_i+1])): # grouping floting number
c_state = 'en&num'
if char == '-' and state == "en&num": # grouping word with '-', such as 'state-of-the-art'
c_state = 'en&num'

if state == None:
state = c_state

if state != c_state:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state

if state != "splitter":
word_content.append(char)
word_col_content.append(valid_col[c_i])

if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)

return word_list, word_col_list, state_list

def decode(self, text_index, text_prob=None, is_remove_duplicate=False, return_word_box=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
Expand Down Expand Up @@ -95,8 +154,12 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):

if self.reverse: # for arabic rec
text = self.pred_reverse(text)

result_list.append((text, np.mean(conf_list).tolist()))

if return_word_box:
word_list, word_col_list, state_list = self.get_word_info(text, selection)
result_list.append((text, np.mean(conf_list).tolist(), [len(text_index[batch_idx]), word_list, word_col_list, state_list]))
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list

def get_ignored_tokens(self):
Expand All @@ -111,14 +174,19 @@ def __init__(self, character_dict_path=None, use_space_char=False,
super(CTCLabelDecode, self).__init__(character_dict_path,
use_space_char)

def __call__(self, preds, label=None, *args, **kwargs):
def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
if isinstance(preds, tuple) or isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True, return_word_box=return_word_box)
if return_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs['wh_ratio_list'][rec_idx]
max_wh_ratio = kwargs['max_wh_ratio']
rec[2][0] = rec[2][0]*(wh_ratio/max_wh_ratio)
if label is None:
return text
label = self.decode(label)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉在预测里面的代码可以放在这里,因为所有的CTCLabelDecode都是在rec后处理中。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加这里的话可能要把图像的宽高比也作为参数传进来哈哈

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个能不能利用kwargs传入呢?

Expand Down
26 changes: 19 additions & 7 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args, draw_structure_result
from ppstructure.utility import parse_args, draw_structure_result, cal_ocr_word_box

logger = get_logger()

Expand Down Expand Up @@ -79,6 +79,8 @@ def __init__(self, args):
from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
self.kie_predictor = SerRePredictor(args)

self.return_word_box = args.return_word_box

def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict = {
'image_orientation': 0,
Expand Down Expand Up @@ -156,17 +158,27 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
]
res = []
for box, rec_res in zip(filter_boxes, filter_rec_res):
rec_str, rec_conf = rec_res
rec_str, rec_conf = rec_res[0], rec_res[1]
for token in style_token:
if token in rec_str:
rec_str = rec_str.replace(token, '')
if not self.recovery:
box += [x1, y1]
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
if self.return_word_box:
word_box_content_list, word_box_list = cal_ocr_word_box(rec_str, box, rec_res[2])
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist(),
'text_word': word_box_content_list,
'text_word_region': word_box_list
})
else:
res.append({
'text': rec_str,
'confidence': float(rec_conf),
'text_region': box.tolist()
})
res_list.append({
'type': region['label'].lower(),
'bbox': [x1, y1, x2, y2],
Expand Down
59 changes: 58 additions & 1 deletion ppstructure/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from tools.infer.utility import draw_ocr_box_txt, str2bool, str2int_tuple, init_args as infer_args

import math

def init_args():
parser = infer_args()
Expand Down Expand Up @@ -166,6 +166,63 @@ def draw_structure_result(image, result, font_path):
txts.append(text_result['text'])
scores.append(text_result['confidence'])

if 'text_word_region' in text_result:
for word_region in text_result['text_word_region']:
char_box = word_region
box_height = int(
math.sqrt((char_box[0][0] - char_box[3][0])**2 + (char_box[0][1] - char_box[3][1])**2))
box_width = int(
math.sqrt((char_box[0][0] - char_box[1][0])**2 + (char_box[0][1] - char_box[1][1])**2))
if box_height == 0 or box_width == 0:
continue
boxes.append(word_region)
txts.append("")
scores.append(1.0)

im_show = draw_ocr_box_txt(
img_layout, boxes, txts, scores, font_path=font_path, drop_score=0)
return im_show

def cal_ocr_word_box(rec_str, box, rec_word_info):
''' Calculate the detection frame for each word based on the results of recognition and detection of ocr'''

col_num, word_list, word_col_list, state_list = rec_word_info
box = box.tolist()
bbox_x_start = box[0][0]
bbox_x_end = box[1][0]
bbox_y_start = box[0][1]
bbox_y_end = box[2][1]

cell_width = (bbox_x_end - bbox_x_start)/col_num

word_box_list = []
word_box_content_list = []
cn_width_list = []
cn_col_list = []
for word, word_col, state in zip(word_list, word_col_list, state_list):
if state == 'cn':
if len(word_col) != 1:
char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width
char_width = char_seq_length/(len(word_col)-1)
cn_width_list.append(char_width)
cn_col_list += word_col
word_box_content_list += word
else:
cell_x_start = bbox_x_start + int(word_col[0] * cell_width)
cell_x_end = bbox_x_start + int((word_col[-1]+1) * cell_width)
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)
word_box_content_list.append("".join(word))
if len(cn_col_list) != 0:
if len(cn_width_list) != 0:
avg_char_width = np.mean(cn_width_list)
else:
avg_char_width = (bbox_x_end - bbox_x_start)/len(rec_str)
for center_idx in cn_col_list:
center_x = (center_idx+0.5)*cell_width
cell_x_start = max(int(center_x - avg_char_width/2), 0) + bbox_x_start
cell_x_end = min(int(center_x + avg_char_width/2), bbox_x_end-bbox_x_start) + bbox_x_start
cell = ((cell_x_start, bbox_y_start), (cell_x_end, bbox_y_start), (cell_x_end, bbox_y_end), (cell_x_start, bbox_y_end))
word_box_list.append(cell)

return word_box_content_list, word_box_list
10 changes: 8 additions & 2 deletions tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(self, args):
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.postprocess_params = postprocess_params
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
self.benchmark = args.benchmark
Expand All @@ -139,6 +140,7 @@ def __init__(self, args):
],
warmup=0,
logger=logger)
self.return_word_box = args.return_word_box

def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
Expand Down Expand Up @@ -407,11 +409,12 @@ def __call__(self, img_list):
valid_ratios = []
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
# max_wh_ratio = 0
wh_ratio_list = []
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
wh_ratio_list.append(wh_ratio)
for ino in range(beg_img_no, end_img_no):
if self.rec_algorithm == "SAR":
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
Expand Down Expand Up @@ -616,7 +619,10 @@ def __call__(self, img_list):
preds = outputs
else:
preds = outputs[0]
rec_result = self.postprocess_op(preds)
if self.postprocess_params['name'] == 'CTCLabelDecode':
rec_result = self.postprocess_op(preds, return_word_box=self.return_word_box, wh_ratio_list=wh_ratio_list, max_wh_ratio=max_wh_ratio)
else:
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
if self.benchmark:
Expand Down
2 changes: 1 addition & 1 deletion tools/infer/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __call__(self, img, cls=True):
rec_res)
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
text, score = rec_result[0], rec_result[1]
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
Expand Down
4 changes: 4 additions & 0 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def init_args():

parser.add_argument("--show_log", type=str2bool, default=True)
parser.add_argument("--use_onnx", type=str2bool, default=False)

# extended function
parser.add_argument("--return_word_box", type=str2bool, default=False, help='Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery')

return parser


Expand Down