diff --git a/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py b/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py index 7461eae25..c6b3e3af7 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py @@ -40,7 +40,7 @@ def __init__(self, config: Dict[str, Any]): self.rec_image_shape = config["rec_img_shape"] def __call__( - self, img_list: Union[np.ndarray, List[np.ndarray]] + self, img_list: Union[np.ndarray, List[np.ndarray]], rec_word_box=False ) -> Tuple[List[Tuple[str, float]], float]: if isinstance(img_list, np.ndarray): img_list = [img_list] @@ -58,12 +58,15 @@ def __call__( elapse = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) - max_wh_ratio = 0 + # Parameter Alignment for PaddleOCR + imgC, imgH, imgW = self.rec_image_shape[:3] + max_wh_ratio = imgW / imgH + wh_ratio_list = [] for ino in range(beg_img_no, end_img_no): h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) - + wh_ratio_list.append(wh_ratio) norm_img_batch = [] for ino in range(beg_img_no, end_img_no): norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) @@ -72,7 +75,9 @@ def __call__( starttime = time.time() preds = self.session(norm_img_batch)[0] - rec_result = self.postprocess_op(preds) + rec_result = self.postprocess_op(preds, rec_word_box, + wh_ratio_list=wh_ratio_list, + max_wh_ratio=max_wh_ratio,) for rno, one_res in enumerate(rec_result): rec_res[indices[beg_img_no + rno]] = one_res diff --git a/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py index 7d9be4836..00e3ae0be 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py @@ -16,10 +16,15 @@ def __init__( self.character = self.get_character(character, character_path) self.dict = {char: i for i, char in enumerate(self.character)} - def __call__(self, preds: np.ndarray) -> List[Tuple[str, float]]: + def __call__(self, preds: np.ndarray, rec_word_box: bool = False, **kwargs) -> List[Tuple[str, float]]: preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + text = self.decode(preds_idx, preds_prob, rec_word_box, is_remove_duplicate=True) + if rec_word_box: + for rec_idx, rec in enumerate(text): + wh_ratio = kwargs["wh_ratio_list"][rec_idx] + max_wh_ratio = kwargs["max_wh_ratio"] + rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio) return text def get_character( @@ -64,38 +69,102 @@ def insert_special_char( return character_list def decode( - self, - text_index: np.ndarray, - text_prob: Optional[np.ndarray] = None, - is_remove_duplicate: bool = False, + self, + text_index: np.ndarray, + text_prob: Optional[np.ndarray] = None, + rec_word_box: bool = False, + is_remove_duplicate: bool = False, ) -> List[Tuple[str, float]]: """convert text-index into text-label.""" result_list = [] ignored_tokens = self.get_ignored_tokens() batch_size = len(text_index) for batch_idx in range(batch_size): - char_list, conf_list = [], [] - cur_pred_ids = text_index[batch_idx] - for idx, cur_idx in enumerate(cur_pred_ids): - if cur_idx in ignored_tokens: - continue - - if is_remove_duplicate: - # only for predict - if idx > 0 and cur_pred_ids[idx - 1] == cur_idx: - continue - - char_list.append(self.character[int(cur_idx)]) - - if text_prob is None: - conf_list.append(1) - else: - conf_list.append(text_prob[batch_idx][idx]) + selection = np.ones(len(text_index[batch_idx]), dtype=bool) + if is_remove_duplicate: + selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] + for ignored_token in ignored_tokens: + selection &= text_index[batch_idx] != ignored_token + char_list = [ + self.character[text_id] for text_id in text_index[batch_idx][selection] + ] + if text_prob is not None: + conf_list = text_prob[batch_idx][selection] + else: + conf_list = [1] * len(selection) + if len(conf_list) == 0: + conf_list = [0] text = "".join(char_list) - result_list.append((text, np.mean(conf_list if any(conf_list) else [0]))) + if rec_word_box: + word_list, word_col_list, state_list = self.get_word_info( + text, selection + ) + result_list.append( + ( + text, + np.mean(conf_list).tolist(), + [ + len(text_index[batch_idx]), + word_list, + word_col_list, + state_list, + ], + ) + ) + else: + result_list.append((text, np.mean(conf_list).tolist())) return result_list + @staticmethod + def get_word_info(text: str, + selection: np.ndarray) -> Tuple[List[List[str]], List[List[int]], List[str]]: + """ + Group the decoded characters and record the corresponding decoded positions. + + Args: + text: the decoded text + selection: the bool array that identifies which columns of features are decoded as non-separated characters + Returns: + word_list: list of the grouped words + word_col_list: list of decoding positions corresponding to each character in the grouped word + state_list: list of marker to identify the type of grouping words, including two types of grouping words: + - 'cn': continous chinese characters (e.g., 你好啊) + - 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16) + """ + state = None + word_content = [] + word_col_content = [] + word_list = [] + word_col_list = [] + state_list = [] + valid_col = np.where(selection == True)[0] + + for c_i, char in enumerate(text): + if "\u4e00" <= char <= "\u9fff": + c_state = "cn" + else: + c_state = "en&num" + if state == None: + state = c_state + + if state != c_state: + if len(word_content) != 0: + word_list.append(word_content) + word_col_list.append(word_col_content) + state_list.append(state) + word_content = [] + word_col_content = [] + state = c_state + word_content.append(char) + word_col_content.append(int(valid_col[c_i])) + if len(word_content) != 0: + word_list.append(word_content) + word_col_list.append(word_col_content) + state_list.append(state) + + return word_list, word_col_list, state_list + @staticmethod def get_ignored_tokens() -> List[int]: return [0] # for ctc blank diff --git a/python/rapidocr_onnxruntime/main.py b/python/rapidocr_onnxruntime/main.py index 4c0b0e15f..f82f21eea 100644 --- a/python/rapidocr_onnxruntime/main.py +++ b/python/rapidocr_onnxruntime/main.py @@ -22,6 +22,7 @@ read_yaml, reduce_max_side, update_model_path, + word_box ) root_dir = Path(__file__).resolve().parent @@ -71,12 +72,12 @@ def __call__( use_det = self.use_det if use_det is None else use_det use_cls = self.use_cls if use_cls is None else use_cls use_rec = self.use_rec if use_rec is None else use_rec - + rec_word_box = False if kwargs: box_thresh = kwargs.get("box_thresh", 0.5) unclip_ratio = kwargs.get("unclip_ratio", 1.6) text_score = kwargs.get("text_score", 0.5) - + rec_word_box = kwargs.get("rec_word_box", False) self.text_det.postprocess_op.box_thresh = box_thresh self.text_det.postprocess_op.unclip_ratio = unclip_ratio self.text_score = text_score @@ -103,11 +104,15 @@ def __call__( img, cls_res, cls_elapse = self.text_cls(img) if use_rec: - rec_res, rec_elapse = self.text_rec(img) - + rec_res, rec_elapse = self.text_rec(img, rec_word_box) + # fix word box by fix rotate and perspective + if dt_boxes is not None and rec_res is not None and rec_word_box: + rec_res = word_box.cal_rec_boxes(dt_boxes, img, rec_res) + for i, rec_res_i in enumerate(rec_res): + if rec_res_i[3]: + rec_res_i[3] = self._get_origin_points(rec_res_i[3], op_record, raw_h, raw_w).astype(np.int32).tolist() if dt_boxes is not None and rec_res is not None: dt_boxes = self._get_origin_points(dt_boxes, op_record, raw_h, raw_w) - ocr_res = self.get_final_res( dt_boxes, cls_res, rec_res, det_elapse, cls_elapse, rec_elapse ) @@ -237,7 +242,7 @@ def _get_origin_points( raw_h: int, raw_w: int, ) -> np.ndarray: - dt_boxes_array = np.array(dt_boxes) + dt_boxes_array = np.array(dt_boxes).astype(np.float32) for op in reversed(list(op_record.keys())): v = op_record[op] if "padding" in op: @@ -263,7 +268,7 @@ def get_final_res( self, dt_boxes: Optional[List[np.ndarray]], cls_res: Optional[List[List[Union[str, float]]]], - rec_res: Optional[List[Tuple[str, float]]], + rec_res: Optional[List[Tuple[str, float, List[Union[str, float]]]]], det_elapse: float, cls_elapse: float, rec_elapse: float, @@ -285,7 +290,7 @@ def get_final_res( return None, None ocr_res = [ - [box.tolist(), res[0], res[1]] for box, res in zip(dt_boxes, rec_res) + [box.tolist(), *res] for box, res in zip(dt_boxes, rec_res) ], [det_elapse, cls_elapse, rec_elapse] return ocr_res @@ -299,7 +304,7 @@ def filter_result( filter_boxes, filter_rec_res = [], [] for box, rec_reuslt in zip(dt_boxes, rec_res): - text, score = rec_reuslt + text, score = rec_reuslt[0], rec_reuslt[1] if float(score) >= self.text_score: filter_boxes.append(box) filter_rec_res.append(rec_reuslt) diff --git a/python/rapidocr_onnxruntime/utils/word_box.py b/python/rapidocr_onnxruntime/utils/word_box.py new file mode 100644 index 000000000..4779a6e98 --- /dev/null +++ b/python/rapidocr_onnxruntime/utils/word_box.py @@ -0,0 +1,204 @@ +import copy +import math +from typing import Optional, List, Tuple, Union, Any +import numpy as np +import cv2 + + +def cal_rec_boxes( + dt_boxes: Optional[List[np.ndarray]], + crop_imgs: Optional[List[np.ndarray]], + rec_res: Optional[List[Any]]): + res = [] + for i, (box, rec_res) in enumerate(zip(dt_boxes, rec_res)): + direction = "w" + img_crop_width = int( + max( + np.linalg.norm(box[0] - box[1]), + np.linalg.norm(box[2] - box[3]), + ) + ) + img_crop_height = int( + max( + np.linalg.norm(box[0] - box[3]), + np.linalg.norm(box[1] - box[2]), + ) + ) + if img_crop_height * 1.0 / img_crop_width >= 1.5: + direction = "h" + + rec_str, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2] + crop_img = crop_imgs[i] + h, w = crop_img.shape[:2] + crop_img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]]) + word_box_content_list, word_box_list = cal_ocr_word_box( + rec_str, crop_img_box, rec_word_info + ) + # fix word box overlap + adjust_box_overlap(word_box_list) + word_box_list = reverse_rotate_crop_image(copy.deepcopy(box), word_box_list, direction) + + res.append([rec_res[0], rec_res[1], word_box_content_list, word_box_list]) + return res + + +def adjust_box_overlap(word_box_list): + # 调整bbox有重叠的地方 + for i in range(len(word_box_list) - 1): + cur, nxt = word_box_list[i], word_box_list[i + 1] + if cur[1][0] > nxt[0][0]: # 有交集 + distance = abs(cur[1][0] - nxt[0][0]) + cur[1][0] -= distance / 2 + cur[2][0] -= distance / 2 + nxt[0][0] += distance / 2 + nxt[3][0] += distance / 2 + + +# 绕pointx,pointy顺时针旋转 +def s_rotate(angle, valuex, valuey, pointx, pointy): + valuex = np.array(valuex) + valuey = np.array(valuey) + sRotatex = (valuex - pointx) * math.cos(angle) + (valuey - pointy) * math.sin(angle) + pointx + sRotatey = (valuey - pointy) * math.cos(angle) - (valuex - pointx) * math.sin(angle) + pointy + return [sRotatex, sRotatey] + + +# ———————————————— +# 版权声明:本文为CSDN博主「星夜孤帆」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 +# 原文链接:https://blog.csdn.net/qq_38826019/article/details/84233397 + +def reverse_rotate_crop_image(bbox_points, word_points_list, direction='w'): + """ + get_rotate_crop_image的逆操作 + img为原图 + part_img为crop后的图 + bbox_points为part_img中对应在原图的bbox, 四个点,左上,右上,右下,左下 + part_points为在part_img中的点[(x, y), (x, y)] + """ + # np.rot + bbox_points = np.float32(bbox_points) + left = int(np.min(bbox_points[:, 0])) + top = int(np.min(bbox_points[:, 1])) + bbox_points[:, 0] = bbox_points[:, 0] - left + bbox_points[:, 1] = bbox_points[:, 1] - top + img_crop_width = int(np.linalg.norm(bbox_points[0] - bbox_points[1])) + img_crop_height = int(np.linalg.norm(bbox_points[0] - bbox_points[3])) + pts_std = np.array( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ).astype(np.float32) + M = cv2.getPerspectiveTransform(bbox_points, pts_std) + _, IM = cv2.invert(M) + + new_word_points_list = [] + + for word_points in word_points_list: + new_word_points = [] + for point in word_points: + new_point = point + if direction == 'h': + new_point = s_rotate(math.radians(-90), new_point[0], new_point[1], 0, 0) + new_point[0] = new_point[0] + img_crop_width + + p = np.float32(new_point + [1]) + x, y, z = np.dot(IM, p) + new_point = [x / z, y / z] + + new_point = [int(new_point[0] + left), int(new_point[1] + top)] + new_word_points.append(new_point) + new_word_points = order_points(new_word_points) + new_word_points_list.append(new_word_points) + return new_word_points_list + + +def cal_ocr_word_box(rec_str: str, + box: np.ndarray, + rec_word_info: List[Tuple[str, List[int]]]): + """Calculate the detection frame for each word based on the results of recognition and detection of ocr""" + + col_num, word_list, word_col_list, state_list = rec_word_info + box = box.tolist() + bbox_x_start = box[0][0] + bbox_x_end = box[1][0] + bbox_y_start = box[0][1] + bbox_y_end = box[2][1] + + cell_width = (bbox_x_end - bbox_x_start) / col_num + word_box_list = [] + word_box_content_list = [] + cn_width_list = [] + cn_col_list = [] + for word, word_col, state in zip(word_list, word_col_list, state_list): + if state == "cn": + if len(word_col) != 1: + char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width + char_width = char_seq_length / (len(word_col) - 1) + cn_width_list.append(char_width) + cn_col_list += word_col + word_box_content_list += word + else: + cell_x_start = bbox_x_start + int(word_col[0] * cell_width) + cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width) + cell = [ + [cell_x_start, bbox_y_start], + [cell_x_end, bbox_y_start], + [cell_x_end, bbox_y_end], + [cell_x_start, bbox_y_end], + ] + word_box_list.append(cell) + word_box_content_list.append("".join(word)) + if len(cn_col_list) != 0: + if len(cn_width_list) != 0: + avg_char_width = np.mean(cn_width_list) + else: + avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_str) + for center_idx in cn_col_list: + center_x = (center_idx + 0.5) * cell_width + cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start + cell_x_end = ( + min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start) + + bbox_x_start + ) + cell = [ + [cell_x_start, bbox_y_start], + [cell_x_end, bbox_y_start], + [cell_x_end, bbox_y_end], + [cell_x_start, bbox_y_end], + ] + word_box_list.append(cell) + sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0]) + return word_box_content_list, sorted_word_box_list + + +def order_points(box): + '''矩形框顺序排列 + :param box: numpy.array, shape=(4, 2) + :return: + ''' + box = np.array(box).reshape((-1, 2)) + center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1]) + if np.any(box[:, 0] == center_x) and np.any(box[:, 1] == center_y): # 有两点横坐标相等,有两点纵坐标相等,菱形 + p1 = box[np.where(box[:, 0] == np.min(box[:, 0]))] + p2 = box[np.where(box[:, 1] == np.min(box[:, 1]))] + p3 = box[np.where(box[:, 0] == np.max(box[:, 0]))] + p4 = box[np.where(box[:, 1] == np.max(box[:, 1]))] + elif np.all(box[:, 0] == center_x): # 四个点的横坐标都相同 + y_sort = np.argsort(box[:, 1]) + p1 = box[y_sort[0]] + p2 = box[y_sort[1]] + p3 = box[y_sort[2]] + p4 = box[y_sort[3]] + elif np.any(box[:, 0] == center_x) and np.all(box[:, 1] != center_y): # 只有两点横坐标相等,先上下再左右 + p12, p34 = box[np.where(box[:, 1] < center_y)], box[np.where(box[:, 1] > center_y)] + p1, p2 = p12[np.where(p12[:, 0] == np.min(p12[:, 0]))], p12[np.where(p12[:, 0] == np.max(p12[:, 0]))] + p3, p4 = p34[np.where(p34[:, 0] == np.max(p34[:, 0]))], p34[np.where(p34[:, 0] == np.min(p34[:, 0]))] + else: # 只有两点纵坐标相等,或者是没有相等的,先左右再上下 + p14, p23 = box[np.where(box[:, 0] < center_x)], box[np.where(box[:, 0] > center_x)] + p1, p4 = p14[np.where(p14[:, 1] == np.min(p14[:, 1]))], p14[np.where(p14[:, 1] == np.max(p14[:, 1]))] + p2, p3 = p23[np.where(p23[:, 1] == np.min(p23[:, 1]))], p23[np.where(p23[:, 1] == np.max(p23[:, 1]))] + + return np.array([p1, p2, p3, p4]).reshape((-1, 2)).tolist() diff --git a/python/tests/test_files/text_vertical_words.png b/python/tests/test_files/text_vertical_words.png new file mode 100644 index 000000000..6fcd51e18 Binary files /dev/null and b/python/tests/test_files/text_vertical_words.png differ diff --git a/python/tests/test_ort.py b/python/tests/test_ort.py index 969707502..8bbf7ed9e 100644 --- a/python/tests/test_ort.py +++ b/python/tests/test_ort.py @@ -4,6 +4,7 @@ import logging import sys from pathlib import Path +from typing import List import cv2 import numpy as np @@ -88,7 +89,7 @@ def test_letterbox_like(img_name, gt_len, gt_first_len): result, _ = engine(img_path) assert len(result) == gt_len - assert result[0][1] == gt_first_len + assert result[0][1].lower() == gt_first_len.lower() def test_only_det(): @@ -222,3 +223,22 @@ def test_input_three_ndim_one_channel(): assert result is not None assert result[0][1] == "正品促销" assert len(result) == 17 + +# +@pytest.mark.parametrize( + "img_name,words", + [ + ( + "black_font_color_transparent.png", + ['我', '是', '中', '国', '人'], + ), + ( + "text_vertical_words.png", + ['已', '取', '之', '時', '不', '參', '一', '人', '見', '而'], + ), + ], +) +def test_word_ocr(img_name: str, words: List[str]): + img_path = tests_dir / img_name + result, _ = engine(img_path, rec_word_box=True) + assert result[0][3] == words \ No newline at end of file