Skip to content

Commit

Permalink
Update transducer.py (#2337)
Browse files Browse the repository at this point in the history
fix 'th_accuracy' not in transducer
  • Loading branch information
DaobinZhu authored Feb 2, 2024
1 parent 220caf6 commit f605684
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions wenet/transducer/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ def forward(
# optional attention decoder
loss_att: Optional[torch.Tensor] = None
if self.attention_decoder_weight != 0.0 and self.decoder is not None:
loss_att, _ = self._calc_att_loss(encoder_out, encoder_mask, text,
text_lengths)
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths)
else:
acc_att = None

# optional ctc
loss_ctc: Optional[torch.Tensor] = None
Expand All @@ -145,6 +147,7 @@ def forward(
'loss_att': loss_att,
'loss_ctc': loss_ctc,
'loss_rnnt': loss_rnnt,
'th_accuracy': acc_att,
}

def init_bs(self):
Expand Down

0 comments on commit f605684

Please sign in to comment.