From 6e064cf8575c222b2a92d2f2db48e5acdb502cf3 Mon Sep 17 00:00:00 2001 From: Mddct Date: Mon, 4 Mar 2024 13:26:52 +0800 Subject: [PATCH] [whisper] fix decoding maxlen --- wenet/transformer/search.py | 2 ++ wenet/whisper/whisper.py | 1 + 2 files changed, 3 insertions(+) diff --git a/wenet/transformer/search.py b/wenet/transformer/search.py index 958442906..f431c023f 100644 --- a/wenet/transformer/search.py +++ b/wenet/transformer/search.py @@ -292,6 +292,8 @@ def attention_beam_search( cache: Optional[List[torch.Tensor]] = None if model.decoder.use_sdpa: encoder_mask = mask_to_bias(encoder_mask, encoder_out.dtype) + if hasattr(model, 'decode_maxlen'): + maxlen = model.decode_maxlen # 2. Decoder forward step by step for i in range(prefix_len, maxlen + 1): # Stop if all batch and all beam produce eos diff --git a/wenet/whisper/whisper.py b/wenet/whisper/whisper.py index eedb00178..cc95e7965 100644 --- a/wenet/whisper/whisper.py +++ b/wenet/whisper/whisper.py @@ -46,6 +46,7 @@ def __init__( assert reverse_weight == 0.0 self.sos = special_tokens["sot"] self.eos = special_tokens["eot"] + self.decode_maxlen = self.decoder.embed[1].max_len # TODO(xcsong): time align def set_alignment_heads(self, dump: bytes):