Skip to content

Commit

Permalink
feat: audio pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith15 committed Dec 27, 2024
1 parent 960aebe commit 29f6bb7
Show file tree
Hide file tree
Showing 19 changed files with 179 additions and 123 deletions.
3 changes: 2 additions & 1 deletion nodes/audio_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .apply_whisper import ApplyWhisper
from .load_audio_tensor import LoadAudioTensor
from .save_asr_response import SaveASRResponse
from .save_audio_tensor import SaveAudioTensor

NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper}
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveASRResponse": SaveASRResponse, "ApplyWhisper": ApplyWhisper, "SaveAudioTensor": SaveAudioTensor}

__all__ = ["NODE_CLASS_MAPPINGS"]
5 changes: 3 additions & 2 deletions nodes/audio_utils/apply_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@ def __init__(self):
# TO:DO to get them as params
self.sample_rate = 16000
self.min_duration = 1.0
self.device = "cuda" if torch.cuda.is_available() else "cpu"

def apply_whisper(self, audio, model):
if self.model is None:
self.model = whisper.load_model(model).cuda()
self.model = whisper.load_model(model).to(self.device)

self.audio_buffer.append(audio)
total_duration = sum(chunk.shape[0] / self.sample_rate for chunk in self.audio_buffer)
if total_duration < self.min_duration:
return {"text": "", "segments_alignment": [], "words_alignment": []}

concatenated_audio = torch.cat(self.audio_buffer, dim=0).cuda()
concatenated_audio = torch.cat(self.audio_buffer, dim=0).to(self.device)
self.audio_buffer = []
result = self.model.transcribe(concatenated_audio.float(), fp16=True, word_timestamps=True)
segments = result["segments"]
Expand Down
2 changes: 1 addition & 1 deletion nodes/audio_utils/load_audio_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ def IS_CHANGED():
return float("nan")

def execute(self):
audio = tensor_cache.inputs.pop()
audio = tensor_cache.audio_inputs.pop()
return (audio,)
2 changes: 1 addition & 1 deletion nodes/audio_utils/save_asr_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ def IS_CHANGED(s):
return float("nan")

def execute(self, data: dict):
fut = tensor_cache.outputs.pop()
fut = tensor_cache.audio_outputs.pop()
fut.set_result(data)
return data
25 changes: 25 additions & 0 deletions nodes/audio_utils/save_audio_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from comfystream import tensor_cache


class SaveAudioTensor:
CATEGORY = "audio_utils"
RETURN_TYPES = ()
FUNCTION = "execute"
OUTPUT_NODE = True

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"audio": ("AUDIO",),
}
}

@classmethod
def IS_CHANGED(s):
return float("nan")

def execute(self, audio):
fut = tensor_cache.audio_outputs.pop()
fut.set_result(audio)
return audio
2 changes: 1 addition & 1 deletion nodes/tensor_utils/load_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def IS_CHANGED():
return float("nan")

def execute(self):
input = tensor_cache.inputs.pop()
input = tensor_cache.image_inputs.pop()
return (input,)
2 changes: 1 addition & 1 deletion nodes/tensor_utils/save_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ def IS_CHANGED(s):
return float("nan")

def execute(self, images: torch.Tensor):
fut = tensor_cache.outputs.pop()
fut = tensor_cache.image_outputs.pop()
fut.set_result(images)
return images
108 changes: 19 additions & 89 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import os
import json
import logging
import wave
import numpy as np

from twilio.rest import Client
from aiohttp import web
Expand All @@ -17,7 +15,7 @@
)
from aiortc.rtcrtpsender import RTCRtpSender
from aiortc.codecs import h264
from pipeline import Pipeline
from pipeline import VideoPipeline, AudioPipeline
from utils import patch_loop_datagram

logger = logging.getLogger(__name__)
Expand All @@ -39,93 +37,16 @@ async def recv(self):
return await self.pipeline(frame)

class AudioStreamTrack(MediaStreamTrack):
"""
This custom audio track wraps an incoming audio MediaStreamTrack.
It continuously records frames in 10-second chunks and saves each chunk
as a separate WAV file with an incrementing index.
"""

kind = "audio"

def __init__(self, track: MediaStreamTrack):
def __init__(self, track: MediaStreamTrack, pipeline):
super().__init__()
self.track = track
self.start_time = None
self.frames = []
self._recording_duration = 10.0 # in seconds
self._chunk_index = 0
self._saving = False
self._lock = asyncio.Lock()
self.pipeline = pipeline

async def recv(self):
frame = await self.track.recv()
return frame

# async def recv(self):
# frame = await self.source.recv()

