Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Oct 16, 2024
1 parent e5f7593 commit 7462f0c
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 110 deletions.
22 changes: 8 additions & 14 deletions ct2_gui.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from PySide6.QtWidgets import (
QApplication, QWidget, QVBoxLayout, QPushButton, QLabel,
QComboBox, QHBoxLayout, QGroupBox, QMessageBox
)
from PySide6.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox, QHBoxLayout, QGroupBox
from PySide6.QtCore import Qt
from ct2_logic import VoiceRecorder
import yaml
import logging

logger = logging.getLogger(__name__)

class MyWindow(QWidget):
def __init__(self, cuda_available=False):
Expand All @@ -25,10 +19,9 @@ def __init__(self, cuda_available=False):
config = yaml.safe_load(f)
model = config.get("model_name", "base.en")
quantization = config.get("quantization_type", "int8")
device = config.get("device_type", "cpu")
device = config.get("device_type", "auto")
self.supported_quantizations = config.get("supported_quantizations", {"cpu": [], "cuda": []})
except FileNotFoundError:
logger.warning("config.yaml not found. Using default settings.")
model, quantization, device = "base.en", "int8", "cpu"
self.supported_quantizations = {"cpu": [], "cuda": []}

Expand All @@ -37,7 +30,7 @@ def __init__(self, cuda_available=False):
layout.addWidget(self.record_button)

self.stop_button = QPushButton("Stop and Transcribe", self)
self.stop_button.clicked.connect(self.recorder.save_audio)
self.stop_button.clicked.connect(self.recorder.stop_recording)
layout.addWidget(self.stop_button)

settings_group = QGroupBox("Settings")
Expand Down Expand Up @@ -128,7 +121,8 @@ def set_widgets_enabled(self, enabled):
self.quantization_dropdown.setEnabled(enabled)
self.device_dropdown.setEnabled(enabled)
self.update_model_btn.setEnabled(enabled)
if not enabled:
QApplication.setOverrideCursor(Qt.WaitCursor)
else:
QApplication.restoreOverrideCursor()

def closeEvent(self, event):
if hasattr(self, 'recorder'):
self.recorder.stop_all_threads()
super().closeEvent(event)
183 changes: 136 additions & 47 deletions ct2_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import numpy as np
import wave
import os
import tempfile
import psutil
from PySide6.QtWidgets import QApplication
from PySide6.QtCore import QObject, Signal, Slot, QThread
from PySide6.QtCore import QObject, Signal, Slot, QThread, QMutex, QWaitCondition
from faster_whisper import WhisperModel
import yaml
import threading
import logging
import tempfile
from contextlib import contextmanager
from pathlib import Path
import queue

logger = logging.getLogger(__name__)

Expand All @@ -24,17 +27,27 @@ def __init__(self, model_name, quantization_type, device_type):

def run(self):
try:
if self.isInterruptionRequested():
return

if self.model_name.startswith("distil-whisper"):
model_str = f"ctranslate2-4you/{self.model_name}-ct2-{self.quantization_type}"
else:
model_str = f"ctranslate2-4you/whisper-{self.model_name}-ct2-{self.quantization_type}"

if self.isInterruptionRequested():
return

model = WhisperModel(
model_str,
device=self.device_type,
compute_type=self.quantization_type,
cpu_threads=26
cpu_threads=psutil.cpu_count(logical=False)
)

if self.isInterruptionRequested():
return

self.model_loaded.emit(model, self.model_name)
except Exception as e:
error_message = f"Error loading model: {str(e)}"
Expand All @@ -52,7 +65,14 @@ def __init__(self, model, audio_file):

def run(self):
try:
if self.isInterruptionRequested():
return

segments, _ = self.model.transcribe(self.audio_file)

if self.isInterruptionRequested():
return

clipboard_text = "\n".join([segment.text for segment in segments])
self.transcription_done.emit(clipboard_text)
except Exception as e:
Expand All @@ -61,10 +81,58 @@ def run(self):
self.error_occurred.emit(error_message)
finally:
try:
os.remove(self.audio_file)
Path(self.audio_file).unlink(missing_ok=True)
except OSError as e:
logger.warning(f"Error deleting temporary file: {e}")

class RecordingThread(QThread):
update_status_signal = Signal(str)
recording_error = Signal(str)
recording_finished = Signal()

def __init__(self, samplerate, channels, dtype):
super().__init__()
self.samplerate = samplerate
self.channels = channels
self.dtype = dtype
self.is_recording = QWaitCondition()
self.mutex = QMutex()
self.buffer = queue.Queue()

