diff --git a/buzz/model_loader.py b/buzz/model_loader.py index 2eaedac742..e31d2df087 100644 --- a/buzz/model_loader.py +++ b/buzz/model_loader.py @@ -124,6 +124,13 @@ def open_file_location(self): return self.open_path(path=os.path.dirname(model_path)) + @staticmethod + def default(): + model_type = next( + model_type for model_type in ModelType if model_type.is_available() + ) + return TranscriptionModel(model_type=model_type) + @staticmethod def open_path(path: str): if sys.platform == "win32": diff --git a/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py b/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py index f3749f3d06..077df80b7a 100644 --- a/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py +++ b/buzz/widgets/preferences_dialog/models/file_transcription_preferences.py @@ -39,7 +39,9 @@ def save(self, settings: QSettings) -> None: def load(cls, settings: QSettings) -> "FileTranscriptionPreferences": language = settings.value("language", None) task = settings.value("task", Task.TRANSCRIBE) - model = settings.value("model", TranscriptionModel()) + model: TranscriptionModel = settings.value( + "model", TranscriptionModel.default() + ) word_level_timings = settings.value("word_level_timings", False) temperature = settings.value("temperature", DEFAULT_WHISPER_TEMPERATURE) initial_prompt = settings.value("initial_prompt", "") @@ -47,7 +49,9 @@ def load(cls, settings: QSettings) -> "FileTranscriptionPreferences": return FileTranscriptionPreferences( language=language, task=task, - model=model, + model=model + if model.model_type.is_available() + else TranscriptionModel.default(), word_level_timings=word_level_timings, temperature=temperature, initial_prompt=initial_prompt, diff --git a/buzz/widgets/transcriber/file_transcriber_widget.py b/buzz/widgets/transcriber/file_transcriber_widget.py index bef88e7ec2..29e749237d 100644 --- a/buzz/widgets/transcriber/file_transcriber_widget.py +++ b/buzz/widgets/transcriber/file_transcriber_widget.py @@ -56,6 +56,7 @@ def __init__( self.file_paths = file_paths preferences = self.load_preferences() + print(preferences) ( self.transcription_options, diff --git a/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py b/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py index d19d43fb48..212ba173a2 100644 --- a/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py +++ b/tests/widgets/preferences_dialog/folder_watch_preferences_widget_test.py @@ -25,7 +25,7 @@ def test_edit_folder_watch_preferences(self, qtbot): file_transcription_options=FileTranscriptionPreferences( language=None, task=Task.TRANSCRIBE, - model=TranscriptionModel(), + model=TranscriptionModel.default(), word_level_timings=False, temperature=DEFAULT_WHISPER_TEMPERATURE, initial_prompt="", diff --git a/tests/widgets/preferences_dialog/models_preferences_widget_test.py b/tests/widgets/preferences_dialog/models_preferences_widget_test.py index 48e1f17c53..3571f984fb 100644 --- a/tests/widgets/preferences_dialog/models_preferences_widget_test.py +++ b/tests/widgets/preferences_dialog/models_preferences_widget_test.py @@ -89,11 +89,7 @@ def downloaded_model(): @pytest.fixture(scope="class") def default_model_path(self) -> str: - model_type = next( - model_type for model_type in ModelType if model_type.is_available() - ) - model = TranscriptionModel(model_type=model_type) - return get_model_path(transcription_model=model) + return get_model_path(transcription_model=(TranscriptionModel.default())) def test_should_show_downloaded_model(self, qtbot, default_model_path): widget = ModelsPreferencesWidget()