From aead305f2e206b9b331f92da824fcad3671ecc4f Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Tue, 24 Jan 2023 14:05:57 -0800 Subject: [PATCH] drop python 3.7 support (#889) --- whisper/decoding.py | 19 +-- whisper/tokenizer.py | 38 ++---- whisper/transcribe.py | 279 +++++++++++++++++++++++++++++++++--------- 3 files changed, 241 insertions(+), 95 deletions(-) diff --git a/whisper/decoding.py b/whisper/decoding.py index bb70cc024..983c898a3 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -252,11 +252,10 @@ def __init__(self, temperature: float, eot: int): self.eot = eot def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: - temperature = self.temperature - if temperature == 0: + if self.temperature == 0: next_tokens = logits.argmax(dim=-1) else: - next_tokens = Categorical(logits=logits / temperature).sample() + next_tokens = Categorical(logits=logits / self.temperature).sample() logprobs = F.log_softmax(logits.float(), dim=-1) current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] @@ -511,10 +510,8 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions: def _get_initial_tokens(self) -> Tuple[int]: tokens = list(self.sot_sequence) - prefix = self.options.prefix - prompt = self.options.prompt - if prefix: + if prefix := self.options.prefix: prefix_tokens = ( self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix ) @@ -523,7 +520,7 @@ def _get_initial_tokens(self) -> Tuple[int]: prefix_tokens = prefix_tokens[-max_prefix_len:] tokens = tokens + prefix_tokens - if prompt: + if prompt := self.options.prompt: prompt_tokens = ( self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt ) @@ -698,13 +695,9 @@ def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOpt result: Union[DecodingResult, List[DecodingResult]] The result(s) of decoding contained in `DecodingResult` dataclass instance(s) """ - single = mel.ndim == 2 - if single: + if single := mel.ndim == 2: mel = mel.unsqueeze(0) result = DecodingTask(model, options).run(mel) - - if single: - result = result[0] - return result + return result[0] if single else result diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index a27cb359e..7b4605f3c 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass -from functools import lru_cache +from functools import lru_cache, cached_property from typing import List, Optional, Tuple, Union import numpy as np @@ -156,43 +156,35 @@ def decode_with_timestamps(self, tokens) -> str: outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] return "".join(outputs) - @property - @lru_cache() + @cached_property def eot(self) -> int: return self.tokenizer.eos_token_id - @property - @lru_cache() + @cached_property def sot(self) -> int: return self._get_single_token_id("<|startoftranscript|>") - @property - @lru_cache() + @cached_property def sot_lm(self) -> int: return self._get_single_token_id("<|startoflm|>") - @property - @lru_cache() + @cached_property def sot_prev(self) -> int: return self._get_single_token_id("<|startofprev|>") - @property - @lru_cache() + @cached_property def no_speech(self) -> int: return self._get_single_token_id("<|nospeech|>") - @property - @lru_cache() + @cached_property def no_timestamps(self) -> int: return self._get_single_token_id("<|notimestamps|>") - @property - @lru_cache() + @cached_property def timestamp_begin(self) -> int: return self.tokenizer.all_special_ids[-1] + 1 - @property - @lru_cache() + @cached_property def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: @@ -210,8 +202,7 @@ def language_token(self) -> int: raise KeyError(f"Language {self.language} not found in tokenizer.") - @property - @lru_cache() + @cached_property def all_language_tokens(self) -> Tuple[int]: result = [] for token, token_id in zip( @@ -222,18 +213,15 @@ def all_language_tokens(self) -> Tuple[int]: result.append(token_id) return tuple(result) - @property - @lru_cache() + @cached_property def all_language_codes(self) -> Tuple[str]: return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) - @property - @lru_cache() + @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: return tuple(list(self.sot_sequence) + [self.no_timestamps]) - @property - @lru_cache() + @cached_property def non_speech_tokens(self) -> Tuple[int]: """ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech diff --git a/whisper/transcribe.py b/whisper/transcribe.py index fd52c51a9..bd5261e57 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -10,7 +10,15 @@ from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram from .decoding import DecodingOptions, DecodingResult from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer +from .utils import ( + exact_div, + format_timestamp, + make_safe, + optional_int, + optional_float, + str2bool, + get_writer, +) if TYPE_CHECKING: from .model import Whisper @@ -27,6 +35,7 @@ def transcribe( no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, segment_callback: Optional[callable] = None, + initial_prompt: Optional[str] = None, **decode_options, ): """ @@ -89,19 +98,25 @@ def transcribe( decode_options["language"] = "en" else: if verbose: - print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") + print( + "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" + ) segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) _, probs = model.detect_language(segment) decode_options["language"] = max(probs, key=probs.get) if verbose is not None: - print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") + print( + f"Detected language: {LANGUAGES[decode_options['language']].title()}" + ) language = decode_options["language"] task = decode_options.get("task", "transcribe") tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task) def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: - temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature + temperatures = ( + [temperature] if isinstance(temperature, (int, float)) else temperature + ) decode_result = None for t in temperatures: @@ -118,9 +133,15 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: decode_result = model.decode(segment, options) needs_fallback = False - if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: + if ( + compression_ratio_threshold is not None + and decode_result.compression_ratio > compression_ratio_threshold + ): needs_fallback = True # too repetitive - if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: + if ( + logprob_threshold is not None + and decode_result.avg_logprob < logprob_threshold + ): needs_fallback = True # average log probability is too low if not needs_fallback: @@ -139,15 +160,18 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: all_segments = [] prompt_reset_since = 0 - initial_prompt = decode_options.pop("initial_prompt", None) or [] - if initial_prompt: - initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) - all_tokens.extend(initial_prompt) + if initial_prompt is not None: + initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) + all_tokens.extend(initial_prompt_tokens) + else: + initial_prompt_tokens = [] def add_segment( *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult ): - text = tokenizer.decode([token for token in text_tokens if token < tokenizer.eot]) + text = tokenizer.decode( + [token for token in text_tokens if token < tokenizer.eot] + ) if len(text.strip()) == 0: # skip empty text output return @@ -167,13 +191,19 @@ def add_segment( ) if verbose: - print(make_safe(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")) + print( + make_safe( + f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" + ) + ) # show the progress bar when verbose is False (otherwise the transcribed text will be printed) num_frames = mel.shape[-1] previous_seek_value = seek - with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: + with tqdm.tqdm( + total=num_frames, unit="frames", disable=verbose is not False + ) as pbar: while seek < num_frames: timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) segment = pad_or_trim(mel[:, seek:], N_FRAMES).to(model.device).to(dtype) @@ -186,17 +216,26 @@ def add_segment( if no_speech_threshold is not None: # no voice activity check should_skip = result.no_speech_prob > no_speech_threshold - if logprob_threshold is not None and result.avg_logprob > logprob_threshold: + if ( + logprob_threshold is not None + and result.avg_logprob > logprob_threshold + ): # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False if should_skip: - seek += segment.shape[-1] # fast-forward to the next segment boundary + seek += segment.shape[ + -1 + ] # fast-forward to the next segment boundary continue timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) - consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) - if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[ + 0 + ].add_(1) + if ( + len(consecutive) > 0 + ): # if the output contains two consecutive timestamp tokens last_slice = 0 for current_slice in consecutive: sliced_tokens = tokens[last_slice:current_slice] @@ -207,7 +246,8 @@ def add_segment( sliced_tokens[-1].item() - tokenizer.timestamp_begin ) add_segment( - start=timestamp_offset + start_timestamp_position * time_precision, + start=timestamp_offset + + start_timestamp_position * time_precision, end=timestamp_offset + end_timestamp_position * time_precision, text_tokens=sliced_tokens[1:-1], result=result, @@ -221,10 +261,15 @@ def add_segment( else: duration = segment_duration timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin: + if ( + len(timestamps) > 0 + and timestamps[-1].item() != tokenizer.timestamp_begin + ): # no consecutive timestamps but it has a timestamp; use the last one. # single timestamp at the end means no speech after the last timestamp. - last_timestamp_position = timestamps[-1].item() - tokenizer.timestamp_begin + last_timestamp_position = ( + timestamps[-1].item() - tokenizer.timestamp_begin + ) duration = last_timestamp_position * time_precision add_segment( @@ -247,40 +292,160 @@ def add_segment( pbar.update(min(num_frames, seek) - previous_seek_value) previous_seek_value = seek - return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) + return dict( + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), + segments=all_segments, + language=language, + ) def cli(): from . import available_models - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") - parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") - parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") - parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") - parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") - parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") - - parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") - parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection") - - parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") - parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") - parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero") - parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search") - parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default") - - parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") - parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") - parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") - parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") - - parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") - parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed") - parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed") - parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence") - parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "audio", nargs="+", type=str, help="audio file(s) to transcribe" + ) + parser.add_argument( + "--model", + default="small", + choices=available_models(), + help="name of the Whisper model to use", + ) + parser.add_argument( + "--model_dir", + type=str, + default=None, + help="the path to save model files; uses ~/.cache/whisper by default", + ) + parser.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + help="device to use for PyTorch inference", + ) + parser.add_argument( + "--output_dir", + "-o", + type=str, + default=".", + help="directory to save the outputs", + ) + parser.add_argument( + "--output_format", + "-f", + type=str, + default="all", + choices=["txt", "vtt", "srt", "tsv", "json", "all"], + help="format of the output file; if not specified, all available formats will be produced", + ) + parser.add_argument( + "--verbose", + type=str2bool, + default=True, + help="whether to print out the progress and debug messages", + ) + + parser.add_argument( + "--task", + type=str, + default="transcribe", + choices=["transcribe", "translate"], + help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')", + ) + parser.add_argument( + "--language", + type=str, + default=None, + choices=sorted(LANGUAGES.keys()) + + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), + help="language spoken in the audio, specify None to perform language detection", + ) + + parser.add_argument( + "--temperature", type=float, default=0, help="temperature to use for sampling" + ) + parser.add_argument( + "--best_of", + type=optional_int, + default=5, + help="number of candidates when sampling with non-zero temperature", + ) + parser.add_argument( + "--beam_size", + type=optional_int, + default=5, + help="number of beams in beam search, only applicable when temperature is zero", + ) + parser.add_argument( + "--patience", + type=float, + default=None, + help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search", + ) + parser.add_argument( + "--length_penalty", + type=float, + default=None, + help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default", + ) + + parser.add_argument( + "--suppress_tokens", + type=str, + default="-1", + help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations", + ) + parser.add_argument( + "--initial_prompt", + type=str, + default=None, + help="optional text to provide as a prompt for the first window.", + ) + parser.add_argument( + "--condition_on_previous_text", + type=str2bool, + default=True, + help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop", + ) + parser.add_argument( + "--fp16", + type=str2bool, + default=True, + help="whether to perform inference in fp16; True by default", + ) + + parser.add_argument( + "--temperature_increment_on_fallback", + type=optional_float, + default=0.2, + help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below", + ) + parser.add_argument( + "--compression_ratio_threshold", + type=optional_float, + default=2.4, + help="if the gzip compression ratio is higher than this value, treat the decoding as failed", + ) + parser.add_argument( + "--logprob_threshold", + type=optional_float, + default=-1.0, + help="if the average log probability is lower than this value, treat the decoding as failed", + ) + parser.add_argument( + "--no_speech_threshold", + type=optional_float, + default=0.6, + help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence", + ) + parser.add_argument( + "--threads", + type=optional_int, + default=0, + help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS", + ) args = parser.parse_args().__dict__ model_name: str = args.pop("model") @@ -292,29 +457,29 @@ def cli(): if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if args["language"] is not None: - warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") + warnings.warn( + f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead." + ) args["language"] = "en" temperature = args.pop("temperature") - temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback") - if temperature_increment_on_fallback is not None: - temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback)) + if (increment := args.pop("temperature_increment_on_fallback")) is not None: + temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) else: temperature = [temperature] - threads = args.pop("threads") - if threads > 0: + if (threads := args.pop("threads")) > 0: torch.set_num_threads(threads) from . import load_model + model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) - for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) writer(result, audio_path) -if __name__ == '__main__': +if __name__ == "__main__": cli()