Skip to content

Commit

Permalink
refactor(api): simplify run_whisper_mac_english call for MPS device
Browse files Browse the repository at this point in the history
  • Loading branch information
zackees committed Feb 1, 2025
1 parent f9fb85d commit d973f80
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/transcribe_anything/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def transcribe(
hugging_face_token=hugging_face_token,
other_args=other_args,
)
elif device_enum == Device.MPS and (language_str == "" or language_str == "en" or language_str == "English"):
run_whisper_mac_english(input_wav=Path(tmp_wav), model=model_str, output_dir=Path(tmpdir), task=task_str)
elif device_enum == Device.MPS and (language_str == "" or language_str == "en" or language_str == "English") and (task_str == "transcribe"):
run_whisper_mac_english(input_wav=Path(tmp_wav), model=model_str, output_dir=Path(tmpdir))
else:
run_whisper(
input_wav=Path(tmp_wav),
Expand Down
4 changes: 2 additions & 2 deletions src/transcribe_anything/insanely_fast_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def run_insanely_fast_whisper(
if sys.platform == "darwin":
# Attempts fixed recommended for the mps machines. This seems
# to be necessary since a recent update.
env["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "7"
env["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
env["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0"
# env["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
device_id = get_device_id()
cmd_list = []
output_dir.mkdir(parents=True, exist_ok=True)
Expand Down
4 changes: 1 addition & 3 deletions src/transcribe_anything/whisper_mac.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def run_whisper_mac_english( # pylint: disable=too-many-arguments
input_wav: Path,
model: str,
output_dir: Path,
task: str,
) -> None:
"""Runs whisper."""
input_wav_abs = input_wav.resolve()
Expand All @@ -59,8 +58,7 @@ def run_whisper_mac_english( # pylint: disable=too-many-arguments
cmd_list.append(model)
cmd_list.append("--output_dir")
cmd_list.append(str(output_dir))
cmd_list.append("--task")
cmd_list.append(task)

# Remove the empty strings.
cmd_list = [str(x).strip() for x in cmd_list if str(x).strip()]
# cmd = " ".join(cmd_list)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_insanely_fast_whisper_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import unittest
from pathlib import Path


from transcribe_anything.util import is_mac_arm
from transcribe_anything.whisper_mac import run_whisper_mac_english

Expand All @@ -36,7 +35,5 @@ def test_local_file(self) -> None:
)




if __name__ == "__main__":
unittest.main()

0 comments on commit d973f80

Please sign in to comment.