Skip to content

Commit 2d04794

Browse files
committed
fix: audio nodes
1 parent 6c529f5 commit 2d04794

File tree

10 files changed

+199
-191
lines changed

10 files changed

+199
-191
lines changed

nodes/audio_utils/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .load_audio_tensor import LoadAudioTensor
2-
from .save_result import SaveResult
32
from .save_audio_tensor import SaveAudioTensor
43

5-
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveResult": SaveResult, "SaveAudioTensor": SaveAudioTensor}
4+
NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveAudioTensor": SaveAudioTensor}
65

76
__all__ = ["NODE_CLASS_MAPPINGS"]

nodes/audio_utils/load_audio_tensor.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,30 @@ class LoadAudioTensor:
88
FUNCTION = "execute"
99

1010
def __init__(self):
11-
self.audio_buffer = np.array([])
11+
self.audio_buffer = np.array([], dtype=np.int16)
12+
self.buffer_samples = None
1213

1314
@classmethod
1415
def INPUT_TYPES(s):
1516
return {
1617
"required": {
17-
"buffer_size": ("FLOAT", {"default": 500.0})
18+
"buffer_size": ("FLOAT", {"default": 500.0}),
19+
"sample_rate": ("INT", {"default": 48000})
1820
}
1921
}
2022

2123
@classmethod
2224
def IS_CHANGED():
2325
return float("nan")
2426

25-
def execute(self, buffer_size):
26-
audio = tensor_cache.audio_inputs.get(block=True)
27-
self.audio_buffer = np.concatenate((self.audio_buffer, audio))
28-
29-
buffer_size_samples = int(buffer_size * 48)
30-
31-
if self.audio_buffer.size >= buffer_size_samples:
32-
buffered_audio = self.audio_buffer[:buffer_size_samples]
33-
self.audio_buffer = self.audio_buffer[buffer_size_samples:]
34-
return (buffered_audio,)
35-
else:
36-
return (None,)
27+
def execute(self, buffer_size, sample_rate):
28+
if not self.buffer_samples:
29+
self.buffer_samples = int(buffer_size * sample_rate / 1000)
30+
31+
while self.audio_buffer.size < self.buffer_samples:
32+
audio = tensor_cache.audio_inputs.get()
33+
self.audio_buffer = np.concatenate((self.audio_buffer, audio))
34+
35+
buffered_audio = self.audio_buffer
36+
self.audio_buffer = np.array([], dtype=np.int16)
37+
return (buffered_audio,)
+17-3
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,39 @@
1+
import queue
2+
import logging
13
from comfystream import tensor_cache
24

5+
logger = logging.getLogger(__name__)
36

47
class SaveAudioTensor:
58
CATEGORY = "audio_utils"
69
RETURN_TYPES = ()
710
FUNCTION = "execute"
811
OUTPUT_NODE = True
912

13+
def __init__(self):
14+
self.frame_samples = None
15+
1016
@classmethod
1117
def INPUT_TYPES(s):
1218
return {
1319
"required": {
1420
"audio": ("AUDIO",),
21+
"frame_size": ("FLOAT", {"default": 20.0}),
22+
"sample_rate": ("INT", {"default": 48000})
1523
}
1624
}
1725

1826
@classmethod
1927
def IS_CHANGED(s):
2028
return float("nan")
2129

22-
def execute(self, audio):
23-
fut = tensor_cache.audio_outputs.get()
24-
fut.set_result((audio))
30+
def execute(self, audio, frame_size, sample_rate):
31+
if self.frame_samples is None:
32+
self.frame_samples = int(frame_size * sample_rate / 1000)
33+
34+
for idx in range(0, len(audio), self.frame_samples):
35+
frame = audio[idx:idx + self.frame_samples]
36+
fut = tensor_cache.audio_outputs.get()
37+
fut.set_result(frame)
2538
return (audio,)
39+

server/app.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ async def offer(request):
106106

107107
params = await request.json()
108108

109-
pipeline.set_prompts(params["prompts"])
109+
await pipeline.set_prompts(params["prompts"])
110110
await pipeline.warm()
111111

112112
offer_params = params["offer"]
@@ -213,7 +213,7 @@ async def set_prompt(request):
213213
pipeline = request.app["pipeline"]
214214

215215
prompt = await request.json()
216-
pipeline.set_prompts(prompt)
216+
await pipeline.set_prompts(prompt)
217217

218218
return web.Response(content_type="application/json", text="OK")
219219

server/pipeline.py

+25-56
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
1-
import torch
21
import av
2+
import torch
33
import numpy as np
4-
import fractions
54
import asyncio
65

7-
from av import AudioFrame
8-
from typing import Any, Dict, Optional, Union, List
6+
from typing import Any, Dict, Union, List
97
from comfystream.client import ComfyStreamClient
10-
from comfystream import tensor_cache
11-
12-
WARMUP_RUNS = 5
138

9+
WARMUP_RUNS = 10
1410

1511

1612
class Pipeline:
@@ -19,53 +15,39 @@ def __init__(self, **kwargs):
1915

2016
self.video_futures = asyncio.Queue()
2117
self.audio_futures = asyncio.Queue()
22-
23-
self.audio_output_frames = []
2418

