Skip to content

Commit

Permalink
feat: combine audio and video streams
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith15 committed Jan 24, 2025
1 parent 49deb2f commit 415c387
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 168 deletions.
26 changes: 22 additions & 4 deletions nodes/audio_utils/load_audio_tensor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
import numpy as np

from comfystream import tensor_cache

class LoadAudioTensor:
CATEGORY = "audio_utils"
RETURN_TYPES = ("AUDIO",)
FUNCTION = "execute"

def __init__(self):
self.audio_buffer = np.array([])

@classmethod
def INPUT_TYPES(s):
return {}
return {
"required": {
"buffer_size": ("FLOAT", {"default": 500.0})
}
}

@classmethod
def IS_CHANGED():
return float("nan")

def execute(self):
audio = tensor_cache.audio_inputs.pop()
return (audio,)
def execute(self, buffer_size):
audio = tensor_cache.audio_inputs.get(block=True)
self.audio_buffer = np.concatenate((self.audio_buffer, audio))

buffer_size_samples = int(buffer_size * 48)

if self.audio_buffer.size >= buffer_size_samples:
buffered_audio = self.audio_buffer[:buffer_size_samples]
self.audio_buffer = self.audio_buffer[buffer_size_samples:]
return (buffered_audio,)
else:
return (None,)
9 changes: 4 additions & 5 deletions nodes/audio_utils/save_audio_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ def INPUT_TYPES(s):
return {
"required": {
"audio": ("AUDIO",),
"text": ("TEXT",)
}
}

@classmethod
def IS_CHANGED(s):
return float("nan")

def execute(self, audio, text):
fut = tensor_cache.audio_outputs.pop()
fut.set_result((audio, text))
return (audio, text)
def execute(self, audio):
fut = tensor_cache.audio_outputs.get()
fut.set_result((audio))
return (audio,)
24 changes: 0 additions & 24 deletions nodes/audio_utils/save_result.py

This file was deleted.

2 changes: 1 addition & 1 deletion nodes/tensor_utils/load_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def IS_CHANGED():
return float("nan")

def execute(self):
input = tensor_cache.image_inputs.pop()
input = tensor_cache.image_inputs.get(block=True)
return (input,)
2 changes: 1 addition & 1 deletion nodes/tensor_utils/save_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ def IS_CHANGED(s):
return float("nan")

def execute(self, images: torch.Tensor):
fut = tensor_cache.image_outputs.pop()
fut = tensor_cache.image_outputs.get()
fut.set_result(images)
return images
45 changes: 12 additions & 33 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from aiortc.rtcrtpsender import RTCRtpSender
from aiortc.codecs import h264
from pipeline import VideoPipeline, AudioPipeline
from pipeline import Pipeline
from utils import patch_loop_datagram

logger = logging.getLogger(__name__)
Expand All @@ -30,17 +30,15 @@ def __init__(self, track: MediaStreamTrack, pipeline):
super().__init__()
self.track = track
self.pipeline = pipeline
self.processed_frames = asyncio.Queue()
asyncio.create_task(self.collect_frames())

async def collect_frames(self):
while True:
frame = await self.track.recv()
processed = await self.pipeline(frame)
await self.processed_frames.put(processed)
await self.pipeline.put_video_frame(frame)

async def recv(self):
return await self.processed_frames.get()
return await self.pipeline.get_processed_video_frame()


class AudioStreamTrack(MediaStreamTrack):
Expand All @@ -49,28 +47,15 @@ def __init__(self, track: MediaStreamTrack, pipeline):
super().__init__()
self.track = track
self.pipeline = pipeline
self.incoming_frames = asyncio.Queue()
self.processed_frames = asyncio.Queue()
asyncio.create_task(self.collect_frames())
asyncio.create_task(self.process_frames())
self.started = False

async def collect_frames(self):
while True:
frame = await self.track.recv()
await self.incoming_frames.put(frame)

async def process_frames(self):
while True:
frames = []
while len(frames) < 25:
frames.append(await self.incoming_frames.get())
processed_frames = await self.pipeline(frames)
for processed_frame in processed_frames:
await self.processed_frames.put(processed_frame)
await self.pipeline.put_audio_frame(frame)

async def recv(self):
return await self.processed_frames.get()
return await self.pipeline.get_processed_audio_frame()


def force_codec(pc, sender, forced_codec):
Expand Down Expand Up @@ -114,16 +99,13 @@ def get_ice_servers():


async def offer(request):
video_pipeline = request.app["video_pipeline"]
audio_pipeline = request.app["audio_pipeline"]
pipeline = request.app["pipeline"]
pcs = request.app["pcs"]

params = await request.json()

video_pipeline.set_prompt(params["video_prompt"])
await video_pipeline.warm()
audio_pipeline.set_prompt(params["audio_prompt"])
await audio_pipeline.warm()
pipeline.set_prompts(params["video_prompt"])
await pipeline.warm()

offer_params = params["offer"]
offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"])
Expand Down Expand Up @@ -154,14 +136,14 @@ async def offer(request):
def on_track(track):
logger.info(f"Track received: {track.kind}")
if track.kind == "video":
videoTrack = VideoStreamTrack(track, video_pipeline)
videoTrack = VideoStreamTrack(track, pipeline)
tracks["video"] = videoTrack
sender = pc.addTrack(videoTrack)

codec = "video/H264"
force_codec(pc, sender, codec)
elif track.kind == "audio":
audioTrack = AudioStreamTrack(track, audio_pipeline)
audioTrack = AudioStreamTrack(track, pipeline)
tracks["audio"] = audioTrack
pc.addTrack(audioTrack)

Expand Down Expand Up @@ -196,7 +178,7 @@ async def set_prompt(request):
pipeline = request.app["pipeline"]

prompt = await request.json()
pipeline.set_prompt(prompt)
pipeline.set_prompts(prompt)

return web.Response(content_type="application/json", text="OK")

Expand All @@ -209,10 +191,7 @@ async def on_startup(app: web.Application):
if app["media_ports"]:
patch_loop_datagram(app["media_ports"])

app["video_pipeline"] = VideoPipeline(
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
)
app["audio_pipeline"] = AudioPipeline(
app["pipeline"] = Pipeline(
cwd=app["workspace"], disable_cuda_malloc=True, gpu_only=True
)
app["pcs"] = set()
Expand Down
Loading

0 comments on commit 415c387

Please sign in to comment.