diff --git a/python/demo.py b/python/demo.py index cd5cf1687..9b886ca8b 100644 --- a/python/demo.py +++ b/python/demo.py @@ -8,20 +8,25 @@ # from rapidocr_paddle import RapidOCR, VisRes # from rapidocr_openvino import RapidOCR, VisRes - engine = RapidOCR() vis = VisRes() -image_path = "tests/test_files/black_font_color_transparent.png" +image_path = "tests/test_files/ch_en_num.jpg" with open(image_path, "rb") as f: img = f.read() -result, elapse_list = engine(img) +result, elapse_list = engine(img, return_word_box=True) print(result) print(elapse_list) -boxes, txts, scores = list(zip(*result)) +(boxes, txts, scores, words_boxes, words) = list(zip(*result)) font_path = "resources/fonts/FZYTK.TTF" vis_img = vis(img, boxes, txts, scores, font_path) cv2.imwrite("vis.png", vis_img) + +words_boxes = sum(words_boxes, []) +words_all = sum(words, []) +words_scores = [1.0] * len(words_boxes) +vis_img = vis(img, words_boxes, words_all, words_scores, font_path) +cv2.imwrite("vis_single.png", vis_img) diff --git a/python/rapidocr_onnxruntime/cal_rec_boxes/__init__.py b/python/rapidocr_onnxruntime/cal_rec_boxes/__init__.py new file mode 100644 index 000000000..5127715f9 --- /dev/null +++ b/python/rapidocr_onnxruntime/cal_rec_boxes/__init__.py @@ -0,0 +1,4 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +from .main import CalRecBoxes diff --git a/python/rapidocr_onnxruntime/cal_rec_boxes/main.py b/python/rapidocr_onnxruntime/cal_rec_boxes/main.py new file mode 100644 index 000000000..5e0d71b7e --- /dev/null +++ b/python/rapidocr_onnxruntime/cal_rec_boxes/main.py @@ -0,0 +1,260 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL / Joker1212 +# @Contact: liekkaskono@163.com +import copy +import math +from typing import Any, List, Optional, Tuple + +import cv2 +import numpy as np + + +class CalRecBoxes: + """计算识别文字的汉字单字和英文单词的坐标框。代码借鉴自PaddlePaddle/PaddleOCR和fanqie03/char-detection""" + + def __init__(self): + pass + + def __call__( + self, + imgs: Optional[List[np.ndarray]], + dt_boxes: Optional[List[np.ndarray]], + rec_res: Optional[List[Any]], + ): + res = [] + for img, box, rec_res in zip(imgs, dt_boxes, rec_res): + direction = self.get_box_direction(box) + + rec_txt, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2] + h, w = img.shape[:2] + img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]]) + word_box_content_list, word_box_list = self.cal_ocr_word_box( + rec_txt, img_box, rec_word_info + ) + word_box_list = self.adjust_box_overlap(copy.deepcopy(word_box_list)) + word_box_list = self.reverse_rotate_crop_image( + copy.deepcopy(box), word_box_list, direction + ) + res.append([rec_txt, rec_conf, word_box_list, word_box_content_list]) + return res + + @staticmethod + def get_box_direction(box: np.ndarray) -> str: + direction = "w" + img_crop_width = int( + max( + np.linalg.norm(box[0] - box[1]), + np.linalg.norm(box[2] - box[3]), + ) + ) + img_crop_height = int( + max( + np.linalg.norm(box[0] - box[3]), + np.linalg.norm(box[1] - box[2]), + ) + ) + if img_crop_height * 1.0 / img_crop_width >= 1.5: + direction = "h" + return direction + + @staticmethod + def cal_ocr_word_box( + rec_txt: str, box: np.ndarray, rec_word_info: List[Tuple[str, List[int]]] + ) -> Tuple[List[str], List[List[int]]]: + """Calculate the detection frame for each word based on the results of recognition and detection of ocr + 汉字坐标是单字的 + 英语坐标是单词级别的 + """ + + col_num, word_list, word_col_list, state_list = rec_word_info + box = box.tolist() + bbox_x_start = box[0][0] + bbox_x_end = box[1][0] + bbox_y_start = box[0][1] + bbox_y_end = box[2][1] + + cell_width = (bbox_x_end - bbox_x_start) / col_num + word_box_list = [] + word_box_content_list = [] + cn_width_list = [] + cn_col_list = [] + for word, word_col, state in zip(word_list, word_col_list, state_list): + if state == "cn": + if len(word_col) != 1: + char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width + char_width = char_seq_length / (len(word_col) - 1) + cn_width_list.append(char_width) + cn_col_list += word_col + word_box_content_list += word + else: + cell_x_start = bbox_x_start + int(word_col[0] * cell_width) + cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width) + cell = [ + [cell_x_start, bbox_y_start], + [cell_x_end, bbox_y_start], + [cell_x_end, bbox_y_end], + [cell_x_start, bbox_y_end], + ] + word_box_list.append(cell) + word_box_content_list.append("".join(word)) + + if len(cn_col_list) != 0: + if len(cn_width_list) != 0: + avg_char_width = np.mean(cn_width_list) + else: + avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_txt) + + for center_idx in cn_col_list: + center_x = (center_idx + 0.5) * cell_width + cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start + cell_x_end = ( + min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start) + + bbox_x_start + ) + cell = [ + [cell_x_start, bbox_y_start], + [cell_x_end, bbox_y_start], + [cell_x_end, bbox_y_end], + [cell_x_start, bbox_y_end], + ] + word_box_list.append(cell) + sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0]) + return word_box_content_list, sorted_word_box_list + + @staticmethod + def adjust_box_overlap( + word_box_list: List[List[List[int]]], + ) -> List[List[List[int]]]: + # 调整bbox有重叠的地方 + for i in range(len(word_box_list) - 1): + cur, nxt = word_box_list[i], word_box_list[i + 1] + if cur[1][0] > nxt[0][0]: # 有交集 + distance = abs(cur[1][0] - nxt[0][0]) + cur[1][0] -= distance / 2 + cur[2][0] -= distance / 2 + nxt[0][0] += distance / 2 + nxt[3][0] += distance / 2 + return word_box_list + + def reverse_rotate_crop_image( + self, + bbox_points: np.ndarray, + word_points_list: List[List[List[int]]], + direction: str = "w", + ) -> List[List[List[int]]]: + """ + get_rotate_crop_image的逆操作 + img为原图 + part_img为crop后的图 + bbox_points为part_img中对应在原图的bbox, 四个点,左上,右上,右下,左下 + part_points为在part_img中的点[(x, y), (x, y)] + """ + bbox_points = np.float32(bbox_points) + + left = int(np.min(bbox_points[:, 0])) + top = int(np.min(bbox_points[:, 1])) + bbox_points[:, 0] = bbox_points[:, 0] - left + bbox_points[:, 1] = bbox_points[:, 1] - top + + img_crop_width = int(np.linalg.norm(bbox_points[0] - bbox_points[1])) + img_crop_height = int(np.linalg.norm(bbox_points[0] - bbox_points[3])) + + pts_std = np.array( + [ + [0, 0], + [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height], + ] + ).astype(np.float32) + M = cv2.getPerspectiveTransform(bbox_points, pts_std) + _, IM = cv2.invert(M) + + new_word_points_list = [] + for word_points in word_points_list: + new_word_points = [] + for point in word_points: + new_point = point + if direction == "h": + new_point = self.s_rotate( + math.radians(-90), new_point[0], new_point[1], 0, 0 + ) + new_point[0] = new_point[0] + img_crop_width + + p = np.float32(new_point + [1]) + x, y, z = np.dot(IM, p) + new_point = [x / z, y / z] + + new_point = [int(new_point[0] + left), int(new_point[1] + top)] + new_word_points.append(new_point) + new_word_points = self.order_points(new_word_points) + new_word_points_list.append(new_word_points) + return new_word_points_list + + @staticmethod + def s_rotate(angle, valuex, valuey, pointx, pointy): + """绕pointx,pointy顺时针旋转 + https://blog.csdn.net/qq_38826019/article/details/84233397 + """ + valuex = np.array(valuex) + valuey = np.array(valuey) + sRotatex = ( + (valuex - pointx) * math.cos(angle) + + (valuey - pointy) * math.sin(angle) + + pointx + ) + sRotatey = ( + (valuey - pointy) * math.cos(angle) + - (valuex - pointx) * math.sin(angle) + + pointy + ) + return [sRotatex, sRotatey] + + @staticmethod + def order_points(box: List[List[int]]) -> List[List[int]]: + """矩形框顺序排列""" + box = np.array(box).reshape((-1, 2)) + center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1]) + if np.any(box[:, 0] == center_x) and np.any( + box[:, 1] == center_y + ): # 有两点横坐标相等,有两点纵坐标相等,菱形 + p1 = box[np.where(box[:, 0] == np.min(box[:, 0]))] + p2 = box[np.where(box[:, 1] == np.min(box[:, 1]))] + p3 = box[np.where(box[:, 0] == np.max(box[:, 0]))] + p4 = box[np.where(box[:, 1] == np.max(box[:, 1]))] + elif np.all(box[:, 0] == center_x): # 四个点的横坐标都相同 + y_sort = np.argsort(box[:, 1]) + p1 = box[y_sort[0]] + p2 = box[y_sort[1]] + p3 = box[y_sort[2]] + p4 = box[y_sort[3]] + elif np.any(box[:, 0] == center_x) and np.all( + box[:, 1] != center_y + ): # 只有两点横坐标相等,先上下再左右 + p12, p34 = ( + box[np.where(box[:, 1] < center_y)], + box[np.where(box[:, 1] > center_y)], + ) + p1, p2 = ( + p12[np.where(p12[:, 0] == np.min(p12[:, 0]))], + p12[np.where(p12[:, 0] == np.max(p12[:, 0]))], + ) + p3, p4 = ( + p34[np.where(p34[:, 0] == np.max(p34[:, 0]))], + p34[np.where(p34[:, 0] == np.min(p34[:, 0]))], + ) + else: # 只有两点纵坐标相等,或者是没有相等的,先左右再上下 + p14, p23 = ( + box[np.where(box[:, 0] < center_x)], + box[np.where(box[:, 0] > center_x)], + ) + p1, p4 = ( + p14[np.where(p14[:, 1] == np.min(p14[:, 1]))], + p14[np.where(p14[:, 1] == np.max(p14[:, 1]))], + ) + p2, p3 = ( + p23[np.where(p23[:, 1] == np.min(p23[:, 1]))], + p23[np.where(p23[:, 1] == np.max(p23[:, 1]))], + ) + + return np.array([p1, p2, p3, p4]).reshape((-1, 2)).tolist() diff --git a/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py b/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py index c6b3e3af7..e823ea655 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py @@ -18,6 +18,7 @@ import cv2 import numpy as np + from rapidocr_onnxruntime.utils import OrtInferSession, read_yaml from .utils import CTCLabelDecode @@ -40,7 +41,9 @@ def __init__(self, config: Dict[str, Any]): self.rec_image_shape = config["rec_img_shape"] def __call__( - self, img_list: Union[np.ndarray, List[np.ndarray]], rec_word_box=False + self, + img_list: Union[np.ndarray, List[np.ndarray]], + return_word_box: bool = False, ) -> Tuple[List[Tuple[str, float]], float]: if isinstance(img_list, np.ndarray): img_list = [img_list] @@ -58,6 +61,7 @@ def __call__( elapse = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) + # Parameter Alignment for PaddleOCR imgC, imgH, imgW = self.rec_image_shape[:3] max_wh_ratio = imgW / imgH @@ -67,6 +71,7 @@ def __call__( wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) wh_ratio_list.append(wh_ratio) + norm_img_batch = [] for ino in range(beg_img_no, end_img_no): norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) @@ -75,9 +80,12 @@ def __call__( starttime = time.time() preds = self.session(norm_img_batch)[0] - rec_result = self.postprocess_op(preds, rec_word_box, - wh_ratio_list=wh_ratio_list, - max_wh_ratio=max_wh_ratio,) + rec_result = self.postprocess_op( + preds, + return_word_box, + wh_ratio_list=wh_ratio_list, + max_wh_ratio=max_wh_ratio, + ) for rno, one_res in enumerate(rec_result): rec_res[indices[beg_img_no + rno]] = one_res diff --git a/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py b/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py index 00e3ae0be..83b89518d 100644 --- a/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py +++ b/python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py @@ -16,11 +16,15 @@ def __init__( self.character = self.get_character(character, character_path) self.dict = {char: i for i, char in enumerate(self.character)} - def __call__(self, preds: np.ndarray, rec_word_box: bool = False, **kwargs) -> List[Tuple[str, float]]: + def __call__( + self, preds: np.ndarray, return_word_box: bool = False, **kwargs + ) -> List[Tuple[str, float]]: preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) - text = self.decode(preds_idx, preds_prob, rec_word_box, is_remove_duplicate=True) - if rec_word_box: + text = self.decode( + preds_idx, preds_prob, return_word_box, is_remove_duplicate=True + ) + if return_word_box: for rec_idx, rec in enumerate(text): wh_ratio = kwargs["wh_ratio_list"][rec_idx] max_wh_ratio = kwargs["max_wh_ratio"] @@ -69,11 +73,11 @@ def insert_special_char( return character_list def decode( - self, - text_index: np.ndarray, - text_prob: Optional[np.ndarray] = None, - rec_word_box: bool = False, - is_remove_duplicate: bool = False, + self, + text_index: np.ndarray, + text_prob: Optional[np.ndarray] = None, + return_word_box: bool = False, + is_remove_duplicate: bool = False, ) -> List[Tuple[str, float]]: """convert text-index into text-label.""" result_list = [] @@ -83,20 +87,23 @@ def decode( selection = np.ones(len(text_index[batch_idx]), dtype=bool) if is_remove_duplicate: selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1] + for ignored_token in ignored_tokens: selection &= text_index[batch_idx] != ignored_token - char_list = [ - self.character[text_id] for text_id in text_index[batch_idx][selection] - ] + if text_prob is not None: conf_list = text_prob[batch_idx][selection] else: conf_list = [1] * len(selection) + if len(conf_list) == 0: conf_list = [0] + char_list = [ + self.character[text_id] for text_id in text_index[batch_idx][selection] + ] text = "".join(char_list) - if rec_word_box: + if return_word_box: word_list, word_col_list, state_list = self.get_word_info( text, selection ) @@ -117,10 +124,12 @@ def decode( return result_list @staticmethod - def get_word_info(text: str, - selection: np.ndarray) -> Tuple[List[List[str]], List[List[int]], List[str]]: + def get_word_info( + text: str, selection: np.ndarray + ) -> Tuple[List[List[str]], List[List[int]], List[str]]: """ Group the decoded characters and record the corresponding decoded positions. + from https://github.com/PaddlePaddle/PaddleOCR/blob/fbba2178d7093f1dffca65a5b963ec277f1a6125/ppocr/postprocess/rec_postprocess.py#L70 Args: text: the decoded text @@ -145,6 +154,7 @@ def get_word_info(text: str, c_state = "cn" else: c_state = "en&num" + if state == None: state = c_state @@ -156,8 +166,10 @@ def get_word_info(text: str, word_content = [] word_col_content = [] state = c_state + word_content.append(char) word_col_content.append(int(valid_col[c_i])) + if len(word_content) != 0: word_list.append(word_content) word_col_list.append(word_col_content) diff --git a/python/rapidocr_onnxruntime/config.yaml b/python/rapidocr_onnxruntime/config.yaml index f92d3302e..0ee6fbef3 100644 --- a/python/rapidocr_onnxruntime/config.yaml +++ b/python/rapidocr_onnxruntime/config.yaml @@ -8,6 +8,7 @@ Global: width_height_ratio: 8 max_side_len: 2000 min_side_len: 30 + return_word_box: false intra_op_num_threads: &intra_nums -1 inter_op_num_threads: &inter_nums -1 diff --git a/python/rapidocr_onnxruntime/main.py b/python/rapidocr_onnxruntime/main.py index f82f21eea..e897b892f 100644 --- a/python/rapidocr_onnxruntime/main.py +++ b/python/rapidocr_onnxruntime/main.py @@ -8,6 +8,7 @@ import cv2 import numpy as np +from .cal_rec_boxes import CalRecBoxes from .ch_ppocr_cls import TextClassifier from .ch_ppocr_det import TextDetector from .ch_ppocr_rec import TextRecognizer @@ -22,7 +23,6 @@ read_yaml, reduce_max_side, update_model_path, - word_box ) root_dir = Path(__file__).resolve().parent @@ -61,6 +61,8 @@ def __init__(self, config_path: Optional[str] = None, **kwargs): self.max_side_len = global_config["max_side_len"] self.min_side_len = global_config["min_side_len"] + self.cal_rec_boxes = CalRecBoxes() + def __call__( self, img_content: Union[str, np.ndarray, bytes, Path], @@ -72,12 +74,12 @@ def __call__( use_det = self.use_det if use_det is None else use_det use_cls = self.use_cls if use_cls is None else use_cls use_rec = self.use_rec if use_rec is None else use_rec - rec_word_box = False + return_word_box = False if kwargs: box_thresh = kwargs.get("box_thresh", 0.5) unclip_ratio = kwargs.get("unclip_ratio", 1.6) text_score = kwargs.get("text_score", 0.5) - rec_word_box = kwargs.get("rec_word_box", False) + return_word_box = kwargs.get("return_word_box", False) self.text_det.postprocess_op.box_thresh = box_thresh self.text_det.postprocess_op.unclip_ratio = unclip_ratio self.text_score = text_score @@ -104,15 +106,21 @@ def __call__( img, cls_res, cls_elapse = self.text_cls(img) if use_rec: - rec_res, rec_elapse = self.text_rec(img, rec_word_box) - # fix word box by fix rotate and perspective - if dt_boxes is not None and rec_res is not None and rec_word_box: - rec_res = word_box.cal_rec_boxes(dt_boxes, img, rec_res) - for i, rec_res_i in enumerate(rec_res): - if rec_res_i[3]: - rec_res_i[3] = self._get_origin_points(rec_res_i[3], op_record, raw_h, raw_w).astype(np.int32).tolist() + rec_res, rec_elapse = self.text_rec(img, return_word_box) + + if dt_boxes is not None and rec_res is not None and return_word_box: + rec_res = self.cal_rec_boxes(img, dt_boxes, rec_res) + for rec_res_i in rec_res: + if rec_res_i[2]: + rec_res_i[2] = ( + self._get_origin_points(rec_res_i[2], op_record, raw_h, raw_w) + .astype(np.int32) + .tolist() + ) + if dt_boxes is not None and rec_res is not None: dt_boxes = self._get_origin_points(dt_boxes, op_record, raw_h, raw_w) + ocr_res = self.get_final_res( dt_boxes, cls_res, rec_res, det_elapse, cls_elapse, rec_elapse ) @@ -289,9 +297,11 @@ def get_final_res( if not dt_boxes or not rec_res or len(dt_boxes) <= 0: return None, None - ocr_res = [ - [box.tolist(), *res] for box, res in zip(dt_boxes, rec_res) - ], [det_elapse, cls_elapse, rec_elapse] + ocr_res = [[box.tolist(), *res] for box, res in zip(dt_boxes, rec_res)], [ + det_elapse, + cls_elapse, + rec_elapse, + ] return ocr_res def filter_result( diff --git a/python/rapidocr_onnxruntime/utils/parse_parameters.py b/python/rapidocr_onnxruntime/utils/parse_parameters.py index 527a5c983..4d0fa9092 100644 --- a/python/rapidocr_onnxruntime/utils/parse_parameters.py +++ b/python/rapidocr_onnxruntime/utils/parse_parameters.py @@ -37,6 +37,7 @@ def init_args(): global_group.add_argument("--width_height_ratio", type=int, default=8) global_group.add_argument("--max_side_len", type=int, default=2000) global_group.add_argument("--min_side_len", type=int, default=30) + global_group.add_argument("--return_word_box", action="store_true", default=False) global_group.add_argument("--intra_op_num_threads", type=int, default=-1) global_group.add_argument("--inter_op_num_threads", type=int, default=-1) diff --git a/python/rapidocr_onnxruntime/utils/word_box.py b/python/rapidocr_onnxruntime/utils/word_box.py deleted file mode 100644 index 4779a6e98..000000000 --- a/python/rapidocr_onnxruntime/utils/word_box.py +++ /dev/null @@ -1,204 +0,0 @@ -import copy -import math -from typing import Optional, List, Tuple, Union, Any -import numpy as np -import cv2 - - -def cal_rec_boxes( - dt_boxes: Optional[List[np.ndarray]], - crop_imgs: Optional[List[np.ndarray]], - rec_res: Optional[List[Any]]): - res = [] - for i, (box, rec_res) in enumerate(zip(dt_boxes, rec_res)): - direction = "w" - img_crop_width = int( - max( - np.linalg.norm(box[0] - box[1]), - np.linalg.norm(box[2] - box[3]), - ) - ) - img_crop_height = int( - max( - np.linalg.norm(box[0] - box[3]), - np.linalg.norm(box[1] - box[2]), - ) - ) - if img_crop_height * 1.0 / img_crop_width >= 1.5: - direction = "h" - - rec_str, rec_conf, rec_word_info = rec_res[0], rec_res[1], rec_res[2] - crop_img = crop_imgs[i] - h, w = crop_img.shape[:2] - crop_img_box = np.array([[0, 0], [w, 0], [w, h], [0, h]]) - word_box_content_list, word_box_list = cal_ocr_word_box( - rec_str, crop_img_box, rec_word_info - ) - # fix word box overlap - adjust_box_overlap(word_box_list) - word_box_list = reverse_rotate_crop_image(copy.deepcopy(box), word_box_list, direction) - - res.append([rec_res[0], rec_res[1], word_box_content_list, word_box_list]) - return res - - -def adjust_box_overlap(word_box_list): - # 调整bbox有重叠的地方 - for i in range(len(word_box_list) - 1): - cur, nxt = word_box_list[i], word_box_list[i + 1] - if cur[1][0] > nxt[0][0]: # 有交集 - distance = abs(cur[1][0] - nxt[0][0]) - cur[1][0] -= distance / 2 - cur[2][0] -= distance / 2 - nxt[0][0] += distance / 2 - nxt[3][0] += distance / 2 - - -# 绕pointx,pointy顺时针旋转 -def s_rotate(angle, valuex, valuey, pointx, pointy): - valuex = np.array(valuex) - valuey = np.array(valuey) - sRotatex = (valuex - pointx) * math.cos(angle) + (valuey - pointy) * math.sin(angle) + pointx - sRotatey = (valuey - pointy) * math.cos(angle) - (valuex - pointx) * math.sin(angle) + pointy - return [sRotatex, sRotatey] - - -# ———————————————— -# 版权声明:本文为CSDN博主「星夜孤帆」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。 -# 原文链接:https://blog.csdn.net/qq_38826019/article/details/84233397 - -def reverse_rotate_crop_image(bbox_points, word_points_list, direction='w'): - """ - get_rotate_crop_image的逆操作 - img为原图 - part_img为crop后的图 - bbox_points为part_img中对应在原图的bbox, 四个点,左上,右上,右下,左下 - part_points为在part_img中的点[(x, y), (x, y)] - """ - # np.rot - bbox_points = np.float32(bbox_points) - left = int(np.min(bbox_points[:, 0])) - top = int(np.min(bbox_points[:, 1])) - bbox_points[:, 0] = bbox_points[:, 0] - left - bbox_points[:, 1] = bbox_points[:, 1] - top - img_crop_width = int(np.linalg.norm(bbox_points[0] - bbox_points[1])) - img_crop_height = int(np.linalg.norm(bbox_points[0] - bbox_points[3])) - pts_std = np.array( - [ - [0, 0], - [img_crop_width, 0], - [img_crop_width, img_crop_height], - [0, img_crop_height], - ] - ).astype(np.float32) - M = cv2.getPerspectiveTransform(bbox_points, pts_std) - _, IM = cv2.invert(M) - - new_word_points_list = [] - - for word_points in word_points_list: - new_word_points = [] - for point in word_points: - new_point = point - if direction == 'h': - new_point = s_rotate(math.radians(-90), new_point[0], new_point[1], 0, 0) - new_point[0] = new_point[0] + img_crop_width - - p = np.float32(new_point + [1]) - x, y, z = np.dot(IM, p) - new_point = [x / z, y / z] - - new_point = [int(new_point[0] + left), int(new_point[1] + top)] - new_word_points.append(new_point) - new_word_points = order_points(new_word_points) - new_word_points_list.append(new_word_points) - return new_word_points_list - - -def cal_ocr_word_box(rec_str: str, - box: np.ndarray, - rec_word_info: List[Tuple[str, List[int]]]): - """Calculate the detection frame for each word based on the results of recognition and detection of ocr""" - - col_num, word_list, word_col_list, state_list = rec_word_info - box = box.tolist() - bbox_x_start = box[0][0] - bbox_x_end = box[1][0] - bbox_y_start = box[0][1] - bbox_y_end = box[2][1] - - cell_width = (bbox_x_end - bbox_x_start) / col_num - word_box_list = [] - word_box_content_list = [] - cn_width_list = [] - cn_col_list = [] - for word, word_col, state in zip(word_list, word_col_list, state_list): - if state == "cn": - if len(word_col) != 1: - char_seq_length = (word_col[-1] - word_col[0] + 1) * cell_width - char_width = char_seq_length / (len(word_col) - 1) - cn_width_list.append(char_width) - cn_col_list += word_col - word_box_content_list += word - else: - cell_x_start = bbox_x_start + int(word_col[0] * cell_width) - cell_x_end = bbox_x_start + int((word_col[-1] + 1) * cell_width) - cell = [ - [cell_x_start, bbox_y_start], - [cell_x_end, bbox_y_start], - [cell_x_end, bbox_y_end], - [cell_x_start, bbox_y_end], - ] - word_box_list.append(cell) - word_box_content_list.append("".join(word)) - if len(cn_col_list) != 0: - if len(cn_width_list) != 0: - avg_char_width = np.mean(cn_width_list) - else: - avg_char_width = (bbox_x_end - bbox_x_start) / len(rec_str) - for center_idx in cn_col_list: - center_x = (center_idx + 0.5) * cell_width - cell_x_start = max(int(center_x - avg_char_width / 2), 0) + bbox_x_start - cell_x_end = ( - min(int(center_x + avg_char_width / 2), bbox_x_end - bbox_x_start) - + bbox_x_start - ) - cell = [ - [cell_x_start, bbox_y_start], - [cell_x_end, bbox_y_start], - [cell_x_end, bbox_y_end], - [cell_x_start, bbox_y_end], - ] - word_box_list.append(cell) - sorted_word_box_list = sorted(word_box_list, key=lambda box: box[0][0]) - return word_box_content_list, sorted_word_box_list - - -def order_points(box): - '''矩形框顺序排列 - :param box: numpy.array, shape=(4, 2) - :return: - ''' - box = np.array(box).reshape((-1, 2)) - center_x, center_y = np.mean(box[:, 0]), np.mean(box[:, 1]) - if np.any(box[:, 0] == center_x) and np.any(box[:, 1] == center_y): # 有两点横坐标相等,有两点纵坐标相等,菱形 - p1 = box[np.where(box[:, 0] == np.min(box[:, 0]))] - p2 = box[np.where(box[:, 1] == np.min(box[:, 1]))] - p3 = box[np.where(box[:, 0] == np.max(box[:, 0]))] - p4 = box[np.where(box[:, 1] == np.max(box[:, 1]))] - elif np.all(box[:, 0] == center_x): # 四个点的横坐标都相同 - y_sort = np.argsort(box[:, 1]) - p1 = box[y_sort[0]] - p2 = box[y_sort[1]] - p3 = box[y_sort[2]] - p4 = box[y_sort[3]] - elif np.any(box[:, 0] == center_x) and np.all(box[:, 1] != center_y): # 只有两点横坐标相等,先上下再左右 - p12, p34 = box[np.where(box[:, 1] < center_y)], box[np.where(box[:, 1] > center_y)] - p1, p2 = p12[np.where(p12[:, 0] == np.min(p12[:, 0]))], p12[np.where(p12[:, 0] == np.max(p12[:, 0]))] - p3, p4 = p34[np.where(p34[:, 0] == np.max(p34[:, 0]))], p34[np.where(p34[:, 0] == np.min(p34[:, 0]))] - else: # 只有两点纵坐标相等,或者是没有相等的,先左右再上下 - p14, p23 = box[np.where(box[:, 0] < center_x)], box[np.where(box[:, 0] > center_x)] - p1, p4 = p14[np.where(p14[:, 1] == np.min(p14[:, 1]))], p14[np.where(p14[:, 1] == np.max(p14[:, 1]))] - p2, p3 = p23[np.where(p23[:, 1] == np.min(p23[:, 1]))], p23[np.where(p23[:, 1] == np.max(p23[:, 1]))] - - return np.array([p1, p2, p3, p4]).reshape((-1, 2)).tolist() diff --git a/python/tests/base_module.py b/python/tests/base_module.py index 1a5153f83..e53849415 100644 --- a/python/tests/base_module.py +++ b/python/tests/base_module.py @@ -3,6 +3,7 @@ # @Contact: liekkaskono@163.com import importlib import sys +from dataclasses import dataclass from pathlib import Path from typing import Optional, Union @@ -53,3 +54,10 @@ def download_file(url: str, save_path: Union[str, Path]): class DownloadModelError(Exception): pass + + +@dataclass +class Platform: + mac: str = "Darwin" + windows: str = "Windows" + linux: str = "Linux" diff --git a/python/tests/test_ort.py b/python/tests/test_ort.py index 8bbf7ed9e..637b7eea8 100644 --- a/python/tests/test_ort.py +++ b/python/tests/test_ort.py @@ -2,6 +2,7 @@ # @Author: SWHL # @Contact: liekkaskono@163.com import logging +import platform import sys from pathlib import Path from typing import List @@ -14,12 +15,14 @@ sys.path.append(str(root_dir)) from rapidocr_onnxruntime import LoadImageError, RapidOCR -from tests.base_module import download_file + +from .base_module import Platform, download_file engine = RapidOCR() tests_dir = root_dir / "tests" / "test_files" img_path = tests_dir / "ch_en_num.jpg" package_name = "rapidocr_onnxruntime" +cur_platform = platform.system() def test_long_img(): @@ -28,7 +31,11 @@ def test_long_img(): download_file(img_url, save_path=img_path) result, _ = engine(img_path) assert result is not None - assert len(result) == 55 + if cur_platform == Platform.mac: + assert len(result) == 53 + elif cur_platform == Platform.linux: + assert len(result) == 55 + img_path.unlink() @@ -96,7 +103,6 @@ def test_only_det(): result, _ = engine(img_path, use_det=True, use_cls=False, use_rec=False) assert len(result) == 18 - assert result[0][0] == [5.0, 2.0] def test_only_cls(): @@ -220,25 +226,28 @@ def test_input_three_ndim_one_channel(): result, _ = engine(img) - assert result is not None - assert result[0][1] == "正品促销" - assert len(result) == 17 + if cur_platform == Platform.mac: + assert len(result) == 17 + else: + assert result is not None + assert result[0][1] == "正品促销" + assert len(result) == 17 + -# @pytest.mark.parametrize( "img_name,words", [ ( "black_font_color_transparent.png", - ['我', '是', '中', '国', '人'], + ["我", "是", "中", "国", "人"], ), ( "text_vertical_words.png", - ['已', '取', '之', '時', '不', '參', '一', '人', '見', '而'], + ["已", "取", "之", "時", "不", "參", "一", "人", "見", "而"], ), ], ) def test_word_ocr(img_name: str, words: List[str]): img_path = tests_dir / img_name - result, _ = engine(img_path, rec_word_box=True) - assert result[0][3] == words \ No newline at end of file + result, _ = engine(img_path, return_word_box=True) + assert result[0][4] == words