Skip to content

Commit

Permalink
feat: char rec(单字识别) (#254)
Browse files Browse the repository at this point in the history
* feat: add word spilit for ocr

* chore: rm unnecessary change

* test: fix test error
  • Loading branch information
Joker1212 authored Nov 13, 2024
1 parent 2be4fef commit 2ac3c35
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 38 deletions.
13 changes: 9 additions & 4 deletions python/rapidocr_onnxruntime/ch_ppocr_rec/text_recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, config: Dict[str, Any]):
self.rec_image_shape = config["rec_img_shape"]

def __call__(
self, img_list: Union[np.ndarray, List[np.ndarray]]
self, img_list: Union[np.ndarray, List[np.ndarray]], rec_word_box=False
) -> Tuple[List[Tuple[str, float]], float]:
if isinstance(img_list, np.ndarray):
img_list = [img_list]
Expand All @@ -58,12 +58,15 @@ def __call__(
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
max_wh_ratio = 0
# Parameter Alignment for PaddleOCR
imgC, imgH, imgW = self.rec_image_shape[:3]
max_wh_ratio = imgW / imgH
wh_ratio_list = []
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)

wh_ratio_list.append(wh_ratio)
norm_img_batch = []
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
Expand All @@ -72,7 +75,9 @@ def __call__(

starttime = time.time()
preds = self.session(norm_img_batch)[0]
rec_result = self.postprocess_op(preds)
rec_result = self.postprocess_op(preds, rec_word_box,
wh_ratio_list=wh_ratio_list,
max_wh_ratio=max_wh_ratio,)

for rno, one_res in enumerate(rec_result):
rec_res[indices[beg_img_no + rno]] = one_res
Expand Down
117 changes: 93 additions & 24 deletions python/rapidocr_onnxruntime/ch_ppocr_rec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ def __init__(
self.character = self.get_character(character, character_path)
self.dict = {char: i for i, char in enumerate(self.character)}

def __call__(self, preds: np.ndarray) -> List[Tuple[str, float]]:
def __call__(self, preds: np.ndarray, rec_word_box: bool = False, **kwargs) -> List[Tuple[str, float]]:
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
text = self.decode(preds_idx, preds_prob, rec_word_box, is_remove_duplicate=True)
if rec_word_box:
for rec_idx, rec in enumerate(text):
wh_ratio = kwargs["wh_ratio_list"][rec_idx]
max_wh_ratio = kwargs["max_wh_ratio"]
rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
return text

def get_character(
Expand Down Expand Up @@ -64,38 +69,102 @@ def insert_special_char(
return character_list

def decode(
self,
text_index: np.ndarray,
text_prob: Optional[np.ndarray] = None,
is_remove_duplicate: bool = False,
self,
text_index: np.ndarray,
text_prob: Optional[np.ndarray] = None,
rec_word_box: bool = False,
is_remove_duplicate: bool = False,
) -> List[Tuple[str, float]]:
"""convert text-index into text-label."""
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list, conf_list = [], []
cur_pred_ids = text_index[batch_idx]
for idx, cur_idx in enumerate(cur_pred_ids):
if cur_idx in ignored_tokens:
continue

if is_remove_duplicate:
# only for predict
if idx > 0 and cur_pred_ids[idx - 1] == cur_idx:
continue

char_list.append(self.character[int(cur_idx)])

if text_prob is None:
conf_list.append(1)
else:
conf_list.append(text_prob[batch_idx][idx])
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
if is_remove_duplicate:
selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
for ignored_token in ignored_tokens:
selection &= text_index[batch_idx] != ignored_token
char_list = [
self.character[text_id] for text_id in text_index[batch_idx][selection]
]
if text_prob is not None:
conf_list = text_prob[batch_idx][selection]
else:
conf_list = [1] * len(selection)
if len(conf_list) == 0:
conf_list = [0]

text = "".join(char_list)
result_list.append((text, np.mean(conf_list if any(conf_list) else [0])))
if rec_word_box:
word_list, word_col_list, state_list = self.get_word_info(
text, selection
)
result_list.append(
(
text,
np.mean(conf_list).tolist(),
[
len(text_index[batch_idx]),
word_list,
word_col_list,
state_list,
],
)
)
else:
result_list.append((text, np.mean(conf_list).tolist()))
return result_list

@staticmethod
def get_word_info(text: str,
selection: np.ndarray) -> Tuple[List[List[str]], List[List[int]], List[str]]:
"""
Group the decoded characters and record the corresponding decoded positions.
Args:
text: the decoded text
selection: the bool array that identifies which columns of features are decoded as non-separated characters
Returns:
word_list: list of the grouped words
word_col_list: list of decoding positions corresponding to each character in the grouped word
state_list: list of marker to identify the type of grouping words, including two types of grouping words:
- 'cn': continous chinese characters (e.g., 你好啊)
- 'en&num': continous english characters (e.g., hello), number (e.g., 123, 1.123), or mixed of them connected by '-' (e.g., VGG-16)
"""
state = None
word_content = []
word_col_content = []
word_list = []
word_col_list = []
state_list = []
valid_col = np.where(selection == True)[0]

for c_i, char in enumerate(text):
if "\u4e00" <= char <= "\u9fff":
c_state = "cn"
else:
c_state = "en&num"
if state == None:
state = c_state

if state != c_state:
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)
word_content = []
word_col_content = []
state = c_state
word_content.append(char)
word_col_content.append(int(valid_col[c_i]))
if len(word_content) != 0:
word_list.append(word_content)
word_col_list.append(word_col_content)
state_list.append(state)

return word_list, word_col_list, state_list

@staticmethod
def get_ignored_tokens() -> List[int]:
return [0] # for ctc blank
23 changes: 14 additions & 9 deletions python/rapidocr_onnxruntime/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
read_yaml,
reduce_max_side,
update_model_path,
word_box
)

root_dir = Path(__file__).resolve().parent
Expand Down Expand Up @@ -71,12 +72,12 @@ def __call__(
use_det = self.use_det if use_det is None else use_det
use_cls = self.use_cls if use_cls is None else use_cls
use_rec = self.use_rec if use_rec is None else use_rec

rec_word_box = False
if kwargs:
box_thresh = kwargs.get("box_thresh", 0.5)
unclip_ratio = kwargs.get("unclip_ratio", 1.6)
text_score = kwargs.get("text_score", 0.5)

rec_word_box = kwargs.get("rec_word_box", False)
self.text_det.postprocess_op.box_thresh = box_thresh
self.text_det.postprocess_op.unclip_ratio = unclip_ratio
self.text_score = text_score
Expand All @@ -103,11 +104,15 @@ def __call__(
img, cls_res, cls_elapse = self.text_cls(img)

if use_rec:
rec_res, rec_elapse = self.text_rec(img)

rec_res, rec_elapse = self.text_rec(img, rec_word_box)
# fix word box by fix rotate and perspective
if dt_boxes is not None and rec_res is not None and rec_word_box:
rec_res = word_box.cal_rec_boxes(dt_boxes, img, rec_res)
for i, rec_res_i in enumerate(rec_res):
if rec_res_i[3]:
rec_res_i[3] = self._get_origin_points(rec_res_i[3], op_record, raw_h, raw_w).astype(np.int32).tolist()
if dt_boxes is not None and rec_res is not None:
dt_boxes = self._get_origin_points(dt_boxes, op_record, raw_h, raw_w)

ocr_res = self.get_final_res(
dt_boxes, cls_res, rec_res, det_elapse, cls_elapse, rec_elapse
)
Expand Down Expand Up @@ -237,7 +242,7 @@ def _get_origin_points(
raw_h: int,
raw_w: int,
) -> np.ndarray:
dt_boxes_array = np.array(dt_boxes)
dt_boxes_array = np.array(dt_boxes).astype(np.float32)
for op in reversed(list(op_record.keys())):
v = op_record[op]
if "padding" in op:
Expand All @@ -263,7 +268,7 @@ def get_final_res(
self,
dt_boxes: Optional[List[np.ndarray]],
cls_res: Optional[List[List[Union[str, float]]]],
rec_res: Optional[List[Tuple[str, float]]],
rec_res: Optional[List[Tuple[str, float, List[Union[str, float]]]]],
det_elapse: float,
cls_elapse: float,
rec_elapse: float,
Expand All @@ -285,7 +290,7 @@ def get_final_res(
return None, None

ocr_res = [
[box.tolist(), res[0], res[1]] for box, res in zip(dt_boxes, rec_res)
[box.tolist(), *res] for box, res in zip(dt_boxes, rec_res)
], [det_elapse, cls_elapse, rec_elapse]
return ocr_res

Expand All @@ -299,7 +304,7 @@ def filter_result(

filter_boxes, filter_rec_res = [], []
for box, rec_reuslt in zip(dt_boxes, rec_res):
text, score = rec_reuslt
text, score = rec_reuslt[0], rec_reuslt[1]
if float(score) >= self.text_score:
filter_boxes.append(box)
filter_rec_res.append(rec_reuslt)
Expand Down
Loading

0 comments on commit 2ac3c35

Please sign in to comment.