Skip to content

Commit

Permalink
feat: whisper workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith15 committed Dec 14, 2024
1 parent 53fa913 commit fbe59f4
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 2 deletions.
24 changes: 24 additions & 0 deletions audio_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
import asyncio
import torchaudio

from comfystream.client import ComfyStreamClient

async def main():
cwd = "/home/user/ComfyUI"
client = ComfyStreamClient(cwd=cwd)

with open("./workflows/audio-whsiper-example-workflow.json", "r") as f:
prompt = json.load(f)

client.set_prompt(prompt)

waveform, _ = torchaudio.load("harvard.wav")
if waveform.ndim > 1:
audio_tensor = waveform.mean(dim=0)

output = await client.queue_prompt(audio_tensor)
print(output)

if __name__ == "__main__":
asyncio.run(main())
7 changes: 7 additions & 0 deletions nodes/audio_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .apply_whisper import ApplyWhisper
from .load_audio_tensor import LoadAudioTensor
from .save_asr_response import SaveASRResponse

NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper}

__all__ = ["NODE_CLASS_MAPPINGS"]
60 changes: 60 additions & 0 deletions nodes/audio_utils/apply_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import whisper

class ApplyWhisper:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"audio": ("AUDIO",),
"model": (["base", "tiny", "small", "medium", "large"],),
}
}

RETURN_TYPES = ("DICT",)
FUNCTION = "apply_whisper"

def __init__(self):
self.model = None
self.audio_buffer = []
# TO:DO to get them as params
self.sample_rate = 16000
self.min_duration = 1.0

def apply_whisper(self, audio, model):
if self.model is None:
self.model = whisper.load_model(model).cuda()

self.audio_buffer.append(audio)
total_duration = sum(chunk.shape[0] / self.sample_rate for chunk in self.audio_buffer)
if total_duration < self.min_duration:
return {"text": "", "segments_alignment": [], "words_alignment": []}

concatenated_audio = torch.cat(self.audio_buffer, dim=0).cuda()
self.audio_buffer = []
result = self.model.transcribe(concatenated_audio.float(), fp16=True, word_timestamps=True)
segments = result['segments']
segments_alignment = []
words_alignment = []

for segment in segments:
segment_dict = {
'value': segment['text'].strip(),
'start': segment['start'],
'end': segment['end']
}
segments_alignment.append(segment_dict)

for word in segment["words"]:
word_dict = {
'value': word["word"].strip(),
'start': word["start"],
'end': word['end']
}
words_alignment.append(word_dict)

return ({
"text": result["text"].strip(),
"segments_alignment": segments_alignment,
"words_alignment": words_alignment
},)
18 changes: 18 additions & 0 deletions nodes/audio_utils/load_audio_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from comfystream import tensor_cache

class LoadAudioTensor:
CATEGORY = "tensor_utils"
RETURN_TYPES = ("AUDIO",)
FUNCTION = "execute"

@classmethod
def INPUT_TYPES(s):
return {}

@classmethod
def IS_CHANGED():
return float("nan")

def execute(self):
audio = tensor_cache.inputs.pop()
return (audio,)
24 changes: 24 additions & 0 deletions nodes/audio_utils/save_asr_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from comfystream import tensor_cache

class SaveASRResponse:
CATEGORY = "tensor_utils"
RETURN_TYPES = ()
FUNCTION = "execute"
OUTPUT_NODE = True

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"data": ("DICT",),
}
}

@classmethod
def IS_CHANGED(s):
return float("nan")

def execute(self, data: dict):
fut = tensor_cache.outputs.pop()
fut.set_result(data)
return data
4 changes: 2 additions & 2 deletions src/comfystream/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def convert_prompt(prompt: PromptDictInput) -> Prompt:
"class_type": "SaveTensor",
"_meta": {"title": "SaveTensor"},
}
elif node.get("class_type") == "LoadTensor":
elif node.get("class_type") in ["LoadTensor", "LoadAudioTensor"]:
num_inputs += 1
elif node.get("class_type") == "SaveTensor":
elif node.get("class_type") in ["SaveTensor", "SaveASRResponse"]:
num_outputs += 1

# Only handle single input for now
Expand Down
6 changes: 6 additions & 0 deletions ui/src/components/webcam.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ export function Webcam({ onStreamReady }: WebcamProps) {
width: { exact: 512 },
height: { exact: 512 },
},
audio: {
noiseSuppression: true,
echoCancellation: true,
sampleRate: 16000,
sampleSize: 16,
},
});

if (videoRef.current) videoRef.current.srcObject = stream;
Expand Down
34 changes: 34 additions & 0 deletions workflows/audio-whsiper-example-workflow.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"1": {
"inputs": {},
"class_type": "LoadAudioTensor",
"_meta": {
"title": "Load Audio Tensor"
}
},
"2": {
"inputs": {
"audio": [
"1",
0
],
"model": "large"
},
"class_type": "ApplyWhisper",
"_meta": {
"title": "Apply Whisper"
}
},
"3": {
"inputs": {
"data": [
"2",
0
]
},
"class_type": "SaveASRResponse",
"_meta": {
"title": "Save ASR Response"
}
}
}

0 comments on commit fbe59f4

Please sign in to comment.