forked from yondonfu/comfystream
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline.py
106 lines (81 loc) · 4.45 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import av
import torch
import numpy as np
import asyncio
from typing import Any, Dict, Union, List
from comfystream.client import ComfyStreamClient
WARMUP_RUNS = 5
class Pipeline:
def __init__(self, **kwargs):
self.client = ComfyStreamClient(**kwargs, max_workers=5) # TODO: hardcoded max workers, should it be configurable?
self.video_incoming_frames = asyncio.Queue()
self.audio_incoming_frames = asyncio.Queue()
self.processed_audio_buffer = np.array([], dtype=np.int16)
async def warm_video(self):
dummy_frame = av.VideoFrame()
dummy_frame.side_data.input = torch.randn(1, 512, 512, 3)
for _ in range(WARMUP_RUNS):
self.client.put_video_input(dummy_frame)
await self.client.get_video_output()
async def warm_audio(self):
dummy_audio_inp = np.random.randint(-32768, 32767, int(48000 * 0.5), dtype=np.int16) # TODO: adds a lot of delay if it doesn't match the buffer size, is warmup needed?
for _ in range(WARMUP_RUNS):
self.client.put_audio_input((dummy_audio_inp, 48000))
await self.client.get_audio_output()
async def set_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
if isinstance(prompts, list):
await self.client.set_prompts(prompts)
else:
await self.client.set_prompts([prompts])
async def update_prompts(self, prompts: Union[Dict[Any, Any], List[Dict[Any, Any]]]):
if isinstance(prompts, list):
await self.client.update_prompts(prompts)
else:
await self.client.update_prompts([prompts])
async def put_video_frame(self, frame: av.VideoFrame):
frame.side_data.input = self.video_preprocess(frame)
frame.side_data.skipped = True
self.client.put_video_input(frame)
await self.video_incoming_frames.put(frame)
async def put_audio_frame(self, frame: av.AudioFrame):
inp_np = self.audio_preprocess(frame)
self.client.put_audio_input((inp_np, frame.sample_rate))
await self.audio_incoming_frames.put((frame.pts, frame.time_base, frame.samples, frame.sample_rate))
def video_preprocess(self, frame: av.VideoFrame) -> Union[torch.Tensor, np.ndarray]:
frame_np = frame.to_ndarray(format="rgb24").astype(np.float32) / 255.0
return torch.from_numpy(frame_np).unsqueeze(0)
def audio_preprocess(self, frame: av.AudioFrame) -> Union[torch.Tensor, np.ndarray]:
return frame.to_ndarray().ravel().reshape(-1, 2).mean(axis=1).astype(np.int16)
def video_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.VideoFrame:
return av.VideoFrame.from_ndarray(
(output * 255.0).clamp(0, 255).to(dtype=torch.uint8).squeeze(0).cpu().numpy()
)
def audio_postprocess(self, output: Union[torch.Tensor, np.ndarray]) -> av.AudioFrame:
return av.AudioFrame.from_ndarray(np.repeat(output, 2).reshape(1, -1))
async def get_processed_video_frame(self):
# TODO: make it generic to support purely generative video cases
out_tensor = await self.client.get_video_output()
frame = await self.video_incoming_frames.get()
while frame.side_data.skipped:
frame = await self.video_incoming_frames.get()
processed_frame = self.video_postprocess(out_tensor)
processed_frame.pts = frame.pts
processed_frame.time_base = frame.time_base
return processed_frame
async def get_processed_audio_frame(self):
# TODO: make it generic to support purely generative audio cases
pts, time_base, samples, sample_rate = await self.audio_incoming_frames.get()
if samples > len(self.processed_audio_buffer):
out_tensor = await self.client.get_audio_output()
self.processed_audio_buffer = np.concatenate([self.processed_audio_buffer, out_tensor])
out_data = self.processed_audio_buffer[:samples]
self.processed_audio_buffer = self.processed_audio_buffer[samples:]
processed_frame = self.audio_postprocess(out_data)
processed_frame.pts = pts
processed_frame.time_base = time_base
processed_frame.sample_rate = sample_rate
return processed_frame
async def get_nodes_info(self) -> Dict[str, Any]:
"""Get information about all nodes in the current prompt including metadata."""
nodes_info = await self.client.get_available_nodes()
return nodes_info