Skip to content

Commit

Permalink
add prediction for parseq
Browse files Browse the repository at this point in the history
  • Loading branch information
ToddBear committed Sep 6, 2023
1 parent 473cc8b commit c37c989
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ppocr/modeling/heads/rec_parseq_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def forward_test(self, memory, max_length=None):
tgt_in[:, (0)] = self.bos_id

logits = []
for i in range(num_steps):
for i in range(paddle.to_tensor(num_steps)):
j = i + 1
tgt_out = self.decode(tgt_in[:, :j], memory, tgt_mask[:j, :j], tgt_query=pos_queries[:, i:j], tgt_query_mask=query_mask[i:j, :j])
p_i = self.head(tgt_out)
Expand Down
5 changes: 4 additions & 1 deletion ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,10 @@ def __init__(self, character_dict_path=None, use_space_char=False,
self.max_text_length = kwargs.get('max_text_length', 25)

def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict']
if isinstance(preds, dict):
pred = preds['predict']
else:
pred = preds

char_num = len(self.character_str) + 1 # We don't predict <bos> nor <pad>, with only addition <eos>
if isinstance(pred, paddle.Tensor):
Expand Down
21 changes: 21 additions & 0 deletions tools/infer/predict_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ def __init__(self, args):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "ParseQ":
postprocess_params = {
'name': 'ParseQLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger)
Expand Down Expand Up @@ -348,6 +354,17 @@ def resize_norm_img_svtr(self, img, image_shape):
resized_image /= 0.5
return resized_image

def resize_norm_img_parseq(self, img, image_shape):

imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_CUBIC)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image

def resize_norm_img_abinet(self, img, image_shape):

imgC, imgH, imgW = image_shape
Expand Down Expand Up @@ -480,6 +497,10 @@ def __call__(self, img_list):
word_label_list = []
norm_img_mask_batch.append(norm_image_mask)
word_label_list.append(word_label)
elif self.rec_algorithm == "ParseQ":
norm_img = self.resize_norm_img_parseq(img_list[indices[ino]], self.rec_image_shape)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
Expand Down
4 changes: 1 addition & 3 deletions train.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
# recommended paddle.__version__ == 2.0.0
export CUDA_VISIBLE_DEVICES=0,2,6,7
python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,2,6,7' tools/train.py -c configs/rec/rec_vit_parseq.yml #rec_r31_sar
# python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '2,6' tools/train.py -c configs/rec/rec_r31_sar.yml
python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml

0 comments on commit c37c989

Please sign in to comment.