@contextmanager
def audio_stream(self):
stream = sd.InputStream(samplerate=self.samplerate, channels=self.channels, dtype=self.dtype, callback=self.audio_callback)
try:
with stream:
yield
finally:
stream.close()

def audio_callback(self, indata, frames, time, status):
if status:
logger.warning(status)
self.buffer.put(indata.copy())

def run(self):
self.mutex.lock()
self.update_status_signal.emit("Recording...")

try:
with self.audio_stream():
while not self.isInterruptionRequested():
self.is_recording.wait(self.mutex)
except Exception as e:
error_message = f"Recording error: {e}"
logger.error(error_message)
self.recording_error.emit(error_message)
finally:
self.mutex.unlock()
self.recording_finished.emit()

def stop(self):
self.requestInterruption()
self.is_recording.wakeAll()

class VoiceRecorder(QObject):
update_status_signal = Signal(str)
enable_widgets_signal = Signal(bool)
Expand All @@ -75,15 +143,14 @@ def __init__(self, window, samplerate=44100, channels=1, dtype='int16'):
self.channels = channels
self.dtype = dtype
self.window = window
self.is_recording = False
self.frames = []
self.model = None
self.model_lock = threading.Lock()
self.model_mutex = QMutex()
self.load_settings()

def load_settings(self):
config_path = Path("config.yaml")
try:
with open("config.yaml", "r") as f:
with config_path.open("r") as f:
config = yaml.safe_load(f)
model_name = config.get("model_name", "base.en")
quantization_type = config.get("quantization_type", "int8")
Expand All @@ -99,22 +166,24 @@ def save_settings(self, model_name, quantization_type, device_type):
"quantization_type": quantization_type,
"device_type": device_type
}
with open("config.yaml", "w") as f:
config_path = Path("config.yaml")
with config_path.open("w") as f:
yaml.safe_dump(config, f)

def update_model(self, model_name, quantization_type, device_type):
self.enable_widgets_signal.emit(False)
self.update_status_signal.emit(f"Updating model to {model_name}...")

self.model_loader_thread = ModelLoaderThread(model_name, quantization_type, device_type)
self.model_loader_thread.model_loaded.connect(self.on_model_loaded)
self.model_loader_thread.error_occurred.connect(self.on_model_load_error)
self.model_loader_thread.start()

@Slot(object, str)
def on_model_loaded(self, model, model_name):
with self.model_lock:
self.model = model
self.model_mutex.lock()
self.model = model
self.model_mutex.unlock()
self.save_settings(model_name, self.model_loader_thread.quantization_type, self.model_loader_thread.device_type)
self.update_status_signal.emit(f"Model updated to {model_name} on {self.model_loader_thread.device_type} device")
self.enable_widgets_signal.emit(True)
Expand All @@ -126,12 +195,14 @@ def on_model_load_error(self, error_message):

def transcribe_audio(self, audio_file):
self.update_status_signal.emit("Transcribing audio...")
with self.model_lock:
if self.model is None:
self.update_status_signal.emit("No model loaded.")
self.enable_widgets_signal.emit(True)
return
model = self.model
self.model_mutex.lock()
if self.model is None:
self.model_mutex.unlock()
self.update_status_signal.emit("No model loaded.")
self.enable_widgets_signal.emit(True)
return
model = self.model
self.model_mutex.unlock()

self.transcription_thread = TranscriptionThread(model, audio_file)
self.transcription_thread.transcription_done.connect(self.on_transcription_done)
Expand All @@ -149,47 +220,65 @@ def on_transcription_error(self, error_message):
self.update_status_signal.emit(error_message)
self.enable_widgets_signal.emit(True)

def record_audio(self):
self.update_status_signal.emit("Recording...")
def callback(indata, frames, time, status):
if status:
logger.warning(status)
self.frames.append(indata.copy())
try:
with sd.InputStream(samplerate=self.samplerate, channels=self.channels, dtype=self.dtype, callback=callback):
while self.is_recording:
sd.sleep(100)
except Exception as e:
error_message = f"Recording error: {e}"
logger.error(error_message)
self.update_status_signal.emit(error_message)
self.enable_widgets_signal.emit(True)
@Slot(str)
def on_recording_error(self, error_message):
self.update_status_signal.emit(error_message)
self.enable_widgets_signal.emit(True)

@Slot()
def on_recording_finished(self):
self.save_audio()

