From f605684e870ff877cbe68bda44ffc3e955b7a0b9 Mon Sep 17 00:00:00 2001 From: Daobin Zhu <37605608+DaobinZhu@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:21:37 +0800 Subject: [PATCH] Update transducer.py (#2337) fix 'th_accuracy' not in transducer --- wenet/transducer/transducer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index 746150141..224077337 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -124,8 +124,10 @@ def forward( # optional attention decoder loss_att: Optional[torch.Tensor] = None if self.attention_decoder_weight != 0.0 and self.decoder is not None: - loss_att, _ = self._calc_att_loss(encoder_out, encoder_mask, text, - text_lengths) + loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, + text, text_lengths) + else: + acc_att = None # optional ctc loss_ctc: Optional[torch.Tensor] = None @@ -145,6 +147,7 @@ def forward( 'loss_att': loss_att, 'loss_ctc': loss_ctc, 'loss_rnnt': loss_rnnt, + 'th_accuracy': acc_att, } def init_bs(self):