From 29f6bb7d1a60bf146a1f41e05e1a7adcef9cbc59 Mon Sep 17 00:00:00 2001 From: Varshith B Date: Sat, 28 Dec 2024 01:18:51 +0530 Subject: [PATCH] feat: audio pipeline --- nodes/audio_utils/__init__.py | 3 +- nodes/audio_utils/apply_whisper.py | 5 +- nodes/audio_utils/load_audio_tensor.py | 2 +- nodes/audio_utils/save_asr_response.py | 2 +- nodes/audio_utils/save_audio_tensor.py | 25 ++++ nodes/tensor_utils/load_tensor.py | 2 +- nodes/tensor_utils/save_tensor.py | 2 +- server/app.py | 108 +++--------------- server/pipeline.py | 43 ++++++- src/comfystream/client.py | 15 ++- src/comfystream/tensor_cache.py | 7 +- src/comfystream/utils.py | 2 +- ui/src/app/api/offer/route.ts | 4 +- ui/src/components/peer.tsx | 3 +- ui/src/components/room.tsx | 6 +- ui/src/components/settings.tsx | 41 +++++-- ui/src/components/webcam.tsx | 6 +- ui/src/hooks/use-peer.ts | 5 +- .../audio-tensor-utils-example-workflow.json | 21 ++++ 19 files changed, 179 insertions(+), 123 deletions(-) create mode 100644 nodes/audio_utils/save_audio_tensor.py create mode 100644 workflows/audio-tensor-utils-example-workflow.json diff --git a/nodes/audio_utils/__init__.py b/nodes/audio_utils/__init__.py index ab395c38..162e8244 100644 --- a/nodes/audio_utils/__init__.py +++ b/nodes/audio_utils/__init__.py @@ -1,7 +1,8 @@ from .apply_whisper import ApplyWhisper from .load_audio_tensor import LoadAudioTensor from .save_asr_response import SaveASRResponse +from .save_audio_tensor import SaveAudioTensor -NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper} +NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper, "SaveAudioTensor": SaveAudioTensor} __all__ = ["NODE_CLASS_MAPPINGS"] diff --git a/nodes/audio_utils/apply_whisper.py b/nodes/audio_utils/apply_whisper.py index 87bffb34..c10ac884 100644 --- a/nodes/audio_utils/apply_whisper.py +++ b/nodes/audio_utils/apply_whisper.py @@ -21,17 +21,18 @@ def __init__(self): # TO:DO to get them as params self.sample_rate = 16000 self.min_duration = 1.0 + self.device = "cuda" if torch.cuda.is_available() else "cpu" def apply_whisper(self, audio, model): if self.model is None: - self.model = whisper.load_model(model).cuda() + self.model = whisper.load_model(model).to(self.device) self.audio_buffer.append(audio) total_duration = sum(chunk.shape[0] / self.sample_rate for chunk in self.audio_buffer) if total_duration < self.min_duration: return {"text": "", "segments_alignment": [], "words_alignment": []} - concatenated_audio = torch.cat(self.audio_buffer, dim=0).cuda() + concatenated_audio = torch.cat(self.audio_buffer, dim=0).to(self.device) self.audio_buffer = [] result = self.model.transcribe(concatenated_audio.float(), fp16=True, word_timestamps=True) segments = result["segments"] diff --git a/nodes/audio_utils/load_audio_tensor.py b/nodes/audio_utils/load_audio_tensor.py index ef89a119..97b92b35 100644 --- a/nodes/audio_utils/load_audio_tensor.py +++ b/nodes/audio_utils/load_audio_tensor.py @@ -14,5 +14,5 @@ def IS_CHANGED(): return float("nan") def execute(self): - audio = tensor_cache.inputs.pop() + audio = tensor_cache.audio_inputs.pop() return (audio,) \ No newline at end of file diff --git a/nodes/audio_utils/save_asr_response.py b/nodes/audio_utils/save_asr_response.py index b402931f..816ca1bc 100644 --- a/nodes/audio_utils/save_asr_response.py +++ b/nodes/audio_utils/save_asr_response.py @@ -19,6 +19,6 @@ def IS_CHANGED(s): return float("nan") def execute(self, data: dict): - fut = tensor_cache.outputs.pop() + fut = tensor_cache.audio_outputs.pop() fut.set_result(data) return data \ No newline at end of file diff --git a/nodes/audio_utils/save_audio_tensor.py b/nodes/audio_utils/save_audio_tensor.py new file mode 100644 index 00000000..5135285c --- /dev/null +++ b/nodes/audio_utils/save_audio_tensor.py @@ -0,0 +1,25 @@ +from comfystream import tensor_cache + + +class SaveAudioTensor: + CATEGORY = "audio_utils" + RETURN_TYPES = () + FUNCTION = "execute" + OUTPUT_NODE = True + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "audio": ("AUDIO",), + } + } + + @classmethod + def IS_CHANGED(s): + return float("nan") + + def execute(self, audio): + fut = tensor_cache.audio_outputs.pop() + fut.set_result(audio) + return audio diff --git a/nodes/tensor_utils/load_tensor.py b/nodes/tensor_utils/load_tensor.py index 622f176e..117078ea 100644 --- a/nodes/tensor_utils/load_tensor.py +++ b/nodes/tensor_utils/load_tensor.py @@ -15,5 +15,5 @@ def IS_CHANGED(): return float("nan") def execute(self): - input = tensor_cache.inputs.pop() + input = tensor_cache.image_inputs.pop() return (input,) diff --git a/nodes/tensor_utils/save_tensor.py b/nodes/tensor_utils/save_tensor.py index 131604f7..6c5d869a 100644 --- a/nodes/tensor_utils/save_tensor.py +++ b/nodes/tensor_utils/save_tensor.py @@ -22,6 +22,6 @@ def IS_CHANGED(s): return float("nan") def execute(self, images: torch.Tensor): - fut = tensor_cache.outputs.pop() + fut = tensor_cache.image_outputs.pop() fut.set_result(images) return images diff --git a/server/app.py b/server/app.py index ec42d831..86a98b18 100644 --- a/server/app.py +++ b/server/app.py @@ -3,8 +3,6 @@ import os import json import logging -import wave -import numpy as np from twilio.rest import Client from aiohttp import web @@ -17,7 +15,7 @@ ) from aiortc.rtcrtpsender import RTCRtpSender from aiortc.codecs import h264 -from pipeline import Pipeline +from pipeline import VideoPipeline, AudioPipeline from utils import patch_loop_datagram logger = logging.getLogger(__name__) @@ -39,93 +37,16 @@ async def recv(self): return await self.pipeline(frame) class AudioStreamTrack(MediaStreamTrack): - """ - This custom audio track wraps an incoming audio MediaStreamTrack. - It continuously records frames in 10-second chunks and saves each chunk - as a separate WAV file with an incrementing index. - """ - kind = "audio" - def __init__(self, track: MediaStreamTrack): + def __init__(self, track: MediaStreamTrack, pipeline): super().__init__() self.track = track - self.start_time = None - self.frames = [] - self._recording_duration = 10.0 # in seconds - self._chunk_index = 0 - self._saving = False - self._lock = asyncio.Lock() + self.pipeline = pipeline async def recv(self): frame = await self.track.recv() - return frame - - # async def recv(self): - # frame = await self.source.recv() - - # # On the first frame, record the start time. - # if self.start_time is None: - # self.start_time = frame.time - # logger.info(f"Audio recording started at time: {self.start_time:.3f}") - - # elapsed = frame.time - self.start_time - # self.frames.append(frame) - - # logger.info(f"Received audio frame at time: {frame.time:.3f}, total frames: {len(self.frames)}") - - # # Check if we've hit 10 seconds and we're not currently saving. - # if elapsed >= self._recording_duration and not self._saving: - # logger.info(f"10 second chunk reached (elapsed: {elapsed:.3f}s). Preparing to save chunk {self._chunk_index}.") - # self._saving = True - # # Handle saving in a background task so we don't block the recv loop. - # asyncio.create_task(self.save_audio()) - - # return frame - - async def save_audio(self): - logger.info(f"Starting to save audio chunk {self._chunk_index}...") - async with self._lock: - # Extract properties from the first frame - if not self.frames: - logger.warning("No frames to save, skipping.") - self._saving = False - return - - sample_rate = self.frames[0].sample_rate - layout = self.frames[0].layout - channels = len(layout.channels) - - logger.info(f"Audio chunk {self._chunk_index}: sample_rate={sample_rate}, channels={channels}, frames_count={len(self.frames)}") - - # Convert all frames to ndarray and concatenate - data_arrays = [f.to_ndarray() for f in self.frames] - data = np.concatenate(data_arrays, axis=1) # shape: (channels, total_samples) - - # Interleave channels (if multiple) since WAV expects interleaved samples. - interleaved = data.T.flatten() - - # If needed, convert float frames to int16 - # interleaved = (interleaved * 32767).astype(np.int16) - - filename = f"output_{self._chunk_index}.wav" - logger.info(f"Writing audio chunk {self._chunk_index} to file: {filename}") - with wave.open(filename, 'wb') as wf: - wf.setnchannels(channels) - wf.setsampwidth(2) # 16-bit PCM - wf.setframerate(sample_rate) - wf.writeframes(interleaved.tobytes()) - - logger.info(f"Audio chunk {self._chunk_index} saved successfully as {filename}") - - # Increment the chunk index for the next segment - self._chunk_index += 1 - - # Reset for next recording chunk - self.frames.clear() - self.start_time = None - self._saving = False - logger.info(f"Ready to record next 10-second chunk. Current chunk index: {self._chunk_index}") + return await self.pipeline(frame) def force_codec(pc, sender, forced_codec): @@ -169,13 +90,19 @@ def get_ice_servers(): async def offer(request): - pipeline = request.app["pipeline"] + video_pipeline = request.app["video_pipeline"] + audio_pipeline = request.app["audio_pipeline"] pcs = request.app["pcs"] params = await request.json() - pipeline.set_prompt(params["prompt"]) - await pipeline.warm() + print("VIDEO PROMPT", params["video_prompt"]) + print("AUDIO PROMPT", params["audio_prompt"]) + + video_pipeline.set_prompt(params["video_prompt"]) + await video_pipeline.warm() + audio_pipeline.set_prompt(params["audio_prompt"]) + await audio_pipeline.warm() offer_params = params["offer"] offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) @@ -206,14 +133,14 @@ async def offer(request): def on_track(track): logger.info(f"Track received: {track.kind}") if track.kind == "video": - videoTrack = VideoStreamTrack(track, pipeline) + videoTrack = VideoStreamTrack(track, video_pipeline) tracks["video"] = videoTrack sender = pc.addTrack(videoTrack) codec = "video/H264" force_codec(pc, sender, codec) elif track.kind == "audio": - audioTrack = AudioStreamTrack(track) + audioTrack = AudioStreamTrack(track, audio_pipeline) tracks["audio"] = audioTrack pc.addTrack(audioTrack) @@ -261,7 +188,10 @@ async def on_startup(app: web.Application): if app["media_ports"]: patch_loop_datagram(app["media_ports"]) - app["pipeline"] = Pipeline( + app["video_pipeline"] = VideoPipeline( + cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True + ) + app["audio_pipeline"] = AudioPipeline( cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True ) app["pcs"] = set() diff --git a/server/pipeline.py b/server/pipeline.py index 02a66c27..17d73f67 100644 --- a/server/pipeline.py +++ b/server/pipeline.py @@ -2,15 +2,15 @@ import av import numpy as np -from typing import Any, Dict +from typing import Any, Dict, Optional, Union from comfystream.client import ComfyStreamClient WARMUP_RUNS = 5 -class Pipeline: +class VideoPipeline: def __init__(self, **kwargs): - self.client = ComfyStreamClient(**kwargs) + self.client = ComfyStreamClient(**kwargs, type="image") async def warm(self): frame = torch.randn(1, 512, 512, 3) @@ -42,3 +42,40 @@ async def __call__(self, frame: av.VideoFrame) -> av.VideoFrame: post_output.time_base = frame.time_base return post_output + + +class AudioPipeline: + def __init__(self, **kwargs): + self.client = ComfyStreamClient(**kwargs, type="audio") + + async def warm(self): + dummy_audio = torch.randn(16000) + for _ in range(WARMUP_RUNS): + await self.predict(dummy_audio) + + def set_prompt(self, prompt: Dict[Any, Any]): + self.client.set_prompt(prompt) + + def preprocess(self, frame: av.AudioFrame) -> torch.Tensor: + self.sample_rate = frame.sample_rate + samples = frame.to_ndarray(format="s16", layout="stereo") + samples = samples.astype(np.float32) / 32768.0 + return torch.from_numpy(samples) + + def postprocess(self, output: torch.Tensor) -> Optional[Union[av.AudioFrame, str]]: + out_np = output.cpu().numpy() + out_np = np.clip(out_np * 32768.0, -32768, 32767).astype(np.int16) + audio_frame = av.AudioFrame.from_ndarray(out_np, format="s16", layout="stereo") + return audio_frame + + async def predict(self, frame: torch.Tensor) -> torch.Tensor: + return await self.client.queue_prompt(frame) + + async def __call__(self, frame: av.AudioFrame): + pre_output = self.preprocess(frame) + pred_output = await self.predict(pre_output) + post_output = self.postprocess(pred_output) + post_output.sample_rate = self.sample_rate + post_output.pts = frame.pts + post_output.time_base = frame.time_base + return post_output diff --git a/src/comfystream/client.py b/src/comfystream/client.py index e380b08d..2d1946dc 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -9,20 +9,29 @@ class ComfyStreamClient: - def __init__(self, **kwargs): + def __init__(self, type: str = "image", **kwargs): config = Configuration(**kwargs) # TODO: Need to handle cleanup for EmbeddedComfyClient if not using async context manager? self.comfy_client = EmbeddedComfyClient(config) self.prompt = None + self.type = type.lower() + if self.type not in {"image", "audio"}: + raise ValueError(f"Unsupported type: {self.type}. Supported types are 'image' and 'audio'.") + + self.input_cache = getattr(tensor_cache, f"{self.type}_inputs", None) + self.output_cache = getattr(tensor_cache, f"{self.type}_outputs", None) + + if self.input_cache is None or self.output_cache is None: + raise AttributeError(f"tensor_cache does not have attributes for type '{self.type}'.") def set_prompt(self, prompt: PromptDictInput): self.prompt = convert_prompt(prompt) async def queue_prompt(self, input: torch.Tensor) -> torch.Tensor: - tensor_cache.inputs.append(input) + self.input_cache.append(input) output_fut = asyncio.Future() - tensor_cache.outputs.append(output_fut) + self.output_cache.append(output_fut) await self.comfy_client.queue_prompt(self.prompt) diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index 44d7c03b..40a58abd 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -2,5 +2,8 @@ import torch from typing import List -inputs: List[torch.Tensor] = [] -outputs: List[asyncio.Future] = [] +image_inputs: List[torch.Tensor] = [] +image_outputs: List[asyncio.Future] = [] + +audio_inputs: List[torch.Tensor] = [] +audio_outputs: List[asyncio.Future] = [] diff --git a/src/comfystream/utils.py b/src/comfystream/utils.py index 4376c403..0948e2b9 100644 --- a/src/comfystream/utils.py +++ b/src/comfystream/utils.py @@ -48,7 +48,7 @@ def convert_prompt(prompt: PromptDictInput) -> Prompt: num_primary_inputs += 1 elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]: num_inputs += 1 - elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveASRResponse"]: + elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveASRResponse", "SaveAudioTensor"]: num_outputs += 1 # Only handle single primary input diff --git a/ui/src/app/api/offer/route.ts b/ui/src/app/api/offer/route.ts index b42428af..44a32605 100644 --- a/ui/src/app/api/offer/route.ts +++ b/ui/src/app/api/offer/route.ts @@ -1,14 +1,14 @@ import { NextRequest, NextResponse } from "next/server"; export const POST = async function POST(req: NextRequest) { - const { endpoint, prompt, offer } = await req.json(); + const { endpoint, video_prompt, audio_prompt, offer } = await req.json(); const res = await fetch(endpoint + "/offer", { method: "POST", headers: { "Content-Type": "application/json", }, - body: JSON.stringify({ prompt, offer }), + body: JSON.stringify({ video_prompt, audio_prompt, offer }), }); return NextResponse.json(await res.json(), { status: res.status }); diff --git a/ui/src/components/peer.tsx b/ui/src/components/peer.tsx index e3708f8d..7243e978 100644 --- a/ui/src/components/peer.tsx +++ b/ui/src/components/peer.tsx @@ -4,7 +4,8 @@ import { PeerContext } from "@/context/peer-context"; export interface PeerProps extends React.HTMLAttributes { url: string; - prompt: any; + videoPrompt: any; + audioPrompt: any; connect: boolean; onConnected: () => void; onDisconnected: () => void; diff --git a/ui/src/components/room.tsx b/ui/src/components/room.tsx index ac0da782..8cc58e8e 100644 --- a/ui/src/components/room.tsx +++ b/ui/src/components/room.tsx @@ -79,7 +79,8 @@ export function Room() { frameRate: 0, selectedDeviceId: "", selectedAudioDeviceId: "", // New property for audio device - prompt: null, + videoPrompt: null, + audioPrompt: null, }); const connectingRef = useRef(false); @@ -135,7 +136,8 @@ export function Room() {
(null); + const [videoPrompt, setVideoPrompt] = useState(null); + const [audioPrompt, setAudioPrompt] = useState(null); const [videoDevices, setVideoDevices] = useState([]); const [audioDevices, setAudioDevices] = useState([]); const [selectedDevice, setSelectedDevice] = useState(""); @@ -185,19 +187,32 @@ function ConfigForm({ config, onSubmit }: ConfigFormProps) { streamUrl: values.streamUrl ? values.streamUrl.replace(/\/+$/, "") : values.streamUrl, - prompt, + videoPrompt: videoPrompt, + audioPrompt: audioPrompt, selectedDeviceId: selectedDevice, selectedAudioDeviceId: selectedAudioDevice, }); }; - const handlePromptChange = async (e: any) => { + const handleVideoPromptChange = async (e: any) => { const file = e.target.files[0]; if (!file) return; try { const text = await file.text(); - setPrompt(JSON.parse(text)); + setVideoPrompt(JSON.parse(text)); + } catch (err) { + console.error(err); + } + }; + + const handleAudioPromptChange = async (e: any) => { + const file = e.target.files[0]; + if (!file) return; + + try { + const text = await file.text(); + setAudioPrompt(JSON.parse(text)); } catch (err) { console.error(err); } @@ -269,12 +284,22 @@ function ConfigForm({ config, onSubmit }: ConfigFormProps) {
- + + +
+ +
+
diff --git a/ui/src/components/webcam.tsx b/ui/src/components/webcam.tsx index 66960039..cbe03280 100644 --- a/ui/src/components/webcam.tsx +++ b/ui/src/components/webcam.tsx @@ -149,9 +149,9 @@ export function Webcam({ onStreamReady, deviceId, frameRate, selectedAudioDevice }, audio: { ...(selectedAudioDeviceId ? { deviceId: { exact: selectedAudioDeviceId } } : {}), - sampleRate: 16000, - sampleSize: 16, - channelCount: 1, + sampleRate: { ideal: 16000 }, + sampleSize: { ideal: 16 }, + channelCount: { exact: 1 }, }, }); return newStream; diff --git a/ui/src/hooks/use-peer.ts b/ui/src/hooks/use-peer.ts index 138c3055..a64ab121 100644 --- a/ui/src/hooks/use-peer.ts +++ b/ui/src/hooks/use-peer.ts @@ -6,7 +6,7 @@ const MAX_OFFER_RETRIES = 5; const OFFER_RETRY_INTERVAL = 500; export function usePeer(props: PeerProps): Peer { - const { url, prompt, connect, onConnected, onDisconnected, localStream } = + const { url, videoPrompt, audioPrompt, connect, onConnected, onDisconnected, localStream } = props; const [peerConnection, setPeerConnection] = @@ -33,7 +33,8 @@ export function usePeer(props: PeerProps): Peer { }, body: JSON.stringify({ endpoint: url, - prompt, + video_prompt: videoPrompt, + audio_prompt: audioPrompt, offer, }), }); diff --git a/workflows/audio-tensor-utils-example-workflow.json b/workflows/audio-tensor-utils-example-workflow.json new file mode 100644 index 00000000..db2873d1 --- /dev/null +++ b/workflows/audio-tensor-utils-example-workflow.json @@ -0,0 +1,21 @@ +{ + "1": { + "inputs": {}, + "class_type": "LoadAudioTensor", + "_meta": { + "title": "Load Audio Tensor" + } + }, + "2": { + "inputs": { + "audio": [ + "1", + 0 + ] + }, + "class_type": "SaveAudioTensor", + "_meta": { + "title": "Save Audio Tensor" + } + } + } \ No newline at end of file