diff --git a/ppocr/modeling/heads/rec_satrn_head.py b/ppocr/modeling/heads/rec_satrn_head.py index b969c89693..263c3b1bb3 100644 --- a/ppocr/modeling/heads/rec_satrn_head.py +++ b/ppocr/modeling/heads/rec_satrn_head.py @@ -283,13 +283,15 @@ def forward(self, feat, valid_ratios=None): Tensor: A tensor of shape :math:`(N, T, D_m)`. """ if valid_ratios is None: - valid_ratios = [1.0 for _ in range(feat.shape[0])] + bs = paddle.shape(feat)[0] + valid_ratios = paddle.full((bs, 1), 1., dtype=paddle.float32) + feat = self.position_enc(feat) n, c, h, w = feat.shape mask = paddle.zeros((n, h, w)) for i, valid_ratio in enumerate(valid_ratios): - valid_width = min(w, math.ceil(w * valid_ratio)) + valid_width = int(min(w, paddle.ceil(w * valid_ratio))) mask[i, :, :valid_width] = 1 mask = mask.reshape([n, h * w]) @@ -347,7 +349,6 @@ def _get_sinusoid_encoding_table(self, n_position, d_hid): return sinusoid_table.unsqueeze(0) def forward(self, x): - x = x + self.position_table[:, :x.shape[1]].clone().detach() return self.dropout(x) @@ -514,7 +515,6 @@ def forward_train(self, feat, out_enc, targets, valid_ratio): return outputs def forward_test(self, feat, out_enc, valid_ratio): - src_mask = self._get_mask(out_enc, valid_ratio) N = out_enc.shape[0] init_target_seq = paddle.full( @@ -556,13 +556,11 @@ def __init__(self, enc_cfg, dec_cfg, **kwargs): self.decoder = SATRNDecoder(**dec_cfg) def forward(self, feat, targets=None): - if targets is not None: targets, valid_ratio = targets else: targets, valid_ratio = None, None holistic_feat = self.encoder(feat, valid_ratio) # bsz c - final_out = self.decoder(feat, holistic_feat, targets, valid_ratio) return final_out