From 44df1701cfbbf44a3095ad65fd1802048ccce6a6 Mon Sep 17 00:00:00 2001 From: Varshith B Date: Sun, 16 Feb 2025 22:18:23 +0530 Subject: [PATCH] fix: frame skipping --- nodes/tensor_utils/load_tensor.py | 6 ++++-- server/pipeline.py | 20 ++++++++++++-------- src/comfystream/client.py | 6 ++++-- src/comfystream/tensor_cache.py | 3 ++- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/nodes/tensor_utils/load_tensor.py b/nodes/tensor_utils/load_tensor.py index 6c13a8a7..519cbb77 100644 --- a/nodes/tensor_utils/load_tensor.py +++ b/nodes/tensor_utils/load_tensor.py @@ -1,3 +1,4 @@ +import time from comfystream import tensor_cache @@ -15,5 +16,6 @@ def IS_CHANGED(): return float("nan") def execute(self): - input = tensor_cache.image_inputs.get(block=True) - return (input,) + frame = tensor_cache.image_inputs.get(block=True) + frame.side_data.skipped = False + return (frame.side_data.input,) diff --git a/server/pipeline.py b/server/pipeline.py index 3eeafab0..1e96515f 100644 --- a/server/pipeline.py +++ b/server/pipeline.py @@ -19,10 +19,11 @@ def __init__(self, **kwargs): self.processed_audio_buffer = np.array([], dtype=np.int16) async def warm_video(self): - dummy_video_inp = torch.randn(1, 512, 512, 3) + dummy_frame = av.VideoFrame() + dummy_frame.side_data.input = torch.randn(1, 512, 512, 3) for _ in range(WARMUP_RUNS): - self.client.put_video_input(dummy_video_inp) + self.client.put_video_input(dummy_frame) await self.client.get_video_output() async def warm_audio(self): @@ -45,9 +46,10 @@ async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any await self.client.update_prompts([prompts]) async def put_video_frame(self, frame: av.VideoFrame): - inp_tensor = self.video_preprocess(frame) - self.client.put_video_input(inp_tensor) - await self.video_incoming_frames.put((frame.pts, frame.time_base)) + frame.side_data.input = self.video_preprocess(frame) + frame.side_data.skipped = True + self.client.put_video_input(frame) + await self.video_incoming_frames.put(frame) async def put_audio_frame(self, frame: av.AudioFrame): inp_np = self.audio_preprocess(frame) @@ -71,12 +73,14 @@ def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.Audio async def get_processed_video_frame(self): # TODO: make it generic to support purely generative video cases - pts, time_base = await self.video_incoming_frames.get() out_tensor = await self.client.get_video_output() + frame = await self.video_incoming_frames.get() + while frame.side_data.skipped: + frame = await self.video_incoming_frames.get() processed_frame = self.video_postprocess(out_tensor) - processed_frame.pts = pts - processed_frame.time_base = time_base + processed_frame.pts = frame.pts + processed_frame.time_base = frame.time_base return processed_frame diff --git a/src/comfystream/client.py b/src/comfystream/client.py index cc33409a..4b85b2ab 100644 --- a/src/comfystream/client.py +++ b/src/comfystream/client.py @@ -43,8 +43,10 @@ async def run_prompt(self, prompt_index: int): logger.error(f"Error type: {type(e)}") raise - def put_video_input(self, inp_tensor): - tensor_cache.image_inputs.put(inp_tensor) + def put_video_input(self, frame): + if tensor_cache.image_inputs.full(): + tensor_cache.image_inputs.get(block=True) + tensor_cache.image_inputs.put(frame) def put_audio_input(self, inp_tensor): tensor_cache.audio_inputs.put(inp_tensor) diff --git a/src/comfystream/tensor_cache.py b/src/comfystream/tensor_cache.py index ab948744..0216f73b 100644 --- a/src/comfystream/tensor_cache.py +++ b/src/comfystream/tensor_cache.py @@ -6,7 +6,8 @@ from typing import Union -image_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue() +# TODO: improve eviction policy fifo might not be the best, skip alternate frames instead +image_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue(maxsize=1) image_outputs: AsyncQueue[Union[torch.Tensor, np.ndarray]] = AsyncQueue() audio_inputs: Queue[Union[torch.Tensor, np.ndarray]] = Queue()