diff --git a/whisper/decoding.py b/whisper/decoding.py index ed8d900a156..bb70cc024bd 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -423,10 +423,14 @@ def apply(self, logits: Tensor, tokens: Tensor): else: # cannot be normal text tokens logits[k, : self.tokenizer.eot] = -np.inf - # apply the `max_initial_timestamp` option - if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None: - last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index - logits[:, last_allowed + 1 :] = -np.inf + if tokens.shape[1] == self.sample_begin: + # suppress generating non-timestamp tokens at the beginning + logits[:, : self.tokenizer.timestamp_begin] = -np.inf + + # apply the `max_initial_timestamp` option + if self.max_initial_timestamp_index is not None: + last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + logits[:, last_allowed + 1 :] = -np.inf # if sum of probability over timestamps is above any other token, sample timestamp logprobs = F.log_softmax(logits.float(), dim=-1)