# # On the first frame, record the start time.
# if self.start_time is None:
# self.start_time = frame.time
# logger.info(f"Audio recording started at time: {self.start_time:.3f}")

# elapsed = frame.time - self.start_time
# self.frames.append(frame)

# logger.info(f"Received audio frame at time: {frame.time:.3f}, total frames: {len(self.frames)}")

# # Check if we've hit 10 seconds and we're not currently saving.
# if elapsed >= self._recording_duration and not self._saving:
# logger.info(f"10 second chunk reached (elapsed: {elapsed:.3f}s). Preparing to save chunk {self._chunk_index}.")
# self._saving = True
# # Handle saving in a background task so we don't block the recv loop.
# asyncio.create_task(self.save_audio())

# return frame

async def save_audio(self):
logger.info(f"Starting to save audio chunk {self._chunk_index}...")
async with self._lock:
# Extract properties from the first frame
if not self.frames:
logger.warning("No frames to save, skipping.")
self._saving = False
return

sample_rate = self.frames[0].sample_rate
layout = self.frames[0].layout
channels = len(layout.channels)

logger.info(f"Audio chunk {self._chunk_index}: sample_rate={sample_rate}, channels={channels}, frames_count={len(self.frames)}")

# Convert all frames to ndarray and concatenate
data_arrays = [f.to_ndarray() for f in self.frames]
data = np.concatenate(data_arrays, axis=1) # shape: (channels, total_samples)

# Interleave channels (if multiple) since WAV expects interleaved samples.
interleaved = data.T.flatten()

# If needed, convert float frames to int16
# interleaved = (interleaved * 32767).astype(np.int16)

filename = f"output_{self._chunk_index}.wav"
logger.info(f"Writing audio chunk {self._chunk_index} to file: {filename}")
with wave.open(filename, 'wb') as wf:
wf.setnchannels(channels)
wf.setsampwidth(2) # 16-bit PCM
wf.setframerate(sample_rate)
wf.writeframes(interleaved.tobytes())

logger.info(f"Audio chunk {self._chunk_index} saved successfully as {filename}")

# Increment the chunk index for the next segment
self._chunk_index += 1

# Reset for next recording chunk
self.frames.clear()
self.start_time = None
self._saving = False
logger.info(f"Ready to record next 10-second chunk. Current chunk index: {self._chunk_index}")
return await self.pipeline(frame)


def force_codec(pc, sender, forced_codec):
Expand Down Expand Up @@ -169,13 +90,19 @@ def get_ice_servers():


async def offer(request):
pipeline = request.app["pipeline"]
video_pipeline = request.app["video_pipeline"]
audio_pipeline = request.app["audio_pipeline"]
pcs = request.app["pcs"]

params = await request.json()

pipeline.set_prompt(params["prompt"])
await pipeline.warm()
print("VIDEO PROMPT", params["video_prompt"])
print("AUDIO PROMPT", params["audio_prompt"])

video_pipeline.set_prompt(params["video_prompt"])
await video_pipeline.warm()
audio_pipeline.set_prompt(params["audio_prompt"])
await audio_pipeline.warm()

offer_params = params["offer"]
offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"])
Expand Down Expand Up @@ -206,14 +133,14 @@ async def offer(request):
def on_track(track):
logger.info(f"Track received: {track.kind}")
if track.kind == "video":
videoTrack = VideoStreamTrack(track, pipeline)
videoTrack = VideoStreamTrack(track, video_pipeline)
tracks["video"] = videoTrack
sender = pc.addTrack(videoTrack)

codec = "video/H264"
force_codec(pc, sender, codec)
elif track.kind == "audio":
audioTrack = AudioStreamTrack(track)
audioTrack = AudioStreamTrack(track, audio_pipeline)
tracks["audio"] = audioTrack
pc.addTrack(audioTrack)

Expand Down Expand Up @@ -261,7 +188,10 @@ async def on_startup(app: web.Application):
if app["media_ports"]:
patch_loop_datagram(app["media_ports"])

