From 415c387649671c4eecf85784e3f323d194a07fac Mon Sep 17 00:00:00 2001 From: Varshith B Date: Fri, 24 Jan 2025 23:09:21 +0530 Subject: [PATCH] feat: combine audio and video streams --- nodes/audio_utils/load_audio_tensor.py | 26 +++- nodes/audio_utils/save_audio_tensor.py | 9 +- nodes/audio_utils/save_result.py | 24 ---- nodes/tensor_utils/load_tensor.py | 2 +- nodes/tensor_utils/save_tensor.py | 2 +- server/app.py | 45 ++---- server/pipeline.py | 148 ++++++++++---------- src/comfystream/client.py | 34 ++--- src/comfystream/tensor_cache.py | 10 +- workflows/audio-video-tensor-utils.json | 46 ++++++ workflows/audio-voice-changer-workflow.json | 33 +++++ workflows/voice-changer-and-whisper.json | 49 +++++++ 12 files changed, 260 insertions(+), 168 deletions(-) delete mode 100644 nodes/audio_utils/save_result.py create mode 100644 workflows/audio-video-tensor-utils.json create mode 100644 workflows/audio-voice-changer-workflow.json create mode 100644 workflows/voice-changer-and-whisper.json diff --git a/nodes/audio_utils/load_audio_tensor.py b/nodes/audio_utils/load_audio_tensor.py index 97b92b35..62cce3a6 100644 --- a/nodes/audio_utils/load_audio_tensor.py +++ b/nodes/audio_utils/load_audio_tensor.py @@ -1,3 +1,5 @@ +import numpy as np + from comfystream import tensor_cache class LoadAudioTensor: @@ -5,14 +7,30 @@ class LoadAudioTensor: RETURN_TYPES = ("AUDIO",) FUNCTION = "execute" + def __init__(self): + self.audio_buffer = np.array([]) + @classmethod def INPUT_TYPES(s): - return {} + return { + "required": { + "buffer_size": ("FLOAT", {"default": 500.0}) + } + } @classmethod def IS_CHANGED(): return float("nan") - def execute(self): - audio = tensor_cache.audio_inputs.pop() - return (audio,) \ No newline at end of file + def execute(self, buffer_size): + audio = tensor_cache.audio_inputs.get(block=True) + self.audio_buffer = np.concatenate((self.audio_buffer, audio)) + + buffer_size_samples = int(buffer_size * 48) + + if self.audio_buffer.size >= buffer_size_samples: + buffered_audio = self.audio_buffer[:buffer_size_samples] + self.audio_buffer = self.audio_buffer[buffer_size_samples:] + return (buffered_audio,) + else: + return (None,) \ No newline at end of file diff --git a/nodes/audio_utils/save_audio_tensor.py b/nodes/audio_utils/save_audio_tensor.py index 36eec545..a2f54dcc 100644 --- a/nodes/audio_utils/save_audio_tensor.py +++ b/nodes/audio_utils/save_audio_tensor.py @@ -12,7 +12,6 @@ def INPUT_TYPES(s): return { "required": { "audio": ("AUDIO",), - "text": ("TEXT",) } } @@ -20,7 +19,7 @@ def INPUT_TYPES(s): def IS_CHANGED(s): return float("nan") - def execute(self, audio, text): - fut = tensor_cache.audio_outputs.pop() - fut.set_result((audio, text)) - return (audio, text) + def execute(self, audio): + fut = tensor_cache.audio_outputs.get() + fut.set_result((audio)) + return (audio,) diff --git a/nodes/audio_utils/save_result.py b/nodes/audio_utils/save_result.py deleted file mode 100644 index 4b8ef892..00000000 --- a/nodes/audio_utils/save_result.py +++ /dev/null @@ -1,24 +0,0 @@ -from comfystream import tensor_cache - -class SaveResult: - CATEGORY = "audio_utils" - RETURN_TYPES = () - FUNCTION = "execute" - OUTPUT_NODE = True - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "result": ("RESULT",), - } - } - - @classmethod - def IS_CHANGED(s): - return float("nan") - - def execute(self, result): - fut = tensor_cache.audio_outputs.pop() - fut.set_result(result) - return result \ No newline at end of file diff --git a/nodes/tensor_utils/load_tensor.py b/nodes/tensor_utils/load_tensor.py index 117078ea..6c13a8a7 100644 --- a/nodes/tensor_utils/load_tensor.py +++ b/nodes/tensor_utils/load_tensor.py @@ -15,5 +15,5 @@ def IS_CHANGED(): return float("nan") def execute(self): - input = tensor_cache.image_inputs.pop() + input = tensor_cache.image_inputs.get(block=True) return (input,) diff --git a/nodes/tensor_utils/save_tensor.py b/nodes/tensor_utils/save_tensor.py index 6c5d869a..98d6cfff 100644 --- a/nodes/tensor_utils/save_tensor.py +++ b/nodes/tensor_utils/save_tensor.py @@ -22,6 +22,6 @@ def IS_CHANGED(s): return float("nan") def execute(self, images: torch.Tensor): - fut = tensor_cache.image_outputs.pop() + fut = tensor_cache.image_outputs.get() fut.set_result(images) return images diff --git a/server/app.py b/server/app.py index 8c884b43..3e320ef0 100644 --- a/server/app.py +++ b/server/app.py @@ -15,7 +15,7 @@ ) from aiortc.rtcrtpsender import RTCRtpSender from aiortc.codecs import h264 -from pipeline import VideoPipeline, AudioPipeline +from pipeline import Pipeline from utils import patch_loop_datagram logger = logging.getLogger(__name__) @@ -30,17 +30,15 @@ def __init__(self, track: MediaStreamTrack, pipeline): super().__init__() self.track = track self.pipeline = pipeline - self.processed_frames = asyncio.Queue() asyncio.create_task(self.collect_frames()) async def collect_frames(self): while True: frame = await self.track.recv() - processed = await self.pipeline(frame) - await self.processed_frames.put(processed) + await self.pipeline.put_video_frame(frame) async def recv(self): - return await self.processed_frames.get() + return await self.pipeline.get_processed_video_frame() class AudioStreamTrack(MediaStreamTrack): @@ -49,28 +47,15 @@ def __init__(self, track: MediaStreamTrack, pipeline): super().__init__() self.track = track self.pipeline = pipeline - self.incoming_frames = asyncio.Queue() - self.processed_frames = asyncio.Queue() asyncio.create_task(self.collect_frames()) - asyncio.create_task(self.process_frames()) - self.started = False async def collect_frames(self): while True: frame = await self.track.recv() - await self.incoming_frames.put(frame) - - async def process_frames(self): - while True: - frames = [] - while len(frames) < 25: - frames.append(await self.incoming_frames.get()) - processed_frames = await self.pipeline(frames) - for processed_frame in processed_frames: - await self.processed_frames.put(processed_frame) + await self.pipeline.put_audio_frame(frame) async def recv(self): - return await self.processed_frames.get() + return await self.pipeline.get_processed_audio_frame() def force_codec(pc, sender, forced_codec): @@ -114,16 +99,13 @@ def get_ice_servers(): async def offer(request): - video_pipeline = request.app["video_pipeline"] - audio_pipeline = request.app["audio_pipeline"] + pipeline = request.app["pipeline"] pcs = request.app["pcs"] params = await request.json() - video_pipeline.set_prompt(params["video_prompt"]) - await video_pipeline.warm() - audio_pipeline.set_prompt(params["audio_prompt"]) - await audio_pipeline.warm() + pipeline.set_prompts(params["video_prompt"]) + await pipeline.warm() offer_params = params["offer"] offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) @@ -154,14 +136,14 @@ async def offer(request): def on_track(track): logger.info(f"Track received: {track.kind}") if track.kind == "video": - videoTrack = VideoStreamTrack(track, video_pipeline) + videoTrack = VideoStreamTrack(track, pipeline) tracks["video"] = videoTrack sender = pc.addTrack(videoTrack) codec = "video/H264" force_codec(pc, sender, codec) elif track.kind == "audio": - audioTrack = AudioStreamTrack(track, audio_pipeline) + audioTrack = AudioStreamTrack(track, pipeline) tracks["audio"] = audioTrack pc.addTrack(audioTrack) @@ -196,7 +178,7 @@ async def set_prompt(request): pipeline = request.app["pipeline"] prompt = await request.json() - pipeline.set_prompt(prompt) + pipeline.set_prompts(prompt) return web.Response(content_type="application/json", text="OK") @@ -209,10 +191,7 @@ async def on_startup(app: web.Application): if app["media_ports"]: patch_loop_datagram(app["media_ports"]) - app["video_pipeline"] = VideoPipeline( - cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True - ) - app["audio_pipeline"] = AudioPipeline( + app["pipeline"] = Pipeline( cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True ) app["pcs"] = set() diff --git a/server/pipeline.py b/server/pipeline.py index acbcef1c..3fb6bbfe 100644 --- a/server/pipeline.py +++ b/server/pipeline.py @@ -2,85 +2,84 @@ import av import numpy as np import fractions +import asyncio from av import AudioFrame from typing import Any, Dict, Optional, Union, List from comfystream.client import ComfyStreamClient +from comfystream import tensor_cache WARMUP_RUNS = 5 -import logging -display_logger = logging.getLogger('display_logger') -display_logger.setLevel(logging.INFO) -handler = logging.FileHandler('display_logs.txt') -formatter = logging.Formatter('%(message)s') -handler.setFormatter(formatter) -display_logger.addHandler(handler) - -class VideoPipeline: +class Pipeline: def __init__(self, **kwargs): - self.client = ComfyStreamClient(**kwargs, type="image") - - async def warm(self): - frame = torch.randn(1, 512, 512, 3) - - for _ in range(WARMUP_RUNS): - await self.predict(frame) - - def set_prompt(self, prompt: Dict[Any, Any]): - self.client.set_prompt(prompt) - - def preprocess(self, frame: av.VideoFrame) -> torch.Tensor: - frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - return torch.from_numpy(frame_np).unsqueeze(0) - - async def predict(self, frame: torch.Tensor) -> torch.Tensor: - return await self.client.queue_prompt(frame) - - def postprocess(self, frame: torch.Tensor) -> av.VideoFrame: - return av.VideoFrame.from_ndarray( - (frame * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() - ) + self.client = ComfyStreamClient(**kwargs, max_workers=2) - async def __call__(self, frame: av.VideoFrame) -> av.VideoFrame: - pre_output = self.preprocess(frame) - pred_output = await self.predict(pre_output) - post_output = self.postprocess(pred_output) + self.video_futures = asyncio.Queue() + self.audio_futures = asyncio.Queue() - post_output.pts = frame.pts - post_output.time_base = frame.time_base - - return post_output - - -class AudioPipeline: - def __init__(self, **kwargs): - self.client = ComfyStreamClient(**kwargs, type="audio") - self.resampler = av.audio.resampler.AudioResampler(format='s16', layout='mono', rate=48000) + self.audio_output_frames = [] + + self.resampler = av.audio.resampler.AudioResampler(format='s16', layout='mono', rate=48000) # find a better way to convert to mono self.sample_rate = 48000 self.frame_size = int(self.sample_rate * 0.02) self.time_base = fractions.Fraction(1, self.sample_rate) self.curr_pts = 0 async def warm(self): - dummy_audio = np.random.randint(-32768, 32767, 48000 * 1, dtype=np.int16) - for _ in range(WARMUP_RUNS): - await self.predict(dummy_audio) - - def set_prompt(self, prompt: Dict[Any, Any]): - self.client.set_prompt(prompt) + dummy_video_frame = torch.randn(1, 512, 512, 3) + dummy_audio_frame = np.random.randint(-32768, 32767, 48000 * 1, dtype=np.int16) - def preprocess(self, frames: List[av.AudioFrame]) -> torch.Tensor: - audio_arrays = [] - for frame in frames: - audio_arrays.append(self.resampler.resample(frame)[0].to_ndarray()) - return np.concatenate(audio_arrays, axis=1).flatten() + for _ in range(WARMUP_RUNS): + image_out_fut = asyncio.Future() + audio_out_fut = asyncio.Future() + tensor_cache.image_outputs.put(image_out_fut) + tensor_cache.audio_outputs.put(audio_out_fut) + + tensor_cache.image_inputs.put(dummy_video_frame) + tensor_cache.audio_inputs.put(dummy_audio_frame) + + await image_out_fut + await audio_out_fut + + def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + if isinstance(prompts, dict): + self.client.set_prompts([prompts]) + else: + self.client.set_prompts(prompts) + + async def put_video_frame(self, frame: av.VideoFrame): + inp_tensor = self.video_preprocess(frame) + out_future = asyncio.Future() + tensor_cache.image_outputs.put(out_future) + tensor_cache.image_inputs.put(inp_tensor) + await self.video_futures.put((out_future, frame.pts, frame.time_base)) + + async def put_audio_frame(self, frame: av.AudioFrame): + inp_tensor = self.audio_preprocess(frame) + out_future = asyncio.Future() + tensor_cache.audio_outputs.put(out_future) + tensor_cache.audio_inputs.put(inp_tensor) + await self.audio_futures.put(out_future) + + def video_preprocess(self, frame: av.VideoFrame) -> torch.Tensor: + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + return torch.from_numpy(frame_np).unsqueeze(0) + + def audio_preprocess(self, frame: av.AudioFrame) -> torch.Tensor: + return self.resampler.resample(frame)[0].to_ndarray().flatten() + + def video_postprocess(self, output: torch.Tensor) -> av.VideoFrame: + return av.VideoFrame.from_ndarray( + (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() + ) - def postprocess(self, out_np) -> Optional[Union[av.AudioFrame, str]]: + def audio_postprocess(self, output: torch.Tensor) -> av.AudioFrame: frames = [] - for idx in range(0, len(out_np), self.frame_size): - frame_samples = out_np[idx:idx + self.frame_size] - frame_samples = frame_samples.reshape(1, -1) + print("OUTPUT SHAPE", output.shape) + for idx in range(0, len(output), self.frame_size): + frame_samples = output[idx:idx + self.frame_size] + frame_samples = frame_samples.reshape(1, -1).astype(np.int16) frame = AudioFrame.from_ndarray(frame_samples, layout="mono") frame.sample_rate = self.sample_rate frame.pts = self.curr_pts @@ -89,14 +88,21 @@ def postprocess(self, out_np) -> Optional[Union[av.AudioFrame, str]]: frames.append(frame) return frames - - async def predict(self, frame) -> torch.Tensor: - return await self.client.queue_prompt(frame) - - async def __call__(self, frames: List[av.AudioFrame]): - pre_audio = self.preprocess(frames) - pred_audio, text = await self.predict(pre_audio) - if text[-1] != "": - display_logger.info(f"{text[0]} {text[1]} {text[2]}") - pred_audios = self.postprocess(pred_audio) - return pred_audios \ No newline at end of file + + async def get_processed_video_frame(self): + out_fut, pts, time_base = await self.video_futures.get() + frame = self.video_postprocess(await out_fut) + frame.pts = pts + frame.time_base = time_base + return frame + + + async def get_processed_audio_frame(self): + while not self.audio_output_frames: + out_fut = await self.audio_futures.get() + output = await out_fut + if output is None: + print("No Audio output") + continue + self.audio_output_frames.extend(self.audio_postprocess(output)) + return self.audio_output_frames.pop(0) \ No newline at end of file diff --git a/src/comfystream/client.py b/src/comfystream/client.py index 2d1946dc..809eeab9 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -1,38 +1,24 @@ import torch import asyncio +from typing import List from comfy.api.components.schema.prompt import PromptDictInput from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import EmbeddedComfyClient -from comfystream import tensor_cache from comfystream.utils import convert_prompt class ComfyStreamClient: - def __init__(self, type: str = "image", **kwargs): + def __init__(self, max_workers: int = 1, **kwargs): config = Configuration(**kwargs) # TODO: Need to handle cleanup for EmbeddedComfyClient if not using async context manager? - self.comfy_client = EmbeddedComfyClient(config) - self.prompt = None - self.type = type.lower() - if self.type not in {"image", "audio"}: - raise ValueError(f"Unsupported type: {self.type}. Supported types are 'image' and 'audio'.") - - self.input_cache = getattr(tensor_cache, f"{self.type}_inputs", None) - self.output_cache = getattr(tensor_cache, f"{self.type}_outputs", None) - - if self.input_cache is None or self.output_cache is None: - raise AttributeError(f"tensor_cache does not have attributes for type '{self.type}'.") + self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers) - def set_prompt(self, prompt: PromptDictInput): - self.prompt = convert_prompt(prompt) + def set_prompts(self, prompts: List[PromptDictInput]): + for prompt in [convert_prompt(prompt) for prompt in prompts]: + asyncio.create_task(self.run_prompt(prompt)) - async def queue_prompt(self, input: torch.Tensor) -> torch.Tensor: - self.input_cache.append(input) - - output_fut = asyncio.Future() - self.output_cache.append(output_fut) - - await self.comfy_client.queue_prompt(self.prompt) - - return await output_fut + async def run_prompt(self, prompt: PromptDictInput): + # TODO: Handle workflow changing mid stream + while True: + await self.comfy_client.queue_prompt(prompt) diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index 40a58abd..5d8bc464 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -1,9 +1,9 @@ import asyncio import torch -from typing import List +from queue import Queue -image_inputs: List[torch.Tensor] = [] -image_outputs: List[asyncio.Future] = [] +image_inputs: Queue[torch.Tensor] = Queue() +image_outputs: Queue[asyncio.Future] = Queue() -audio_inputs: List[torch.Tensor] = [] -audio_outputs: List[asyncio.Future] = [] +audio_inputs: Queue[torch.Tensor] = Queue() +audio_outputs: Queue[asyncio.Future] = Queue() diff --git a/workflows/audio-video-tensor-utils.json b/workflows/audio-video-tensor-utils.json new file mode 100644 index 00000000..8630324a --- /dev/null +++ b/workflows/audio-video-tensor-utils.json @@ -0,0 +1,46 @@ +[ + { + "1": { + "inputs": { + "images": [ + "2", + 0 + ] + }, + "class_type": "SaveTensor", + "_meta": { + "title": "SaveTensor" + } + }, + "2": { + "inputs": {}, + "class_type": "LoadTensor", + "_meta": { + "title": "LoadTensor" + } + } + }, + { + "1": { + "inputs": { + "buffer_size": 500.0 + }, + "class_type": "LoadAudioTensor", + "_meta": { + "title": "Load Audio Tensor" + } + }, + "2": { + "inputs": { + "audio": [ + "1", + 0 + ] + }, + "class_type": "SaveAudioTensor", + "_meta": { + "title": "Save Audio Tensor" + } + } + } +] \ No newline at end of file diff --git a/workflows/audio-voice-changer-workflow.json b/workflows/audio-voice-changer-workflow.json new file mode 100644 index 00000000..a0a7a75b --- /dev/null +++ b/workflows/audio-voice-changer-workflow.json @@ -0,0 +1,33 @@ +{ + "1": { + "inputs": {}, + "class_type": "LoadAudioTensor", + "_meta": { + "title": "Load Audio Tensor" + } + }, + "2": { + "inputs": { + "audio": [ + "1", + 0 + ] + }, + "class_type": "VoiceChanger", + "_meta": { + "title": "Voice Changer" + } + }, + "3": { + "inputs": { + "audio": [ + "2", + 0 + ] + }, + "class_type": "SaveAudioTensor", + "_meta": { + "title": "Save Audio Tensor" + } + } + } \ No newline at end of file diff --git a/workflows/voice-changer-and-whisper.json b/workflows/voice-changer-and-whisper.json new file mode 100644 index 00000000..f9945840 --- /dev/null +++ b/workflows/voice-changer-and-whisper.json @@ -0,0 +1,49 @@ +{ + "1": { + "inputs": {}, + "class_type": "LoadAudioTensor", + "_meta": { + "title": "Load Audio Tensor" + } + }, + "2": { + "inputs": { + "audio": [ + "1", + 0 + ] + }, + "class_type": "ApplyWhisper", + "_meta": { + "title": "Apply Whisper" + } + }, + "3": { + "inputs": { + "audio": [ + "1", + 0 + ] + }, + "class_type": "VoiceChanger", + "_meta": { + "title": "Voice Changer" + } + }, + "4": { + "inputs": { + "audio": [ + "3", + 0 + ], + "text": [ + "2", + 0 + ] + }, + "class_type": "SaveAudioTensor", + "_meta": { + "title": "Save Audio Tensor" + } + } +} \ No newline at end of file