From 7f0dc9e173c5e7ce34075c544eec60d6d39d411b Mon Sep 17 00:00:00 2001 From: Alejandro Molina Date: Sun, 22 Oct 2023 17:37:17 +0200 Subject: [PATCH 1/7] ADD parser for new argument --max_words_count --- whisper/transcribe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 6e43a22fa..372f99d63 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -412,6 +412,7 @@ def cli(): parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") + parser.add_argument("--max_words_count", type=optional_int, default=None, help="(requires --word_timestamps True) (no effect with --max_line_width) the maximum number of words in a segment") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") # fmt: on From f11faf247e5666806bfa9a3f2a917b23a44cc308 Mon Sep 17 00:00:00 2001 From: Alejandro Molina Date: Sun, 22 Oct 2023 17:40:03 +0200 Subject: [PATCH 2/7] ADD max_words_count in words_options ADD warning for max_line_width compatibility --- whisper/transcribe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 372f99d63..cd12c43f9 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -445,13 +445,15 @@ def cli(): model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) - word_options = ["highlight_words", "max_line_count", "max_line_width"] + word_options = ["highlight_words", "max_line_count", "max_line_width", "max_words_count"] if not args["word_timestamps"]: for option in word_options: if args[option]: parser.error(f"--{option} requires --word_timestamps True") if args["max_line_count"] and not args["max_line_width"]: warnings.warn("--max_line_count has no effect without --max_line_width") + if args["max_words_count"] and args["max_line_width"]: + warnings.warn("--max_words_count has no effect with --max_line_width") writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) From 376acb29345d9c5d5863b71438180b722de7e504 Mon Sep 17 00:00:00 2001 From: Alejandro Molina Date: Sun, 22 Oct 2023 17:41:37 +0200 Subject: [PATCH 3/7] ADD logic for max_words_count --- whisper/utils.py | 65 +++++++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/whisper/utils.py b/whisper/utils.py index 22260d0d9..7192b7d97 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -104,7 +104,9 @@ def iterate_result(self, result: dict, options: dict): raw_max_line_width: Optional[int] = options["max_line_width"] max_line_count: Optional[int] = options["max_line_count"] highlight_words: bool = options["highlight_words"] + max_words_count: Optional[int] = options["max_words_count"] max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width + max_words_count = 1000 if max_words_count is None else max_words_count preserve_segments = max_line_count is None or raw_max_line_width is None def iterate_subtitles(): @@ -114,34 +116,41 @@ def iterate_subtitles(): subtitle: list[dict] = [] last = result["segments"][0]["words"][0]["start"] for segment in result["segments"]: - for i, original_timing in enumerate(segment["words"]): - timing = original_timing.copy() - long_pause = not preserve_segments and timing["start"] - last > 3.0 - has_room = line_len + len(timing["word"]) <= max_line_width - seg_break = i == 0 and len(subtitle) > 0 and preserve_segments - if line_len > 0 and has_room and not long_pause and not seg_break: - # line continuation - line_len += len(timing["word"]) - else: - # new line - timing["word"] = timing["word"].strip() - if ( - len(subtitle) > 0 - and max_line_count is not None - and (long_pause or line_count >= max_line_count) - or seg_break - ): - # subtitle break - yield subtitle - subtitle = [] - line_count = 1 - elif line_len > 0: - # line break - line_count += 1 - timing["word"] = "\n" + timing["word"] - line_len = len(timing["word"].strip()) - subtitle.append(timing) - last = timing["start"] + chunk_index = 0 + words_count = max_words_count + while chunk_index < len(segment["words"]): + remaining_words = len(segment["words"]) - chunk_index + if max_words_count > len(segment["words"]) - chunk_index: + words_count = remaining_words + for i, original_timing in enumerate(segment["words"][chunk_index:chunk_index + words_count]): + timing = original_timing.copy() + long_pause = not preserve_segments and timing["start"] - last > 3.0 + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if line_len > 0 and has_room and not long_pause and not seg_break: + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle + subtitle = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + last = timing["start"] + chunk_index += max_words_count if len(subtitle) > 0: yield subtitle From 8e5200bf9bed34477d08218e9d75dbda6b7a5a79 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 6 Nov 2023 01:11:04 -0800 Subject: [PATCH 4/7] rename to max_words_per_line --- whisper/transcribe.py | 8 ++++---- whisper/utils.py | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index cd12c43f9..450dcac16 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -412,7 +412,7 @@ def cli(): parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt") parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") - parser.add_argument("--max_words_count", type=optional_int, default=None, help="(requires --word_timestamps True) (no effect with --max_line_width) the maximum number of words in a segment") + parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") # fmt: on @@ -445,15 +445,15 @@ def cli(): model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) - word_options = ["highlight_words", "max_line_count", "max_line_width", "max_words_count"] + word_options = ["highlight_words", "max_line_count", "max_line_width", "max_words_per_line"] if not args["word_timestamps"]: for option in word_options: if args[option]: parser.error(f"--{option} requires --word_timestamps True") if args["max_line_count"] and not args["max_line_width"]: warnings.warn("--max_line_count has no effect without --max_line_width") - if args["max_words_count"] and args["max_line_width"]: - warnings.warn("--max_words_count has no effect with --max_line_width") + if args["max_words_per_line"] and args["max_line_width"]: + warnings.warn("--max_words_per_line has no effect with --max_line_width") writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) diff --git a/whisper/utils.py b/whisper/utils.py index 7192b7d97..8a31bd931 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -104,9 +104,9 @@ def iterate_result(self, result: dict, options: dict): raw_max_line_width: Optional[int] = options["max_line_width"] max_line_count: Optional[int] = options["max_line_count"] highlight_words: bool = options["highlight_words"] - max_words_count: Optional[int] = options["max_words_count"] + max_words_per_line: Optional[int] = options["max_words_per_line"] max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width - max_words_count = 1000 if max_words_count is None else max_words_count + max_words_per_line = 1000 if max_words_per_line is None else max_words_per_line preserve_segments = max_line_count is None or raw_max_line_width is None def iterate_subtitles(): @@ -117,10 +117,10 @@ def iterate_subtitles(): last = result["segments"][0]["words"][0]["start"] for segment in result["segments"]: chunk_index = 0 - words_count = max_words_count + words_count = max_words_per_line while chunk_index < len(segment["words"]): remaining_words = len(segment["words"]) - chunk_index - if max_words_count > len(segment["words"]) - chunk_index: + if max_words_per_line > len(segment["words"]) - chunk_index: words_count = remaining_words for i, original_timing in enumerate(segment["words"][chunk_index:chunk_index + words_count]): timing = original_timing.copy() @@ -150,7 +150,7 @@ def iterate_subtitles(): line_len = len(timing["word"].strip()) subtitle.append(timing) last = timing["start"] - chunk_index += max_words_count + chunk_index += max_words_per_line if len(subtitle) > 0: yield subtitle From 541adb4f3896caa1e223cd88ae9a366ad7734b8b Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 6 Nov 2023 01:32:19 -0800 Subject: [PATCH 5/7] make them kwargs --- whisper/transcribe.py | 2 +- whisper/utils.py | 50 ++++++++++++++++++++++++++----------------- 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 450dcac16..6b3e0664b 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -457,7 +457,7 @@ def cli(): writer_args = {arg: args.pop(arg) for arg in word_options} for audio_path in args.pop("audio"): result = transcribe(model, audio_path, temperature=temperature, **args) - writer(result, audio_path, writer_args) + writer(result, audio_path, **writer_args) if __name__ == "__main__": diff --git a/whisper/utils.py b/whisper/utils.py index 8a31bd931..fd118e3f7 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -74,7 +74,7 @@ class ResultWriter: def __init__(self, output_dir: str): self.output_dir = output_dir - def __call__(self, result: dict, audio_path: str, options: dict): + def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs): audio_basename = os.path.basename(audio_path) audio_basename = os.path.splitext(audio_basename)[0] output_path = os.path.join( @@ -82,16 +82,16 @@ def __call__(self, result: dict, audio_path: str, options: dict): ) with open(output_path, "w", encoding="utf-8") as f: - self.write_result(result, file=f, options=options) + self.write_result(result, file=f, options=options, **kwargs) - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): raise NotImplementedError class WriteTXT(ResultWriter): extension: str = "txt" - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): for segment in result["segments"]: print(segment["text"].strip(), file=file, flush=True) @@ -100,14 +100,24 @@ class SubtitlesWriter(ResultWriter): always_include_hours: bool decimal_marker: str - def iterate_result(self, result: dict, options: dict): - raw_max_line_width: Optional[int] = options["max_line_width"] - max_line_count: Optional[int] = options["max_line_count"] - highlight_words: bool = options["highlight_words"] - max_words_per_line: Optional[int] = options["max_words_per_line"] - max_line_width = 1000 if raw_max_line_width is None else raw_max_line_width - max_words_per_line = 1000 if max_words_per_line is None else max_words_per_line - preserve_segments = max_line_count is None or raw_max_line_width is None + def iterate_result( + self, + result: dict, + options: Optional[dict] = None, + *, + max_line_width: Optional[int] = None, + max_line_count: Optional[int] = None, + highlight_words: bool = False, + max_words_per_line: Optional[int] = None, + ): + options = options or {} + max_line_width = max_line_width or options.get("max_line_width") + max_line_count = max_line_count or options.get("max_line_count") + highlight_words = highlight_words or options.get("highlight_words", False) + max_words_per_line = max_words_per_line or options.get("max_words_per_line") + preserve_segments = max_line_count is None or max_line_width is None + max_line_width = max_line_width or 1000 + max_words_per_line = max_words_per_line or 1000 def iterate_subtitles(): line_len = 0 @@ -199,9 +209,9 @@ class WriteVTT(SubtitlesWriter): always_include_hours: bool = False decimal_marker: str = "." - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): print("WEBVTT\n", file=file) - for start, end, text in self.iterate_result(result, options): + for start, end, text in self.iterate_result(result, options, **kwargs): print(f"{start} --> {end}\n{text}\n", file=file, flush=True) @@ -210,9 +220,9 @@ class WriteSRT(SubtitlesWriter): always_include_hours: bool = True decimal_marker: str = "," - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): for i, (start, end, text) in enumerate( - self.iterate_result(result, options), start=1 + self.iterate_result(result, options, **kwargs), start=1 ): print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) @@ -229,7 +239,7 @@ class WriteTSV(ResultWriter): extension: str = "tsv" - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: print(round(1000 * segment["start"]), file=file, end="\t") @@ -240,7 +250,7 @@ def write_result(self, result: dict, file: TextIO, options: dict): class WriteJSON(ResultWriter): extension: str = "json" - def write_result(self, result: dict, file: TextIO, options: dict): + def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): json.dump(result, file) @@ -258,9 +268,9 @@ def get_writer( if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: TextIO, options: dict): + def write_all(result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): for writer in all_writers: - writer(result, file, options) + writer(result, file, options, **kwargs) return write_all From 832c0e97281b272b4a003617114950a05bdbeee1 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 6 Nov 2023 01:35:55 -0800 Subject: [PATCH 6/7] allow specifying file path by --model --- whisper/transcribe.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 6b3e0664b..6136ac123 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -378,10 +378,17 @@ def new_segment( def cli(): from . import available_models + def valid_model_name(name): + if name in available_models() or os.path.exists(name): + return name + raise ValueError( + f"model should be one of {available_models()} or path to a model checkpoint" + ) + # fmt: off parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") - parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") + parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") From 177588911abf70feab33d2f1d6debe5663193567 Mon Sep 17 00:00:00 2001 From: Jong Wook Kim Date: Mon, 6 Nov 2023 01:38:55 -0800 Subject: [PATCH 7/7] black formatting --- whisper/transcribe.py | 7 ++++++- whisper/utils.py | 47 +++++++++++++++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 6136ac123..509e322e3 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -452,7 +452,12 @@ def valid_model_name(name): model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) - word_options = ["highlight_words", "max_line_count", "max_line_width", "max_words_per_line"] + word_options = [ + "highlight_words", + "max_line_count", + "max_line_width", + "max_words_per_line", + ] if not args["word_timestamps"]: for option in word_options: if args[option]: diff --git a/whisper/utils.py b/whisper/utils.py index fd118e3f7..7a172c401 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -74,7 +74,9 @@ class ResultWriter: def __init__(self, output_dir: str): self.output_dir = output_dir - def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs): + def __call__( + self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + ): audio_basename = os.path.basename(audio_path) audio_basename = os.path.splitext(audio_basename)[0] output_path = os.path.join( @@ -84,14 +86,18 @@ def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None with open(output_path, "w", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) - def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): raise NotImplementedError class WriteTXT(ResultWriter): extension: str = "txt" - def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): for segment in result["segments"]: print(segment["text"].strip(), file=file, flush=True) @@ -132,12 +138,21 @@ def iterate_subtitles(): remaining_words = len(segment["words"]) - chunk_index if max_words_per_line > len(segment["words"]) - chunk_index: words_count = remaining_words - for i, original_timing in enumerate(segment["words"][chunk_index:chunk_index + words_count]): + for i, original_timing in enumerate( + segment["words"][chunk_index : chunk_index + words_count] + ): timing = original_timing.copy() - long_pause = not preserve_segments and timing["start"] - last > 3.0 + long_pause = ( + not preserve_segments and timing["start"] - last > 3.0 + ) has_room = line_len + len(timing["word"]) <= max_line_width seg_break = i == 0 and len(subtitle) > 0 and preserve_segments - if line_len > 0 and has_room and not long_pause and not seg_break: + if ( + line_len > 0 + and has_room + and not long_pause + and not seg_break + ): # line continuation line_len += len(timing["word"]) else: @@ -209,7 +224,9 @@ class WriteVTT(SubtitlesWriter): always_include_hours: bool = False decimal_marker: str = "." - def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): print("WEBVTT\n", file=file) for start, end, text in self.iterate_result(result, options, **kwargs): print(f"{start} --> {end}\n{text}\n", file=file, flush=True) @@ -220,7 +237,9 @@ class WriteSRT(SubtitlesWriter): always_include_hours: bool = True decimal_marker: str = "," - def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): for i, (start, end, text) in enumerate( self.iterate_result(result, options, **kwargs), start=1 ): @@ -239,7 +258,9 @@ class WriteTSV(ResultWriter): extension: str = "tsv" - def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: print(round(1000 * segment["start"]), file=file, end="\t") @@ -250,7 +271,9 @@ def write_result(self, result: dict, file: TextIO, options: Optional[dict] = Non class WriteJSON(ResultWriter): extension: str = "json" - def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): json.dump(result, file) @@ -268,7 +291,9 @@ def get_writer( if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: TextIO, options: Optional[dict] = None, **kwargs): + def write_all( + result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): for writer in all_writers: writer(result, file, options, **kwargs)