app["pipeline"] = Pipeline(
app["video_pipeline"] = VideoPipeline(
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
)
app["audio_pipeline"] = AudioPipeline(
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
)
app["pcs"] = set()
Expand Down
43 changes: 40 additions & 3 deletions server/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import av
import numpy as np

from typing import Any, Dict
from typing import Any, Dict, Optional, Union
from comfystream.client import ComfyStreamClient

WARMUP_RUNS = 5


class Pipeline:
class VideoPipeline:
def __init__(self, **kwargs):
self.client = ComfyStreamClient(**kwargs)
self.client = ComfyStreamClient(**kwargs, type="image")

async def warm(self):
frame = torch.randn(1, 512, 512, 3)
Expand Down Expand Up @@ -42,3 +42,40 @@ async def __call__(self, frame: av.VideoFrame) -> av.VideoFrame:
post_output.time_base = frame.time_base

return post_output


class AudioPipeline:
def __init__(self, **kwargs):
self.client = ComfyStreamClient(**kwargs, type="audio")

async def warm(self):
dummy_audio = torch.randn(16000)
for _ in range(WARMUP_RUNS):
await self.predict(dummy_audio)

def set_prompt(self, prompt: Dict[Any, Any]):
self.client.set_prompt(prompt)

def preprocess(self, frame: av.AudioFrame) -> torch.Tensor:
self.sample_rate = frame.sample_rate
samples = frame.to_ndarray(format="s16", layout="stereo")
samples = samples.astype(np.float32) / 32768.0
return torch.from_numpy(samples)

def postprocess(self, output: torch.Tensor) -> Optional[Union[av.AudioFrame, str]]:
out_np = output.cpu().numpy()
out_np = np.clip(out_np * 32768.0, -32768, 32767).astype(np.int16)
audio_frame = av.AudioFrame.from_ndarray(out_np, format="s16", layout="stereo")
return audio_frame

async def predict(self, frame: torch.Tensor) -> torch.Tensor:
return await self.client.queue_prompt(frame)

async def __call__(self, frame: av.AudioFrame):
pre_output = self.preprocess(frame)
pred_output = await self.predict(pre_output)
post_output = self.postprocess(pred_output)
post_output.sample_rate = self.sample_rate
post_output.pts = frame.pts
post_output.time_base = frame.time_base
return post_output
15 changes: 12 additions & 3 deletions src/comfystream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,29 @@


class ComfyStreamClient:
def __init__(self, **kwargs):
def __init__(self, type: str = "image", **kwargs):
config = Configuration(**kwargs)
# TODO: Need to handle cleanup for EmbeddedComfyClient if not using async context manager?
self.comfy_client = EmbeddedComfyClient(config)
self.prompt = None
self.type = type.lower()
if self.type not in {"image", "audio"}:
raise ValueError(f"Unsupported type: {self.type}. Supported types are 'image' and 'audio'.")

self.input_cache = getattr(tensor_cache, f"{self.type}_inputs", None)
self.output_cache = getattr(tensor_cache, f"{self.type}_outputs", None)

if self.input_cache is None or self.output_cache is None:
raise AttributeError(f"tensor_cache does not have attributes for type '{self.type}'.")

def set_prompt(self, prompt: PromptDictInput):
self.prompt = convert_prompt(prompt)

async def queue_prompt(self, input: torch.Tensor) -> torch.Tensor:
tensor_cache.inputs.append(input)
self.input_cache.append(input)

output_fut = asyncio.Future()
tensor_cache.outputs.append(output_fut)
self.output_cache.append(output_fut)

await self.comfy_client.queue_prompt(self.prompt)

Expand Down
7 changes: 5 additions & 2 deletions src/comfystream/tensor_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,8 @@
import torch
from typing import List

inputs: List[torch.Tensor] = []
outputs: List[asyncio.Future] = []
image_inputs: List[torch.Tensor] = []
image_outputs: List[asyncio.Future] = []

audio_inputs: List[torch.Tensor] = []
audio_outputs: List[asyncio.Future] = []
2 changes: 1 addition & 1 deletion src/comfystream/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def convert_prompt(prompt: PromptDictInput) -> Prompt:
num_primary_inputs += 1
elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]:
num_inputs += 1
elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveASRResponse"]:
elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveASRResponse", "SaveAudioTensor"]:
num_outputs += 1

# Only handle single primary input
Expand Down
4 changes: 2 additions & 2 deletions ui/src/app/api/offer/route.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { NextRequest, NextResponse } from "next/server";

export const POST = async function POST(req: NextRequest) {
const { endpoint, prompt, offer } = await req.json();
const { endpoint, video_prompt, audio_prompt, offer } = await req.json();

const res = await fetch(endpoint + "/offer", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ prompt, offer }),
body: JSON.stringify({ video_prompt, audio_prompt, offer }),
});

return NextResponse.json(await res.json(), { status: res.status });
Expand Down
3 changes: 2 additions & 1 deletion ui/src/components/peer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import { PeerContext } from "@/context/peer-context";

export interface PeerProps extends React.HTMLAttributes<HTMLDivElement> {
url: string;
prompt: any;
videoPrompt: any;
audioPrompt: any;
connect: boolean;
onConnected: () => void;
onDisconnected: () => void;
Expand Down
Loading

0 comments on commit 29f6bb7

Please sign in to comment.