Skip to content

Commit

Permalink
Make more improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
RussellLuo committed Apr 13, 2024
1 parent a96ba4f commit ebf52cd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 40 deletions.
4 changes: 2 additions & 2 deletions paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def __init__(self, **kwargs):
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, alpha_color=(255, 255, 255)):
"""
OCR with PaddleOCR
args:
img: img for OCR, support ndarray, img_path and list or ndarray
det: use text detection or not. If False, only rec will be exec. Default is True
Expand Down Expand Up @@ -832,7 +832,7 @@ def main():
outfile = args.output + '/' + img_name + '.txt'
with open(outfile,'w',encoding='utf-8') as f:
f.writelines(lines)

elif args.type == 'structure':
img, flag_gif, flag_pdf = check_and_read(img_path)
if not flag_gif and not flag_pdf:
Expand Down
49 changes: 17 additions & 32 deletions ppstructure/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
# has problems with OCR recognition accuracy.
#
# To enhance the OCR recognition accuracy, we implement a patch fix
# that first detect all text regions by using the text_detector
# and then recognize the texts from the text regions (intersecting
# with the layout regions) by using the text_recognizer.
dt_boxes = []
# that first use text_system to detect and recognize all text information
# and then filter out relevant texts according to the layout regions.
text_res = None
if self.text_system is not None:
dt_boxes, elapse = self.text_system.text_detector(img)
time_dict['det'] = elapse
text_res, ocr_time_dict = self._predict_text(img)
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']

res_list = []
for region in layout_res:
Expand All @@ -152,10 +152,9 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict['det'] += table_time_dict['det']
time_dict['rec'] += table_time_dict['rec']
else:
if self.text_system is not None:
res, ocr_time_dict = self._predict_text(ori_im, roi_img, bbox, dt_boxes)
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']
if text_res is not None:
# Filter the text results whose regions intersect with the current layout bbox.
res = self._filter_text_res(text_res, bbox)

res_list.append({
'type': region['label'].lower(),
Expand All @@ -177,20 +176,8 @@ def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):

return None, None

def _predict_text(self, ori_im, roi_img, bbox, dt_boxes):
x1, y1, x2, y2 = bbox

if self.recovery:
wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
wht_im[y1:y2, x1:x2, :] = roi_img
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
wht_im)
else:
# Filter the text regions that intersect with the current bbox.
intersecting_dt_boxes = self._filter_boxes(dt_boxes, bbox)
# Recognize texts from these intersecting text regions.
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
ori_im, dt_boxes=intersecting_dt_boxes)
def _predict_text(self, img):
filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(img)

# remove style char,
# when using the recognition model trained on the PubtabNet dataset,
Expand Down Expand Up @@ -226,16 +213,14 @@ def _predict_text(self, ori_im, roi_img, bbox, dt_boxes):
})
return res, ocr_time_dict

def _filter_boxes(self, dt_boxes, bbox):
boxes = []

for idx in range(len(dt_boxes)):
box = dt_boxes[idx]
def _filter_text_res(self, text_res, bbox):
res = []
for r in text_res:
box = r['text_region']
rect = box[0][0], box[0][1], box[2][0], box[2][1]
if self._has_intersection(bbox, rect):
boxes.append(box.tolist())

return np.array(boxes, np.float32).reshape((len(boxes), 4, 2))
res.append(r)
return res

def _has_intersection(self, rect1, rect2):
x_min1, y_min1, x_max1, y_max1 = rect1
Expand Down
9 changes: 3 additions & 6 deletions tools/infer/predict_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res):
logger.debug(f"{bno}, {rec_res[bno]}")
self.crop_image_res_index += bbox_num

def __call__(self, img, cls=True, dt_boxes=None):
def __call__(self, img, cls=True):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}

if img is None:
Expand All @@ -73,11 +73,8 @@ def __call__(self, img, cls=True, dt_boxes=None):

start = time.time()
ori_im = img.copy()

elapse = 0
if dt_boxes is None:
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse

if dt_boxes is None:
logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
Expand Down

0 comments on commit ebf52cd

Please sign in to comment.