diff --git a/docker/Dockerfile.base b/docker/Dockerfile.base index 647f4978..c6b70692 100644 --- a/docker/Dockerfile.base +++ b/docker/Dockerfile.base @@ -9,6 +9,7 @@ RUN export DEBIAN_FRONTEND=noninteractive && apt-get update && apt-get install - wget \ nano \ socat \ + libsndfile1 \ build-essential llvm tk-dev \ && rm -rf /var/lib/apt/lists/* diff --git a/nodes/audio_utils/__init__.py b/nodes/audio_utils/__init__.py new file mode 100644 index 00000000..0251eac8 --- /dev/null +++ b/nodes/audio_utils/__init__.py @@ -0,0 +1,7 @@ +from .load_audio_tensor import LoadAudioTensor +from .save_audio_tensor import SaveAudioTensor +from .pitch_shift import PitchShifter + +NODE_CLASS_MAPPINGS = {"LoadAudioTensor": LoadAudioTensor, "SaveAudioTensor": SaveAudioTensor, "PitchShifter": PitchShifter} + +__all__ = ["NODE_CLASS_MAPPINGS"] diff --git a/nodes/audio_utils/load_audio_tensor.py b/nodes/audio_utils/load_audio_tensor.py new file mode 100644 index 00000000..643c7d50 --- /dev/null +++ b/nodes/audio_utils/load_audio_tensor.py @@ -0,0 +1,52 @@ +import numpy as np + +from comfystream import tensor_cache + +class LoadAudioTensor: + CATEGORY = "audio_utils" + RETURN_TYPES = ("WAVEFORM", "INT") + FUNCTION = "execute" + + def __init__(self): + self.audio_buffer = np.empty(0, dtype=np.int16) + self.buffer_samples = None + self.sample_rate = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "buffer_size": ("FLOAT", {"default": 500.0}), + } + } + + @classmethod + def IS_CHANGED(): + return float("nan") + + def execute(self, buffer_size): + if self.sample_rate is None or self.buffer_samples is None: + frame = tensor_cache.audio_inputs.get(block=True) + self.sample_rate = frame.sample_rate + self.buffer_samples = int(self.sample_rate * buffer_size / 1000) + self.leftover = frame.side_data.input + + if self.leftover.shape[0] < self.buffer_samples: + chunks = [self.leftover] if self.leftover.size > 0 else [] + total_samples = self.leftover.shape[0] + + while total_samples < self.buffer_samples: + frame = tensor_cache.audio_inputs.get(block=True) + if frame.sample_rate != self.sample_rate: + raise ValueError("Sample rate mismatch") + chunks.append(frame.side_data.input) + total_samples += frame.side_data.input.shape[0] + + merged_audio = np.concatenate(chunks, dtype=np.int16) + buffered_audio = merged_audio[:self.buffer_samples] + self.leftover = merged_audio[self.buffer_samples:] + else: + buffered_audio = self.leftover[:self.buffer_samples] + self.leftover = self.leftover[self.buffer_samples:] + + return buffered_audio, self.sample_rate diff --git a/nodes/audio_utils/pitch_shift.py b/nodes/audio_utils/pitch_shift.py new file mode 100644 index 00000000..ed2b2b38 --- /dev/null +++ b/nodes/audio_utils/pitch_shift.py @@ -0,0 +1,32 @@ +import numpy as np +import librosa + +class PitchShifter: + CATEGORY = "audio_utils" + RETURN_TYPES = ("WAVEFORM", "INT") + FUNCTION = "execute" + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "audio": ("WAVEFORM",), + "sample_rate": ("INT",), + "pitch_shift": ("FLOAT", { + "default": 4.0, + "min": 0.0, + "max": 12.0, + "step": 0.5 + }), + } + } + + @classmethod + def IS_CHANGED(cls): + return float("nan") + + def execute(self, audio, sample_rate, pitch_shift): + audio_float = audio.astype(np.float32) / 32768.0 + shifted_audio = librosa.effects.pitch_shift(y=audio_float, sr=sample_rate, n_steps=pitch_shift) + shifted_int16 = np.clip(shifted_audio * 32768.0, -32768, 32767).astype(np.int16) + return shifted_int16, sample_rate diff --git a/nodes/audio_utils/save_audio_tensor.py b/nodes/audio_utils/save_audio_tensor.py new file mode 100644 index 00000000..1fa56678 --- /dev/null +++ b/nodes/audio_utils/save_audio_tensor.py @@ -0,0 +1,25 @@ +from comfystream import tensor_cache + +class SaveAudioTensor: + CATEGORY = "audio_utils" + RETURN_TYPES = () + FUNCTION = "execute" + OUTPUT_NODE = True + + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "audio": ("WAVEFORM",) + } + } + + @classmethod + def IS_CHANGED(s): + return float("nan") + + def execute(self, audio): + tensor_cache.audio_outputs.put_nowait(audio) + return (audio,) + diff --git a/nodes/tensor_utils/load_tensor.py b/nodes/tensor_utils/load_tensor.py index 622f176e..c39fe8a1 100644 --- a/nodes/tensor_utils/load_tensor.py +++ b/nodes/tensor_utils/load_tensor.py @@ -15,5 +15,6 @@ def IS_CHANGED(): return float("nan") def execute(self): - input = tensor_cache.inputs.pop() - return (input,) + frame = tensor_cache.image_inputs.get(block=True) + frame.side_data.skipped = False + return (frame.side_data.input,) diff --git a/nodes/tensor_utils/save_tensor.py b/nodes/tensor_utils/save_tensor.py index 131604f7..3a021aa5 100644 --- a/nodes/tensor_utils/save_tensor.py +++ b/nodes/tensor_utils/save_tensor.py @@ -22,6 +22,5 @@ def IS_CHANGED(s): return float("nan") def execute(self, images: torch.Tensor): - fut = tensor_cache.outputs.pop() - fut.set_result(images) + tensor_cache.image_outputs.put_nowait(images) return images diff --git a/server/app.py b/server/app.py index a65fa8de..8c83ab32 100644 --- a/server/app.py +++ b/server/app.py @@ -30,15 +30,44 @@ class VideoStreamTrack(MediaStreamTrack): kind = "video" + def __init__(self, track: MediaStreamTrack, pipeline): + super().__init__() + self.track = track + self.pipeline = pipeline + asyncio.create_task(self.collect_frames()) + + async def collect_frames(self): + while True: + try: + frame = await self.track.recv() + await self.pipeline.put_video_frame(frame) + except Exception as e: + await self.pipeline.cleanup() + raise Exception(f"Error collecting video frames: {str(e)}") + async def recv(self): + return await self.pipeline.get_processed_video_frame() + + +class AudioStreamTrack(MediaStreamTrack): + kind = "audio" def __init__(self, track: MediaStreamTrack, pipeline): super().__init__() self.track = track self.pipeline = pipeline + asyncio.create_task(self.collect_frames()) + + async def collect_frames(self): + while True: + try: + frame = await self.track.recv() + await self.pipeline.put_audio_frame(frame) + except Exception as e: + await self.pipeline.cleanup() + raise Exception(f"Error collecting audio frames: {str(e)}") async def recv(self): - frame = await self.track.recv() - return await self.pipeline(frame) + return await self.pipeline.get_processed_audio_frame() def force_codec(pc, sender, forced_codec): @@ -87,8 +116,7 @@ async def offer(request): params = await request.json() - pipeline.set_prompt(params["prompt"]) - await pipeline.warm() + await pipeline.set_prompts(params["prompts"]) offer_params = params["offer"] offer = RTCSessionDescription(sdp=offer_params["sdp"], type=offer_params["type"]) @@ -103,17 +131,19 @@ async def offer(request): pcs.add(pc) - tracks = {"video": None} + tracks = {"video": None, "audio": None} - # Prefer h264 - transceiver = pc.addTransceiver("video") - caps = RTCRtpSender.getCapabilities("video") - prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) - transceiver.setCodecPreferences(prefs) + # Only add video transceiver if video is present in the offer + if "m=video" in offer.sdp: + # Prefer h264 + transceiver = pc.addTransceiver("video") + caps = RTCRtpSender.getCapabilities("video") + prefs = list(filter(lambda x: x.name == "H264", caps.codecs)) + transceiver.setCodecPreferences(prefs) - # Monkey patch max and min bitrate to ensure constant bitrate - h264.MAX_BITRATE = MAX_BITRATE - h264.MIN_BITRATE = MIN_BITRATE + # Monkey patch max and min bitrate to ensure constant bitrate + h264.MAX_BITRATE = MAX_BITRATE + h264.MIN_BITRATE = MIN_BITRATE # Handle control channel from client @pc.on("datachannel") @@ -131,13 +161,13 @@ async def on_message(message): "nodes": nodes_info } channel.send(json.dumps(response)) - elif params.get("type") == "update_prompt": - if "prompt" not in params: + elif params.get("type") == "update_prompts": + if "prompts" not in params: logger.warning("[Control] Missing prompt in update_prompt message") return - pipeline.set_prompt(params["prompt"]) + await pipeline.update_prompts(params["prompts"]) response = { - "type": "prompt_updated", + "type": "prompts_updated", "success": True } channel.send(json.dumps(response)) @@ -158,6 +188,10 @@ def on_track(track): codec = "video/H264" force_codec(pc, sender, codec) + elif track.kind == "audio": + audioTrack = AudioStreamTrack(track, pipeline) + tracks["audio"] = audioTrack + pc.addTrack(audioTrack) @track.on("ended") async def on_ended(): @@ -175,6 +209,11 @@ async def on_connectionstatechange(): await pc.setRemoteDescription(offer) + if "m=audio" in pc.remoteDescription.sdp: + await pipeline.warm_audio() + if "m=video" in pc.remoteDescription.sdp: + await pipeline.warm_video() + answer = await pc.createAnswer() await pc.setLocalDescription(answer) @@ -190,7 +229,7 @@ async def set_prompt(request): pipeline = request.app["pipeline"] prompt = await request.json() - pipeline.set_prompt(prompt) + await pipeline.set_prompts(prompt) return web.Response(content_type="application/json", text="OK") diff --git a/server/pipeline.py b/server/pipeline.py index cb335609..11163447 100644 --- a/server/pipeline.py +++ b/server/pipeline.py @@ -1,50 +1,112 @@ -import torch import av +import torch import numpy as np +import asyncio -from typing import Any, Dict +from typing import Any, Dict, Union, List from comfystream.client import ComfyStreamClient WARMUP_RUNS = 5 - class Pipeline: def __init__(self, **kwargs): - self.client = ComfyStreamClient(**kwargs) + self.client = ComfyStreamClient(**kwargs, max_workers=5) # TODO: hardcoded max workers, should it be configurable? + + self.video_incoming_frames = asyncio.Queue() + self.audio_incoming_frames = asyncio.Queue() - def set_prompt(self, prompt: Dict[Any, Any]): - self.client.set_prompt(prompt) + self.processed_audio_buffer = np.array([], dtype=np.int16) - async def warm(self): - frame = torch.randn(1, 512, 512, 3) + async def warm_video(self): + dummy_frame = av.VideoFrame() + dummy_frame.side_data.input = torch.randn(1, 512, 512, 3) for _ in range(WARMUP_RUNS): - await self.predict(frame) + self.client.put_video_input(dummy_frame) + await self.client.get_video_output() - def preprocess(self, frame: av.VideoFrame) -> torch.Tensor: - frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 - return torch.from_numpy(frame_np).unsqueeze(0) + async def warm_audio(self): + dummy_frame = av.AudioFrame() + dummy_frame.side_data.input = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed? + dummy_frame.sample_rate = 48000 - async def predict(self, frame: torch.Tensor) -> torch.Tensor: - return await self.client.queue_prompt(frame) + for _ in range(WARMUP_RUNS): + self.client.put_audio_input(dummy_frame) + await self.client.get_audio_output() - def postprocess(self, frame: torch.Tensor) -> av.VideoFrame: + async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + if isinstance(prompts, list): + await self.client.set_prompts(prompts) + else: + await self.client.set_prompts([prompts]) + + async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]): + if isinstance(prompts, list): + await self.client.update_prompts(prompts) + else: + await self.client.update_prompts([prompts]) + + async def put_video_frame(self, frame: av.VideoFrame): + frame.side_data.input = self.video_preprocess(frame) + frame.side_data.skipped = True + self.client.put_video_input(frame) + await self.video_incoming_frames.put(frame) + + async def put_audio_frame(self, frame: av.AudioFrame): + frame.side_data.input = self.audio_preprocess(frame) + frame.side_data.skipped = True + self.client.put_audio_input(frame) + await self.audio_incoming_frames.put(frame) + + def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]: + frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0 + return torch.from_numpy(frame_np).unsqueeze(0) + + def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]: + return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16) + + def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame: return av.VideoFrame.from_ndarray( - (frame * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() + (output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy() ) - async def __call__(self, frame: av.VideoFrame) -> av.VideoFrame: - pre_output = self.preprocess(frame) - pred_output = await self.predict(pre_output) - post_output = self.postprocess(pred_output) + def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame: + return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1)) + + async def get_processed_video_frame(self): + # TODO: make it generic to support purely generative video cases + out_tensor = await self.client.get_video_output() + frame = await self.video_incoming_frames.get() + while frame.side_data.skipped: + frame = await self.video_incoming_frames.get() - post_output.pts = frame.pts - post_output.time_base = frame.time_base + processed_frame = self.video_postprocess(out_tensor) + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base + + return processed_frame - return post_output + async def get_processed_audio_frame(self): + # TODO: make it generic to support purely generative audio cases and also add frame skipping + frame = await self.audio_incoming_frames.get() + if frame.samples > len(self.processed_audio_buffer): + out_tensor = await self.client.get_audio_output() + self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor]) + out_data = self.processed_audio_buffer[:frame.samples] + self.processed_audio_buffer = self.processed_audio_buffer[frame.samples:] + processed_frame = self.audio_postprocess(out_data) + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base + processed_frame.sample_rate = frame.sample_rate + + return processed_frame + async def get_nodes_info(self) -> Dict[str, Any]: """Get information about all nodes in the current prompt including metadata.""" nodes_info = await self.client.get_available_nodes() - return nodes_info \ No newline at end of file + return nodes_info + + async def cleanup(self): + await self.client.cleanup() \ No newline at end of file diff --git a/src/comfystream/client.py b/src/comfystream/client.py index ef3a1532..68eefda6 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -1,60 +1,114 @@ -import torch import asyncio -from typing import Any -import json +from typing import List import logging +from comfystream import tensor_cache +from comfystream.utils import convert_prompt + from comfy.api.components.schema.prompt import PromptDictInput from comfy.cli_args_types import Configuration from comfy.client.embedded_comfy_client import EmbeddedComfyClient -from comfystream import tensor_cache -from comfystream.utils import convert_prompt logger = logging.getLogger(__name__) class ComfyStreamClient: - def __init__(self, **kwargs): + def __init__(self, max_workers: int = 1, **kwargs): config = Configuration(**kwargs) - self.comfy_client = EmbeddedComfyClient(config) - self.prompt = None - self._lock = asyncio.Lock() - - def set_prompt(self, prompt: PromptDictInput): - self.prompt = convert_prompt(prompt) - - async def queue_prompt(self, input: torch.Tensor) -> torch.Tensor: - async with self._lock: - tensor_cache.inputs.append(input) - output_fut = asyncio.Future() - tensor_cache.outputs.append(output_fut) + self.comfy_client = EmbeddedComfyClient(config, max_workers=max_workers) + self.running_prompts = {} # To be used for cancelling tasks + self.current_prompts = [] + self.cleanup_lock = asyncio.Lock() + + async def set_prompts(self, prompts: List[PromptDictInput]): + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + for idx in range(len(self.current_prompts)): + task = asyncio.create_task(self.run_prompt(idx)) + self.running_prompts[idx] = task + + async def update_prompts(self, prompts: List[PromptDictInput]): + # TODO: currently under the assumption that only already running prompts are updated + if len(prompts) != len(self.current_prompts): + raise ValueError( + "Number of updated prompts must match the number of currently running prompts." + ) + self.current_prompts = [convert_prompt(prompt) for prompt in prompts] + + async def run_prompt(self, prompt_index: int): + while True: try: - await self.comfy_client.queue_prompt(self.prompt) + await self.comfy_client.queue_prompt(self.current_prompts[prompt_index]) except Exception as e: - logger.error(f"Error queueing prompt: {str(e)}") - logger.error(f"Error type: {type(e)}") + await self.cleanup() + logger.error(f"Error running prompt: {str(e)}") raise - return await output_fut + + async def cleanup(self): + async with self.cleanup_lock: + for task in self.running_prompts.values(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self.running_prompts.clear() + + if self.comfy_client.is_running: + await self.comfy_client.__aexit__() + + await self.cleanup_queues() + logger.info("Client cleanup complete") + + + async def cleanup_queues(self): + while not tensor_cache.image_inputs.empty(): + tensor_cache.image_inputs.get() + + while not tensor_cache.audio_inputs.empty(): + tensor_cache.audio_inputs.get() + + while not tensor_cache.image_outputs.empty(): + await tensor_cache.image_outputs.get() + + while not tensor_cache.audio_outputs.empty(): + await tensor_cache.audio_outputs.get() + + def put_video_input(self, frame): + if tensor_cache.image_inputs.full(): + tensor_cache.image_inputs.get(block=True) + tensor_cache.image_inputs.put(frame) + + def put_audio_input(self, frame): + tensor_cache.audio_inputs.put(frame) + + async def get_video_output(self): + return await tensor_cache.image_outputs.get() + + async def get_audio_output(self): + return await tensor_cache.audio_outputs.get() async def get_available_nodes(self): """Get metadata and available nodes info in a single pass""" - async with self._lock: - if not self.prompt: - return {} + # TODO: make it for for multiple prompts + if not self.running_prompts: + return {} - try: - from comfy.nodes.package import import_all_nodes_in_workspace - nodes = import_all_nodes_in_workspace() - + try: + from comfy.nodes.package import import_all_nodes_in_workspace + nodes = import_all_nodes_in_workspace() + + all_prompts_nodes_info = {} + + for prompt_index, prompt in enumerate(self.current_prompts): # Get set of class types we need metadata for, excluding LoadTensor and SaveTensor needed_class_types = { node.get('class_type') - for node in self.prompt.values() + for node in prompt.values() if node.get('class_type') not in ('LoadTensor', 'SaveTensor') } remaining_nodes = { node_id - for node_id, node in self.prompt.items() + for node_id, node in prompt.items() if node.get('class_type') not in ('LoadTensor', 'SaveTensor') } nodes_info = {} @@ -103,7 +157,7 @@ async def get_available_nodes(self): # Now process any nodes in our prompt that use this class_type for node_id in list(remaining_nodes): - node = self.prompt[node_id] + node = prompt[node_id] if node.get('class_type') != class_type: continue @@ -124,9 +178,11 @@ async def get_available_nodes(self): nodes_info[node_id] = node_info remaining_nodes.remove(node_id) - - return nodes_info - - except Exception as e: - logger.error(f"Error getting node info: {str(e)}") - return {} + + all_prompts_nodes_info[prompt_index] = nodes_info + + return all_prompts_nodes_info[0] # TODO: make it for for multiple prompts + + except Exception as e: + logger.error(f"Error getting node info: {str(e)}") + return {} diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index 44d7c03b..0216f73b 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -1,6 +1,14 @@ -import asyncio import torch -from typing import List +import numpy as np -inputs: List[torch.Tensor] = [] -outputs: List[asyncio.Future] = [] +from queue import Queue +from asyncio import Queue as AsyncQueue + +from typing import Union + +# TODO: improve eviction policy fifo might not be the best, skip alternate frames instead +image_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue(maxsize=1) +image_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() + +audio_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue() +audio_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() diff --git a/src/comfystream/utils.py b/src/comfystream/utils.py index 720692b5..17916f6d 100644 --- a/src/comfystream/utils.py +++ b/src/comfystream/utils.py @@ -36,7 +36,7 @@ def convert_prompt(prompt: PromptDictInput) -> Prompt: "PreviewImage": [], "SaveImage": [], } - + for key, node in prompt.items(): class_type = node.get("class_type") @@ -47,9 +47,9 @@ def convert_prompt(prompt: PromptDictInput) -> Prompt: # Count inputs and outputs if class_type == "PrimaryInputLoadImage": num_primary_inputs += 1 - elif class_type in ["LoadImage", "LoadTensor"]: + elif class_type in ["LoadImage", "LoadTensor", "LoadAudioTensor"]: num_inputs += 1 - elif class_type in ["PreviewImage", "SaveImage", "SaveTensor"]: + elif class_type in ["PreviewImage", "SaveImage", "SaveTensor", "SaveAudioTensor"]: num_outputs += 1 # Only handle single primary input diff --git a/ui/src/app/api/offer/route.ts b/ui/src/app/api/offer/route.ts index b42428af..e69d8113 100644 --- a/ui/src/app/api/offer/route.ts +++ b/ui/src/app/api/offer/route.ts @@ -1,14 +1,14 @@ import { NextRequest, NextResponse } from "next/server"; export const POST = async function POST(req: NextRequest) { - const { endpoint, prompt, offer } = await req.json(); + const { endpoint, prompts, offer } = await req.json(); const res = await fetch(endpoint + "/offer", { method: "POST", headers: { "Content-Type": "application/json", }, - body: JSON.stringify({ prompt, offer }), + body: JSON.stringify({ prompts, offer }), }); return NextResponse.json(await res.json(), { status: res.status }); diff --git a/ui/src/app/page.tsx b/ui/src/app/page.tsx index 4b17e45e..6d7e24fe 100644 --- a/ui/src/app/page.tsx +++ b/ui/src/app/page.tsx @@ -5,22 +5,22 @@ import { PromptContext } from "@/components/settings"; import { useState, useEffect } from "react"; export default function Page() { - const [originalPrompt, setOriginalPrompt] = useState(null); - const [currentPrompt, setCurrentPrompt] = useState(null); + const [originalPrompts, setOriginalPrompts] = useState(null); + const [currentPrompts, setCurrentPrompts] = useState(null); // Update currentPrompt whenever originalPrompt changes useEffect(() => { - if (originalPrompt) { - setCurrentPrompt(JSON.parse(JSON.stringify(originalPrompt))); + if (originalPrompts) { + setCurrentPrompts(JSON.parse(JSON.stringify(originalPrompts))); } - }, [originalPrompt]); + }, [originalPrompts]); return (
diff --git a/ui/src/components/control-panel.tsx b/ui/src/components/control-panel.tsx index 2afff7da..ebda6894 100644 --- a/ui/src/components/control-panel.tsx +++ b/ui/src/components/control-panel.tsx @@ -109,7 +109,7 @@ const InputControl = ({ export const ControlPanel = ({ panelState, onStateChange }: ControlPanelProps) => { const { controlChannel } = usePeerContext(); - const { currentPrompt, setCurrentPrompt } = usePrompt(); + const { currentPrompts, setCurrentPrompts } = usePrompt(); const [availableNodes, setAvailableNodes] = useState>({}); // Add ref to track last sent value and timeout @@ -139,7 +139,7 @@ export const ControlPanel = ({ panelState, onStateChange }: ControlPanelProps) = const data = JSON.parse(event.data); if (data.type === "nodes_info") { setAvailableNodes(data.nodes); - } else if (data.type === "prompt_updated") { + } else if (data.type === "prompts_updated") { if (!data.success) { console.error("[ControlPanel] Failed to update prompt"); } @@ -171,7 +171,7 @@ export const ControlPanel = ({ panelState, onStateChange }: ControlPanelProps) = // Modify the effect that sends updates with debouncing useEffect(() => { const currentInput = panelState.nodeId && panelState.fieldName ? availableNodes[panelState.nodeId]?.inputs[panelState.fieldName] : null; - if (!currentInput || !currentPrompt) return; + if (!currentInput || !currentPrompts) return; let isValidValue = true; let processedValue: any = panelState.value; @@ -218,6 +218,7 @@ export const ControlPanel = ({ panelState, onStateChange }: ControlPanelProps) = // Set a new timeout for the update updateTimeoutRef.current = setTimeout(() => { // Create updated prompt while maintaining current structure + const currentPrompt = currentPrompts[0]; const updatedPrompt = JSON.parse(JSON.stringify(currentPrompt)); // Deep clone if (updatedPrompt[panelState.nodeId] && updatedPrompt[panelState.nodeId].inputs) { updatedPrompt[panelState.nodeId].inputs[panelState.fieldName] = processedValue; @@ -231,17 +232,17 @@ export const ControlPanel = ({ panelState, onStateChange }: ControlPanelProps) = // Send the full prompt update const message = JSON.stringify({ - type: "update_prompt", - prompt: updatedPrompt + type: "update_prompts", + prompts: [updatedPrompt] }); controlChannel.send(message); // Only update current prompt after sending - setCurrentPrompt(updatedPrompt); + setCurrentPrompts([updatedPrompt]); } }, currentInput.type.toLowerCase() === 'number' ? 100 : 300); // Shorter delay for numbers, longer for text } - }, [panelState.value, panelState.nodeId, panelState.fieldName, panelState.isAutoUpdateEnabled, controlChannel, availableNodes, currentPrompt, setCurrentPrompt]); + }, [panelState.value, panelState.nodeId, panelState.fieldName, panelState.isAutoUpdateEnabled, controlChannel, availableNodes, currentPrompts, setCurrentPrompts]); const toggleAutoUpdate = () => { onStateChange({ isAutoUpdateEnabled: !panelState.isAutoUpdateEnabled }); diff --git a/ui/src/components/peer.tsx b/ui/src/components/peer.tsx index dc7573ac..d2ccee74 100644 --- a/ui/src/components/peer.tsx +++ b/ui/src/components/peer.tsx @@ -4,7 +4,7 @@ import { PeerContext } from "@/context/peer-context"; export interface PeerProps extends React.HTMLAttributes { url: string; - prompt: any; + prompts: any; connect: boolean; onConnected: () => void; onDisconnected: () => void; diff --git a/ui/src/components/room.tsx b/ui/src/components/room.tsx index 1cbad4e0..6737cea5 100644 --- a/ui/src/components/room.tsx +++ b/ui/src/components/room.tsx @@ -19,21 +19,70 @@ interface MediaStreamPlayerProps { function MediaStreamPlayer({ stream }: MediaStreamPlayerProps) { const videoRef = useRef(null); + const [needsPlayButton, setNeedsPlayButton] = useState(false); + const hasVideo = stream.getVideoTracks().length > 0; useEffect(() => { - if (videoRef.current && stream) { - videoRef.current.srcObject = stream; - } + if (!videoRef.current || !stream) return; + + const video = videoRef.current; + video.srcObject = stream; + setNeedsPlayButton(false); + + // Handle autoplay + const playStream = async () => { + try { + // Only attempt to play if the video element exists and has a valid srcObject + if (video && video.srcObject) { + await video.play(); + setNeedsPlayButton(false); + } + } catch (error) { + // Log error but don't throw - this is likely due to browser autoplay policy + console.warn("Autoplay prevented:", error); + setNeedsPlayButton(true); + } + }; + playStream(); + + return () => { + if (video) { + video.srcObject = null; + video.pause(); + } + }; }, [stream]); + const handlePlayClick = async () => { + try { + if (videoRef.current) { + await videoRef.current.play(); + setNeedsPlayButton(false); + } + } catch (error) { + console.warn("Manual play failed:", error); + } + }; + return ( -
@@ -205,6 +257,7 @@ export const Room = () => { onStreamReady={onStreamReady} deviceId={config.selectedDeviceId} frameRate={config.frameRate} + selectedAudioDeviceId={config.selectedAudioDeviceId} />
diff --git a/ui/src/components/settings.tsx b/ui/src/components/settings.tsx index f34320a8..0fbe8f3c 100644 --- a/ui/src/components/settings.tsx +++ b/ui/src/components/settings.tsx @@ -40,8 +40,9 @@ import { Select } from "./ui/select"; export interface StreamConfig { streamUrl: string; frameRate: number; - prompt?: any; + prompts?: any; selectedDeviceId: string | undefined; + selectedAudioDeviceId: string | undefined; } interface VideoDevice { @@ -54,6 +55,7 @@ export const DEFAULT_CONFIG: StreamConfig = { process.env.NEXT_PUBLIC_DEFAULT_STREAM_URL || "http://127.0.0.1:8889", frameRate: 30, selectedDeviceId: undefined, + selectedAudioDeviceId: undefined, }; interface StreamSettingsProps { @@ -117,28 +119,28 @@ interface ConfigFormProps { } interface PromptContextType { - originalPrompt: any; - currentPrompt: any; - setOriginalPrompt: (prompt: any) => void; - setCurrentPrompt: (prompt: any) => void; + originalPrompts: any; + currentPrompts: any; + setOriginalPrompts: (prompts: any) => void; + setCurrentPrompts: (prompts: any) => void; } export const PromptContext = createContext({ - originalPrompt: null, - currentPrompt: null, - setOriginalPrompt: () => {}, - setCurrentPrompt: () => {}, + originalPrompts: null, + currentPrompts: null, + setOriginalPrompts: () => {}, + setCurrentPrompts: () => {}, }); export const usePrompt = () => useContext(PromptContext); function ConfigForm({ config, onSubmit }: ConfigFormProps) { - const [prompt, setPrompt] = useState(null); - const { setOriginalPrompt } = usePrompt(); + const [prompts, setPrompts] = useState([]); + const { setOriginalPrompts } = usePrompt(); const [videoDevices, setVideoDevices] = useState([]); - const [selectedDevice, setSelectedDevice] = useState( - config.selectedDeviceId - ); + const [audioDevices, setAudioDevices] = useState([]); + const [selectedDevice, setSelectedDevice] = useState(config.selectedDeviceId); + const [selectedAudioDevice, setSelectedAudioDevice] = useState(config.selectedDeviceId); const form = useForm>({ resolver: zodResolver(formSchema), @@ -150,38 +152,78 @@ function ConfigForm({ config, onSubmit }: ConfigFormProps) { */ const getVideoDevices = useCallback(async () => { try { - // Get Available Video Devices. - await navigator.mediaDevices.getUserMedia({ video: true }); + await navigator.mediaDevices.getUserMedia({ video: true, audio: true }); + const devices = await navigator.mediaDevices.enumerateDevices(); - const videoDevices = devices - .filter((device) => device.kind === "videoinput") - .map((device) => ({ - deviceId: device.deviceId, - label: device.label || `Camera ${device.deviceId.slice(0, 5)}...`, - })); - setVideoDevices(videoDevices); + const videoDevices = [ + { deviceId: "none", label: "No Video" }, + ...devices + .filter((device) => device.kind === "videoinput") + .map((device) => ({ + deviceId: device.deviceId, + label: device.label || `Camera ${device.deviceId.slice(0, 5)}...`, + })) + ]; - // Use first device as default and remove selected device if unavailable. - if (!videoDevices.some((device) => device.deviceId === selectedDevice)) { - setSelectedDevice(videoDevices.length > 0 ? videoDevices[0].deviceId : undefined); + setVideoDevices(videoDevices); + // Set default to first available camera if no selection yet + if (!selectedDevice && videoDevices.length > 1) { + setSelectedDevice(videoDevices[1].deviceId); // Index 1 because 0 is "No Video" } - } catch (err){ - console.log(`Failed to get video devices: ${err}`); + } catch (err) { + console.error("Failed to get video devices"); + // If we can't access video devices, still provide the None option + const videoDevices = [{ deviceId: "none", label: "No Video" }]; + setVideoDevices(videoDevices); + setSelectedDevice("none"); } }, [selectedDevice]); + const getAudioDevices = useCallback(async () => { + try { + const devices = await navigator.mediaDevices.enumerateDevices(); + const audioDevices = [ + { deviceId: "none", label: "No Audio" }, + ...devices + .filter((device) => device.kind === "audioinput") + .map((device) => ({ + deviceId: device.deviceId, + label: device.label || `Microphone ${device.deviceId.slice(0, 5)}...`, + })) + ]; + + setAudioDevices(audioDevices); + // Set default to first available microphone if no selection yet + if (!selectedAudioDevice && audioDevices.length > 1) { + setSelectedAudioDevice(audioDevices[0].deviceId); // Default to "No Audio" for now + } + } catch (err) { + console.error("Failed to get audio devices"); + // If we can't access audio devices, still provide the None option + const audioDevices = [{ deviceId: "none", label: "No Audio" }]; + setAudioDevices(audioDevices); + setSelectedAudioDevice("none"); + } + }, [selectedAudioDevice]); + // Handle device change events. useEffect(() => { getVideoDevices(); + getAudioDevices(); navigator.mediaDevices.addEventListener("devicechange", getVideoDevices); + navigator.mediaDevices.addEventListener("devicechange", getAudioDevices); return () => { navigator.mediaDevices.removeEventListener( "devicechange", getVideoDevices ); + navigator.mediaDevices.removeEventListener( + "devicechange", + getAudioDevices + ); }; - }, [getVideoDevices]); + }, [getVideoDevices, getAudioDevices]); const handleSubmit = (values: z.infer) => { onSubmit({ @@ -189,22 +231,27 @@ function ConfigForm({ config, onSubmit }: ConfigFormProps) { streamUrl: values.streamUrl ? values.streamUrl.replace(/\/+$/, "") : values.streamUrl, - prompt, + prompts: prompts, selectedDeviceId: selectedDevice, + selectedAudioDeviceId: selectedAudioDevice, }); }; - const handlePromptChange = async (e: any) => { - const file = e.target.files[0]; - if (!file) return; + const handlePromptsChange = async (e: React.ChangeEvent) => { + if (!e.target.files?.length) return; try { - const text = await file.text(); - const parsedPrompt = JSON.parse(text); - setPrompt(parsedPrompt); - setOriginalPrompt(parsedPrompt); + const files = Array.from(e.target.files); + const fileReads = files.map(async (file) => { + const text = await file.text(); + return JSON.parse(text); + }); + + const allPrompts = await Promise.all(fileReads); + setPrompts(allPrompts); + setOriginalPrompts(allPrompts); } catch (err) { - console.error(err); + console.error("Failed to parse one or more JSON files.", err); } }; @@ -257,10 +304,7 @@ function ConfigForm({ config, onSubmit }: ConfigFormProps) { onValueChange={handleCameraSelect} > - {videoDevices.length === 0 - ? "No camera devices found" - : videoDevices.find((d) => d.deviceId === selectedDevice) - ?.label || "Select camera"} + {selectedDevice ? (videoDevices.find((d) => d.deviceId === selectedDevice)?.label || "None") : "None"} {videoDevices.length === 0 ? ( @@ -278,14 +322,30 @@ function ConfigForm({ config, onSubmit }: ConfigFormProps) { +
+ + +
+
- +
diff --git a/ui/src/components/webcam.tsx b/ui/src/components/webcam.tsx index 9f078223..61a326b4 100644 --- a/ui/src/components/webcam.tsx +++ b/ui/src/components/webcam.tsx @@ -23,39 +23,27 @@ function StreamCanvas({ const canvasRef = useRef(null); const videoRef = useRef(null); + // Only set up canvas animation if we have video useEffect(() => { - const canvas = canvasRef.current!; - const outputStream = canvas.captureStream(frameRate); - onStreamReady(outputStream); - - return () => { - outputStream.getTracks().forEach((track) => track.stop()); - }; - }, [onStreamReady, frameRate]); + if (!stream || stream.getVideoTracks().length === 0) return; - // Set up canvas animation loop - useEffect(() => { const canvas = canvasRef.current!; const ctx = canvas.getContext("2d")!; let isActive = true; const drawFrame = () => { - if (!isActive) { - // return without scheduling another frame + if (!isActive || !videoRef.current) { return; } - const video = videoRef.current!; + const video = videoRef.current; if (!video?.videoWidth) { - // video is not ready yet requestAnimationFrame(drawFrame); return; } const scale = Math.max(512 / video.videoWidth, 512 / video.videoHeight); - const scaledWidth = video.videoWidth * scale; const scaledHeight = video.videoHeight * scale; - const offsetX = (512 - scaledWidth) / 2; const offsetY = (512 - scaledHeight) / 2; @@ -69,10 +57,12 @@ function StreamCanvas({ return () => { isActive = false; }; - }, []); + }, [stream]); + // Only set up video element if we have video useEffect(() => { - if (!stream) return; + if (!stream || stream.getVideoTracks().length === 0) return; + if (!videoRef.current) { videoRef.current = document.createElement("video"); videoRef.current.muted = true; @@ -82,16 +72,23 @@ function StreamCanvas({ video.srcObject = stream; video.onloadedmetadata = () => { video.play().catch((error) => { - console.log("Video play failed:", error); + console.error("Video play failed:", error); }); }; return () => { - video.pause(); - video.srcObject = null; + if (video) { + video.pause(); + video.srcObject = null; + } }; }, [stream]); + // Only render canvas if we have video + if (!stream || stream.getVideoTracks().length === 0) { + return null; + } + return ( <>
@@ -120,9 +117,10 @@ interface WebcamProps { onStreamReady: (stream: MediaStream) => void; deviceId: string; frameRate: number; + selectedAudioDeviceId: string; } -export function Webcam({ onStreamReady, deviceId, frameRate }: WebcamProps) { +export function Webcam({ onStreamReady, deviceId, frameRate, selectedAudioDeviceId }: WebcamProps) { const [stream, setStream] = useState(null); const replaceStream = useCallback((newStream: MediaStream | null) => { @@ -131,16 +129,12 @@ export function Webcam({ onStreamReady, deviceId, frameRate }: WebcamProps) { if (oldStream) { oldStream.getTracks().forEach((track) => track.stop()); } - if (newStream) { - const videoTrack = newStream.getVideoTracks()[0]; - const settings = videoTrack.getSettings(); - } return newStream; }); }, []); const startWebcam = useCallback(async () => { - if (!deviceId) { + if (deviceId === "none" && selectedAudioDeviceId === "none") { return null; } if (frameRate == 0) { @@ -148,40 +142,73 @@ export function Webcam({ onStreamReady, deviceId, frameRate }: WebcamProps) { } try { - const newStream = await navigator.mediaDevices.getUserMedia({ - video: { - ...(deviceId ? { deviceId: { exact: deviceId } } : {}), + const constraints: MediaStreamConstraints = { + video: deviceId === "none" ? false : { + deviceId: { exact: deviceId }, width: { ideal: 512 }, height: { ideal: 512 }, aspectRatio: { ideal: 1 }, frameRate: { ideal: frameRate, max: frameRate }, }, - }); + audio: selectedAudioDeviceId === "none" ? false : { + deviceId: { exact: selectedAudioDeviceId }, + sampleRate: 48000, + channelCount: 2, + sampleSize: 16, + echoCancellation: false, + noiseSuppression: false, + autoGainControl: false, + }, + }; + + const newStream = await navigator.mediaDevices.getUserMedia(constraints); return newStream; } catch (error) { + console.error("Error accessing media devices.", error); return null; } - }, [deviceId, frameRate]); + }, [deviceId, frameRate, selectedAudioDeviceId]); useEffect(() => { - if (!deviceId) return; + if (deviceId === "none" && selectedAudioDeviceId === "none") return; if (frameRate == 0) return; startWebcam().then((newStream) => { - replaceStream(newStream); + if (newStream) { + replaceStream(newStream); + setStream(newStream); + onStreamReady(newStream); + } }); return () => { replaceStream(null); }; - }, [deviceId, frameRate, startWebcam, replaceStream]); + }, [deviceId, frameRate, selectedAudioDeviceId, startWebcam, replaceStream, onStreamReady]); + + const hasVideo = stream && stream.getVideoTracks().length > 0; + const hasAudio = stream && stream.getAudioTracks().length > 0; + + // Return audio-only placeholder if we have audio but no video + if (!hasVideo && hasAudio) { + return ( +
+ Audio Only +
+ ); + } + + // Return null if we have neither video nor audio + if (!stream || (!hasVideo && !hasAudio)) { + return null; + } return (
{}} // We handle stream ready in the parent component />
); diff --git a/ui/src/hooks/use-peer.ts b/ui/src/hooks/use-peer.ts index 44132673..04c059f6 100644 --- a/ui/src/hooks/use-peer.ts +++ b/ui/src/hooks/use-peer.ts @@ -6,7 +6,7 @@ const MAX_OFFER_RETRIES = 5; const OFFER_RETRY_INTERVAL = 500; export function usePeer(props: PeerProps): Peer { - const { url, prompt, connect, onConnected, onDisconnected, localStream } = + const { url, prompts, connect, onConnected, onDisconnected, localStream } = props; const [peerConnection, setPeerConnection] = @@ -31,7 +31,7 @@ export function usePeer(props: PeerProps): Peer { }, body: JSON.stringify({ endpoint: url, - prompt, + prompts: prompts, offer, }), }); @@ -81,7 +81,9 @@ export function usePeer(props: PeerProps): Peer { const pc = new RTCPeerConnection(configuration); setPeerConnection(pc); - pc.addTransceiver("video"); + if (localStream.getVideoTracks().length > 0) { + pc.addTransceiver("video"); + } localStream.getTracks().forEach((track) => { pc.addTrack(track, localStream); @@ -100,6 +102,12 @@ export function usePeer(props: PeerProps): Peer { setControlChannel(null); }; + pc.ontrack = (event) => { + if (event.streams && event.streams[0]) { + setRemoteStream(event.streams[0]); + } + }; + channel.onerror = (error) => { console.error("Control channel error:", error); }; @@ -108,12 +116,6 @@ export function usePeer(props: PeerProps): Peer { console.log("Received message on control channel:", event.data); }; - pc.ontrack = (event) => { - if (event.track.kind == "video") { - setRemoteStream(event.streams[0]); - } - }; - pc.onicecandidate = async (event) => { if (!event.candidate) { const answer = await sendOffer(url, pc.localDescription!); diff --git a/workflows/audio-tensor-utils-example-workflow.json b/workflows/audio-tensor-utils-example-workflow.json new file mode 100644 index 00000000..37609fe9 --- /dev/null +++ b/workflows/audio-tensor-utils-example-workflow.json @@ -0,0 +1,40 @@ +{ + "1": { + "inputs": { + "buffer_size": 500.0 + }, + "class_type": "LoadAudioTensor", + "_meta": { + "title": "Load Audio Tensor" + } + }, + "2": { + "inputs": { + "audio": [ + "1", + 0 + ], + "sample_rate": [ + "1", + 1 + ], + "pitch_shift": 4.0 + }, + "class_type": "PitchShifter", + "_meta": { + "title": "Pitch Shift" + } + }, + "3": { + "inputs": { + "audio": [ + "2", + 0 + ] + }, + "class_type": "SaveAudioTensor", + "_meta": { + "title": "Save Audio Tensor" + } + } +} \ No newline at end of file