Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: audio support #10

Merged
merged 4 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docker/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ RUN export DEBIAN_FRONTEND=noninteractive && apt-get update && apt-get install -
wget \
nano \
socat \
libsndfile1 \
build-essential llvm tk-dev \
&& rm -rf /var/lib/apt/lists/*

Expand Down
7 changes: 7 additions & 0 deletions nodes/audio_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .load_audio_tensor import LoadAudioTensor
from .save_audio_tensor import SaveAudioTensor
from .pitch_shift import PitchShifter

NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveAudioTensor": SaveAudioTensor, "PitchShifter": PitchShifter}

__all__ = ["NODE_CLASS_MAPPINGS"]
52 changes: 52 additions & 0 deletions nodes/audio_utils/load_audio_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np

from comfystream import tensor_cache

class LoadAudioTensor:
CATEGORY = "audio_utils"
RETURN_TYPES = ("WAVEFORM", "INT")
FUNCTION = "execute"

def __init__(self):
self.audio_buffer = np.empty(0, dtype=np.int16)
self.buffer_samples = None
self.sample_rate = None

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"buffer_size": ("FLOAT", {"default": 500.0}),
}
}

@classmethod
def IS_CHANGED():
return float("nan")

def execute(self, buffer_size):
if self.sample_rate is None or self.buffer_samples is None:
frame = tensor_cache.audio_inputs.get(block=True)
self.sample_rate = frame.sample_rate
self.buffer_samples = int(self.sample_rate * buffer_size / 1000)
self.leftover = frame.side_data.input

if self.leftover.shape[0] < self.buffer_samples:
chunks = [self.leftover] if self.leftover.size > 0 else []
total_samples = self.leftover.shape[0]

while total_samples < self.buffer_samples:
frame = tensor_cache.audio_inputs.get(block=True)
if frame.sample_rate != self.sample_rate:
raise ValueError("Sample rate mismatch")
chunks.append(frame.side_data.input)
total_samples += frame.side_data.input.shape[0]

merged_audio = np.concatenate(chunks, dtype=np.int16)
buffered_audio = merged_audio[:self.buffer_samples]
self.leftover = merged_audio[self.buffer_samples:]
else:
buffered_audio = self.leftover[:self.buffer_samples]
self.leftover = self.leftover[self.buffer_samples:]

return buffered_audio, self.sample_rate
32 changes: 32 additions & 0 deletions nodes/audio_utils/pitch_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
import librosa

class PitchShifter:
CATEGORY = "audio_utils"
RETURN_TYPES = ("WAVEFORM", "INT")
FUNCTION = "execute"

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("WAVEFORM",),
"sample_rate": ("INT",),
"pitch_shift": ("FLOAT", {
"default": 4.0,
"min": 0.0,
"max": 12.0,
"step": 0.5
}),
}
}

@classmethod
def IS_CHANGED(cls):
return float("nan")

def execute(self, audio, sample_rate, pitch_shift):
audio_float = audio.astype(np.float32) / 32768.0
shifted_audio = librosa.effects.pitch_shift(y=audio_float, sr=sample_rate, n_steps=pitch_shift)
shifted_int16 = np.clip(shifted_audio * 32768.0, -32768, 32767).astype(np.int16)
return shifted_int16, sample_rate
25 changes: 25 additions & 0 deletions nodes/audio_utils/save_audio_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from comfystream import tensor_cache

class SaveAudioTensor:
CATEGORY = "audio_utils"
RETURN_TYPES = ()
FUNCTION = "execute"
OUTPUT_NODE = True


@classmethod
def INPUT_TYPES(s):
return {
"required": {
"audio": ("WAVEFORM",)
}
}

@classmethod
def IS_CHANGED(s):
return float("nan")

def execute(self, audio):
tensor_cache.audio_outputs.put_nowait(audio)
return (audio,)

5 changes: 3 additions & 2 deletions nodes/tensor_utils/load_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ def IS_CHANGED():
return float("nan")

def execute(self):
input = tensor_cache.inputs.pop()
return (input,)
frame = tensor_cache.image_inputs.get(block=True)
frame.side_data.skipped = False
return (frame.side_data.input,)
3 changes: 1 addition & 2 deletions nodes/tensor_utils/save_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ def IS_CHANGED(s):
return float("nan")

def execute(self, images: torch.Tensor):
fut = tensor_cache.outputs.pop()
fut.set_result(images)
tensor_cache.image_outputs.put_nowait(images)
return images
75 changes: 57 additions & 18 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,44 @@

class VideoStreamTrack(MediaStreamTrack):
kind = "video"
def __init__(self, track: MediaStreamTrack, pipeline):
super().__init__()
self.track = track
self.pipeline = pipeline
asyncio.create_task(self.collect_frames())

async def collect_frames(self):
while True:
try:
frame = await self.track.recv()
await self.pipeline.put_video_frame(frame)
except Exception as e:
await self.pipeline.cleanup()
raise Exception(f"Error collecting video frames: {str(e)}")

async def recv(self):
return await self.pipeline.get_processed_video_frame()


class AudioStreamTrack(MediaStreamTrack):
kind = "audio"
def __init__(self, track: MediaStreamTrack, pipeline):
super().__init__()
self.track = track
self.pipeline = pipeline
asyncio.create_task(self.collect_frames())

async def collect_frames(self):
while True:
try:
frame = await self.track.recv()
await self.pipeline.put_audio_frame(frame)
except Exception as e:
await self.pipeline.cleanup()
raise Exception(f"Error collecting audio frames: {str(e)}")

async def recv(self):
frame = await self.track.recv()
return await self.pipeline(frame)
return await self.pipeline.get_processed_audio_frame()


def force_codec(pc, sender, forced_codec):
Expand Down Expand Up @@ -87,8 +116,7 @@ async def offer(request):

params = await request.json()

pipeline.set_prompt(params["prompt"])
await pipeline.warm()
await pipeline.set_prompts(params["prompts"])

offer_params = params["offer"]
offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"])
Expand All @@ -103,17 +131,19 @@ async def offer(request):

pcs.add(pc)

tracks = {"video": None}
tracks = {"video": None, "audio": None}

# Prefer h264
transceiver = pc.addTransceiver("video")
caps = RTCRtpSender.getCapabilities("video")
prefs = list(filter(lambda x: x.name == "H264", caps.codecs))
transceiver.setCodecPreferences(prefs)
# Only add video transceiver if video is present in the offer
if "m=video" in offer.sdp:
# Prefer h264
transceiver = pc.addTransceiver("video")
caps = RTCRtpSender.getCapabilities("video")
prefs = list(filter(lambda x: x.name == "H264", caps.codecs))
transceiver.setCodecPreferences(prefs)

# Monkey patch max and min bitrate to ensure constant bitrate
h264.MAX_BITRATE = MAX_BITRATE
h264.MIN_BITRATE = MIN_BITRATE
# Monkey patch max and min bitrate to ensure constant bitrate
h264.MAX_BITRATE = MAX_BITRATE
h264.MIN_BITRATE = MIN_BITRATE

# Handle control channel from client
@pc.on("datachannel")
Expand All @@ -131,13 +161,13 @@ async def on_message(message):
"nodes": nodes_info
}
channel.send(json.dumps(response))
elif params.get("type") == "update_prompt":
if "prompt" not in params:
elif params.get("type") == "update_prompts":
if "prompts" not in params:
logger.warning("[Control] Missing prompt in update_prompt message")
return
pipeline.set_prompt(params["prompt"])
await pipeline.update_prompts(params["prompts"])
response = {
"type": "prompt_updated",
"type": "prompts_updated",
"success": True
}
channel.send(json.dumps(response))
Expand All @@ -158,6 +188,10 @@ def on_track(track):

codec = "video/H264"
force_codec(pc, sender, codec)
elif track.kind == "audio":
audioTrack = AudioStreamTrack(track, pipeline)
tracks["audio"] = audioTrack
pc.addTrack(audioTrack)

@track.on("ended")
async def on_ended():
Expand All @@ -175,6 +209,11 @@ async def on_connectionstatechange():

await pc.setRemoteDescription(offer)

if "m=audio" in pc.remoteDescription.sdp:
await pipeline.warm_audio()
if "m=video" in pc.remoteDescription.sdp:
await pipeline.warm_video()

answer = await pc.createAnswer()
await pc.setLocalDescription(answer)

Expand All @@ -190,7 +229,7 @@ async def set_prompt(request):
pipeline = request.app["pipeline"]

prompt = await request.json()
pipeline.set_prompt(prompt)
await pipeline.set_prompts(prompt)

return web.Response(content_type="application/json", text="OK")

Expand Down
Loading