forked from yondonfu/comfystream
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
53fa913
commit fbe59f4
Showing
8 changed files
with
175 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import json | ||
import asyncio | ||
import torchaudio | ||
|
||
from comfystream.client import ComfyStreamClient | ||
|
||
async def main(): | ||
cwd = "/home/user/ComfyUI" | ||
client = ComfyStreamClient(cwd=cwd) | ||
|
||
with open("./workflows/audio-whsiper-example-workflow.json", "r") as f: | ||
prompt = json.load(f) | ||
|
||
client.set_prompt(prompt) | ||
|
||
waveform, _ = torchaudio.load("harvard.wav") | ||
if waveform.ndim > 1: | ||
audio_tensor = waveform.mean(dim=0) | ||
|
||
output = await client.queue_prompt(audio_tensor) | ||
print(output) | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .apply_whisper import ApplyWhisper | ||
from .load_audio_tensor import LoadAudioTensor | ||
from .save_asr_response import SaveASRResponse | ||
|
||
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper} | ||
|
||
__all__ = ["NODE_CLASS_MAPPINGS"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
import whisper | ||
|
||
class ApplyWhisper: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return { | ||
"required": { | ||
"audio": ("AUDIO",), | ||
"model": (["base", "tiny", "small", "medium", "large"],), | ||
} | ||
} | ||
|
||
RETURN_TYPES = ("DICT",) | ||
FUNCTION = "apply_whisper" | ||
|
||
def __init__(self): | ||
self.model = None | ||
self.audio_buffer = [] | ||
# TO:DO to get them as params | ||
self.sample_rate = 16000 | ||
self.min_duration = 1.0 | ||
|
||
def apply_whisper(self, audio, model): | ||
if self.model is None: | ||
self.model = whisper.load_model(model).cuda() | ||
|
||
self.audio_buffer.append(audio) | ||
total_duration = sum(chunk.shape[0] / self.sample_rate for chunk in self.audio_buffer) | ||
if total_duration < self.min_duration: | ||
return {"text": "", "segments_alignment": [], "words_alignment": []} | ||
|
||
concatenated_audio = torch.cat(self.audio_buffer, dim=0).cuda() | ||
self.audio_buffer = [] | ||
result = self.model.transcribe(concatenated_audio.float(), fp16=True, word_timestamps=True) | ||
segments = result['segments'] | ||
segments_alignment = [] | ||
words_alignment = [] | ||
|
||
for segment in segments: | ||
segment_dict = { | ||
'value': segment['text'].strip(), | ||
'start': segment['start'], | ||
'end': segment['end'] | ||
} | ||
segments_alignment.append(segment_dict) | ||
|
||
for word in segment["words"]: | ||
word_dict = { | ||
'value': word["word"].strip(), | ||
'start': word["start"], | ||
'end': word['end'] | ||
} | ||
words_alignment.append(word_dict) | ||
|
||
return ({ | ||
"text": result["text"].strip(), | ||
"segments_alignment": segments_alignment, | ||
"words_alignment": words_alignment | ||
},) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from comfystream import tensor_cache | ||
|
||
class LoadAudioTensor: | ||
CATEGORY = "tensor_utils" | ||
RETURN_TYPES = ("AUDIO",) | ||
FUNCTION = "execute" | ||
|
||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {} | ||
|
||
@classmethod | ||
def IS_CHANGED(): | ||
return float("nan") | ||
|
||
def execute(self): | ||
audio = tensor_cache.inputs.pop() | ||
return (audio,) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from comfystream import tensor_cache | ||
|
||
class SaveASRResponse: | ||
CATEGORY = "tensor_utils" | ||
RETURN_TYPES = () | ||
FUNCTION = "execute" | ||
OUTPUT_NODE = True | ||
|
||
@classmethod | ||
def INPUT_TYPES(s): | ||
return { | ||
"required": { | ||
"data": ("DICT",), | ||
} | ||
} | ||
|
||
@classmethod | ||
def IS_CHANGED(s): | ||
return float("nan") | ||
|
||
def execute(self, data: dict): | ||
fut = tensor_cache.outputs.pop() | ||
fut.set_result(data) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
{ | ||
"1": { | ||
"inputs": {}, | ||
"class_type": "LoadAudioTensor", | ||
"_meta": { | ||
"title": "Load Audio Tensor" | ||
} | ||
}, | ||
"2": { | ||
"inputs": { | ||
"audio": [ | ||
"1", | ||
0 | ||
], | ||
"model": "large" | ||
}, | ||
"class_type": "ApplyWhisper", | ||
"_meta": { | ||
"title": "Apply Whisper" | ||
} | ||
}, | ||
"3": { | ||
"inputs": { | ||
"data": [ | ||
"2", | ||
0 | ||
] | ||
}, | ||
"class_type": "SaveASRResponse", | ||
"_meta": { | ||
"title": "Save ASR Response" | ||
} | ||
} | ||
} |