2519
self.resampler = av.audio.resampler.AudioResampler(format='s16', layout='mono', rate=48000) # find a better way to convert to mono
26-
self.sample_rate = 48000 # instead of hardcoding, find a clean way to set from audio frame
27-
self.frame_size = int(self.sample_rate * 0.02)
28-
self.time_base = fractions.Fraction(1, self.sample_rate)
29-
self.curr_pts = 0 # figure out a better way to set back pts to processed audio frames
30-
31-
def set_prompt(self, prompt: Dict[Any, Any]):
32-
self.client.set_prompt(prompt)
3320

3421
async def warm(self):
35-
dummy_video_frame = torch.randn(1, 512, 512, 3)
36-
dummy_audio_frame = np.random.randint(-32768, 32767, 48000 * 1, dtype=np.int16)
22+
dummy_video_inp = torch.randn(1, 512, 512, 3)
23+
dummy_audio_inp = np.random.randint(-32768, 32767, 48 * 20, dtype=np.int16) # has to be more than the buffer size in comfy workflow
3724

3825
for _ in range(WARMUP_RUNS):
39-
image_out_fut = asyncio.Future()
40-
audio_out_fut = asyncio.Future()
41-
tensor_cache.image_outputs.put(image_out_fut)
42-
tensor_cache.audio_outputs.put(audio_out_fut)
26+
image_out_fut = self.client.put_video_input(dummy_video_inp)
27+
await image_out_fut
4328

44-
tensor_cache.image_inputs.put(dummy_video_frame)
45-
tensor_cache.audio_inputs.put(dummy_audio_frame)
29+
futs = []
30+
for _ in range(WARMUP_RUNS):
31+
audio_out_fut = self.client.put_audio_input(dummy_audio_inp)
32+
futs.append(audio_out_fut)
4633

47-
await image_out_fut
48-
await audio_out_fut
34+
await asyncio.gather(*futs)
4935

50-
def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
36+
async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
5137
if isinstance(prompts, dict):
52-
self.client.set_prompts([prompts])
38+
await self.client.set_prompts([prompts])
5339
else:
54-
self.client.set_prompts(prompts)
40+
await self.client.set_prompts(prompts)
5541

5642
async def put_video_frame(self, frame: av.VideoFrame):
5743
inp_tensor = self.video_preprocess(frame)
58-
out_future = asyncio.Future()
59-
tensor_cache.image_outputs.put(out_future)
60-
tensor_cache.image_inputs.put(inp_tensor)
44+
out_future = self.client.put_video_input(inp_tensor)
6145
await self.video_futures.put((out_future, frame.pts, frame.time_base))
6246

6347
async def put_audio_frame(self, frame: av.AudioFrame):
6448
inp_tensor = self.audio_preprocess(frame)
65-
out_future = asyncio.Future()
66-
tensor_cache.audio_outputs.put(out_future)
67-
tensor_cache.audio_inputs.put(inp_tensor)
68-
await self.audio_futures.put(out_future)
49+
out_future = self.client.put_audio_input(inp_tensor)
50+
await self.audio_futures.put((out_future, frame.pts, frame.time_base, frame.sample_rate))
6951

7052
def video_preprocess(self, frame: av.VideoFrame) -> torch.Tensor:
7153
frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0
@@ -80,18 +62,7 @@ def video_postprocess(self, output: torch.Tensor) -> av.VideoFrame:
8062
)
8163

8264
def audio_postprocess(self, output: torch.Tensor) -> av.AudioFrame:
83-
frames = []
84-
for idx in range(0, len(output), self.frame_size):
85-
frame_samples = output[idx:idx + self.frame_size]
86-
frame_samples = frame_samples.reshape(1, -1).astype(np.int16)
87-
frame = AudioFrame.from_ndarray(frame_samples, layout="mono")
88-
frame.sample_rate = self.sample_rate
89-
frame.pts = self.curr_pts
90-
frame.time_base = self.time_base
91-
self.curr_pts += 960
92-
93-
frames.append(frame)
94-
return frames
65+
return av.AudioFrame.from_ndarray(output.reshape(1, -1), layout="mono")
9566

9667
async def get_processed_video_frame(self):
9768
out_fut, pts, time_base = await self.video_futures.get()
@@ -101,14 +72,12 @@ async def get_processed_video_frame(self):
10172
return frame
10273

10374
async def get_processed_audio_frame(self):
104-
while not self.audio_output_frames:
105-
out_fut = await self.audio_futures.get()
106-
output = await out_fut
107-
if output is None:
108-
print("No Audio output")
109-
continue
110-
self.audio_output_frames.extend(self.audio_postprocess(output))
111-
return self.audio_output_frames.pop(0)
75+
out_fut, pts, time_base, sample_rate = await self.audio_futures.get()
76+
frame = self.audio_postprocess(await out_fut)
77+
frame.pts = pts
78+
frame.time_base = time_base
79+
frame.sample_rate = sample_rate
80+
return frame
11281

11382
async def get_nodes_info(self) -> Dict[str, Any]:
11483
"""Get information about all nodes in the current prompt including metadata."""

0 commit comments

Comments
 (0)