diff --git a/buzz/transcriber/transcriber.py b/buzz/transcriber/transcriber.py index 7e365ccb01..55a6ad2a6c 100644 --- a/buzz/transcriber/transcriber.py +++ b/buzz/transcriber/transcriber.py @@ -235,11 +235,19 @@ class Stopped(Exception): Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)" -def get_output_file_path(task: FileTranscriptionTask, output_format: OutputFormat): +def get_output_file_path( + task: FileTranscriptionTask, + output_format: OutputFormat, + export_file_name_template: str | None = None, +): input_file_name = os.path.splitext(os.path.basename(task.file_path))[0] date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S") - export_file_name_template = Settings().get_default_export_file_template() + export_file_name_template = ( + export_file_name_template + if not None + else Settings().get_default_export_file_template() + ) output_file_name = ( export_file_name_template.replace("{{ input_file_name }}", input_file_name) diff --git a/tests/transcriber/whisper_file_transcriber_test.py b/tests/transcriber/whisper_file_transcriber_test.py index 7e6a571227..5e79100bf0 100644 --- a/tests/transcriber/whisper_file_transcriber_test.py +++ b/tests/transcriber/whisper_file_transcriber_test.py @@ -31,20 +31,18 @@ class TestWhisperFileTranscriber: @pytest.mark.parametrize( - "file_path,output_format,expected_file_path,default_output_file_name", + "file_path,output_format,expected_file_path", [ pytest.param( "/a/b/c.mp4", OutputFormat.SRT, "/a/b/c-translate--Whisper-tiny.srt", - "{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}", marks=pytest.mark.skipif(platform.system() == "Windows", reason=""), ), pytest.param( "C:\\a\\b\\c.mp4", OutputFormat.SRT, "C:\\a\\b\\c-translate--Whisper-tiny.srt", - "{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}", marks=pytest.mark.skipif(platform.system() != "Windows", reason=""), ), ], @@ -54,7 +52,6 @@ def test_default_output_file( file_path: str, output_format: OutputFormat, expected_file_path: str, - default_output_file_name: str, ): file_path = get_output_file_path( task=FileTranscriptionTask( @@ -64,56 +61,10 @@ def test_default_output_file( model_path="", ), output_format=output_format, + export_file_name_template="{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}-{{ date_time }}", ) assert file_path == expected_file_path - @pytest.mark.parametrize( - "file_path,expected_starts_with", - [ - pytest.param( - "/a/b/c.mp4", - "/a/b/c (Translated on ", - marks=pytest.mark.skipif(platform.system() == "Windows", reason=""), - ), - pytest.param( - "C:\\a\\b\\c.mp4", - "C:\\a\\b\\c (Translated on ", - marks=pytest.mark.skipif(platform.system() != "Windows", reason=""), - ), - ], - ) - def test_default_output_file_with_date( - self, file_path: str, expected_starts_with: str - ): - srt = get_output_file_path( - task=FileTranscriptionTask( - file_path=file_path, - transcription_options=TranscriptionOptions(task=Task.TRANSLATE), - file_transcription_options=FileTranscriptionOptions( - file_paths=[], - ), - model_path="", - ), - output_format=OutputFormat.TXT, - ) - - assert srt.startswith(expected_starts_with) - assert srt.endswith(".txt") - - srt = get_output_file_path( - task=FileTranscriptionTask( - file_path=file_path, - transcription_options=TranscriptionOptions(task=Task.TRANSLATE), - file_transcription_options=FileTranscriptionOptions( - file_paths=[], - ), - model_path="", - ), - output_format=OutputFormat.SRT, - ) - assert srt.startswith(expected_starts_with) - assert srt.endswith(".srt") - @pytest.mark.parametrize( "word_level_timings,expected_segments,model", [