From 32f642d91fe7d3a2b67fb7f4554f72718622c52f Mon Sep 17 00:00:00 2001 From: Chidi Williams Date: Thu, 4 Jan 2024 01:08:56 +0000 Subject: [PATCH] fix: execstack --- tests/transcriber_benchmarks_test.py | 45 ++++++++++------------------ 1 file changed, 15 insertions(+), 30 deletions(-) diff --git a/tests/transcriber_benchmarks_test.py b/tests/transcriber_benchmarks_test.py index 53112d37f8..ec23b59d51 100644 --- a/tests/transcriber_benchmarks_test.py +++ b/tests/transcriber_benchmarks_test.py @@ -43,44 +43,29 @@ def transcribe(qtbot, transcriber: FileTranscriber): @pytest.mark.parametrize( - "transcriber", + "transcriber, model", [ pytest.param( - WhisperCppFileTranscriber( - task=( - get_task( - TranscriptionModel( - model_type=ModelType.WHISPER_CPP, - whisper_model_size=WhisperModelSize.TINY, - ) - ) - ) + WhisperCppFileTranscriber, + TranscriptionModel( + model_type=ModelType.WHISPER_CPP, + whisper_model_size=WhisperModelSize.TINY, ), id="Whisper.cpp - Tiny", ), pytest.param( - WhisperFileTranscriber( - task=( - get_task( - TranscriptionModel( - model_type=ModelType.WHISPER, - whisper_model_size=WhisperModelSize.TINY, - ) - ) - ) + WhisperFileTranscriber, + TranscriptionModel( + model_type=ModelType.WHISPER, + whisper_model_size=WhisperModelSize.TINY, ), id="Whisper - Tiny", ), pytest.param( - WhisperFileTranscriber( - task=( - get_task( - TranscriptionModel( - model_type=ModelType.FASTER_WHISPER, - whisper_model_size=WhisperModelSize.TINY, - ) - ) - ) + WhisperFileTranscriber, + TranscriptionModel( + model_type=ModelType.FASTER_WHISPER, + whisper_model_size=WhisperModelSize.TINY, ), id="Faster Whisper - Tiny", marks=pytest.mark.skipif( @@ -95,6 +80,6 @@ def transcribe(qtbot, transcriber: FileTranscriber): @pytest.mark.skipif( platform.system() == "Linux", reason="Avoid execstack errors on Snap" ) -def test_should_transcribe_and_benchmark(qtbot, benchmark, transcriber): - segments = benchmark(transcribe, qtbot, transcriber) +def test_should_transcribe_and_benchmark(qtbot, benchmark, transcriber, model): + segments = benchmark(transcribe, qtbot, transcriber(task=get_task(model))) assert len(segments) > 0