def start_recording(self):
if not hasattr(self, 'recording_thread') or not self.recording_thread.isRunning():
self.recording_thread = RecordingThread(self.samplerate, self.channels, self.dtype)
self.recording_thread.update_status_signal.connect(self.update_status_signal)
self.recording_thread.recording_error.connect(self.on_recording_error)
self.recording_thread.recording_finished.connect(self.on_recording_finished)
self.recording_thread.start()
else:
self.update_status_signal.emit("Already recording.")

def stop_recording(self):
if hasattr(self, 'recording_thread') and self.recording_thread.isRunning():
self.recording_thread.stop()
else:
self.update_status_signal.emit("Not currently recording.")

def save_audio(self):
self.is_recording = False
self.enable_widgets_signal.emit(False)
temp_filename = tempfile.mktemp(suffix=".wav")
data = np.concatenate(self.frames, axis=0)
audio_data = []
while not self.recording_thread.buffer.empty():
audio_data.append(self.recording_thread.buffer.get())
data = np.concatenate(audio_data)

with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_filename = temp_file.name

try:
with wave.open(temp_filename, "wb") as wf:
wf.setnchannels(self.channels)
wf.setsampwidth(2) # Always 2 for int16
wf.setsampwidth(2)
wf.setframerate(self.samplerate)
wf.writeframes(data.tobytes())

self.update_status_signal.emit("Audio saved, starting transcription...")
self.transcribe_audio(temp_filename)
except Exception as e:
error_message = f"Error saving audio: {e}"
logger.error(error_message)
self.update_status_signal.emit(error_message)
self.enable_widgets_signal.emit(True)
finally:
self.frames.clear()

def start_recording(self):
if not self.is_recording:
self.is_recording = True
threading.Thread(target=self.record_audio).start()
else:
self.update_status_signal.emit("Already recording.")
def stop_all_threads(self):
if hasattr(self, 'recording_thread') and self.recording_thread.isRunning():
self.recording_thread.stop()
self.recording_thread.wait(timeout=5000)

if hasattr(self, 'model_loader_thread') and self.model_loader_thread.isRunning():
self.model_loader_thread.requestInterruption()
self.model_loader_thread.wait(timeout=5000)

if hasattr(self, 'transcription_thread') and self.transcription_thread.isRunning():
self.transcription_thread.requestInterruption()
self.transcription_thread.wait(timeout=5000)
40 changes: 17 additions & 23 deletions ct2_main.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,22 @@
import sys
import os
from pathlib import Path
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
import queue
from contextlib import contextmanager

def set_cuda_paths():
try:
venv_base = Path(sys.executable).parent.parent
nvidia_base_path = venv_base / 'Lib' / 'site-packages' / 'nvidia'
cuda_path = nvidia_base_path / 'cuda_runtime' / 'bin'
cublas_path = nvidia_base_path / 'cublas' / 'bin'
cudnn_path = nvidia_base_path / 'cudnn' / 'bin'
paths_to_add = [str(cuda_path), str(cublas_path), str(cudnn_path)]
env_vars = ['CUDA_PATH', 'CUDA_PATH_V12_1', 'PATH']

for env_var in env_vars:
current_value = os.environ.get(env_var, '')
new_value = os.pathsep.join(paths_to_add + [current_value] if current_value else paths_to_add)
os.environ[env_var] = new_value
logger.info("CUDA paths set successfully.")
except Exception as e:
logger.error(f"Failed to set CUDA paths: {e}")
venv_base = Path(sys.executable).parent.parent
nvidia_base_path = venv_base / 'Lib' / 'site-packages' / 'nvidia'
cuda_path = nvidia_base_path / 'cuda_runtime' / 'bin'
cublas_path = nvidia_base_path / 'cublas' / 'bin'
cudnn_path = nvidia_base_path / 'cudnn' / 'bin'
paths_to_add = [str(cuda_path), str(cublas_path), str(cudnn_path)]
env_vars = ['CUDA_PATH', 'CUDA_PATH_V12_1', 'PATH']

for env_var in env_vars:
current_value = os.environ.get(env_var, '')
new_value = os.pathsep.join(paths_to_add + [current_value] if current_value else paths_to_add)
os.environ[env_var] = new_value

set_cuda_paths()

Expand All @@ -35,9 +28,10 @@ def set_cuda_paths():
quantization_checker = CheckQuantizationSupport()
cuda_available = quantization_checker.has_cuda_device()
quantization_checker.update_supported_quantizations()

app = QApplication(sys.argv)
app.setStyle('Fusion')
window = MyWindow(cuda_available)
window.show()
sys.exit(app.exec())

sys.exit(app.exec())
Loading

0 comments on commit 7462f0c

Please sign in to comment.