1
- import torch
2
1
import av
2
+ import torch
3
3
import numpy as np
4
- import fractions
5
4
import asyncio
6
5
7
- from av import AudioFrame
8
- from typing import Any , Dict , Optional , Union , List
6
+ from typing import Any , Dict , Union , List
9
7
from comfystream .client import ComfyStreamClient
10
- from comfystream import tensor_cache
11
-
12
- WARMUP_RUNS = 5
13
8
9
+ WARMUP_RUNS = 10
14
10
15
11
16
12
class Pipeline :
@@ -19,53 +15,39 @@ def __init__(self, **kwargs):
19
15
20
16
self .video_futures = asyncio .Queue ()
21
17
self .audio_futures = asyncio .Queue ()
22
-
23
- self .audio_output_frames = []
24
18
25
19
self .resampler = av .audio .resampler .AudioResampler (format = 's16' , layout = 'mono' , rate = 48000 ) # find a better way to convert to mono
26
- self .sample_rate = 48000 # instead of hardcoding, find a clean way to set from audio frame
27
- self .frame_size = int (self .sample_rate * 0.02 )
28
- self .time_base = fractions .Fraction (1 , self .sample_rate )
29
- self .curr_pts = 0 # figure out a better way to set back pts to processed audio frames
30
-
31
- def set_prompt (self , prompt : Dict [Any , Any ]):
32
- self .client .set_prompt (prompt )
33
20
34
21
async def warm (self ):
35
- dummy_video_frame = torch .randn (1 , 512 , 512 , 3 )
36
- dummy_audio_frame = np .random .randint (- 32768 , 32767 , 48000 * 1 , dtype = np .int16 )
22
+ dummy_video_inp = torch .randn (1 , 512 , 512 , 3 )
23
+ dummy_audio_inp = np .random .randint (- 32768 , 32767 , 48 * 20 , dtype = np .int16 ) # has to be more than the buffer size in comfy workflow
37
24
38
25
for _ in range (WARMUP_RUNS ):
39
- image_out_fut = asyncio .Future ()
40
- audio_out_fut = asyncio .Future ()
41
- tensor_cache .image_outputs .put (image_out_fut )
42
- tensor_cache .audio_outputs .put (audio_out_fut )
26
+ image_out_fut = self .client .put_video_input (dummy_video_inp )
27
+ await image_out_fut
43
28
44
- tensor_cache .image_inputs .put (dummy_video_frame )
45
- tensor_cache .audio_inputs .put (dummy_audio_frame )
29
+ futs = []
30
+ for _ in range (WARMUP_RUNS ):
31
+ audio_out_fut = self .client .put_audio_input (dummy_audio_inp )
32
+ futs .append (audio_out_fut )
46
33
47
- await image_out_fut
48
- await audio_out_fut
34
+ await asyncio .gather (* futs )
49
35
50
- def set_prompts (self , prompts : Union [Dict [Any , Any ], List [Dict [Any , Any ]]]):
36
+ async def set_prompts (self , prompts : Union [Dict [Any , Any ], List [Dict [Any , Any ]]]):
51
37
if isinstance (prompts , dict ):
52
- self .client .set_prompts ([prompts ])
38
+ await self .client .set_prompts ([prompts ])
53
39
else :
54
- self .client .set_prompts (prompts )
40
+ await self .client .set_prompts (prompts )
55
41
56
42
async def put_video_frame (self , frame : av .VideoFrame ):
57
43
inp_tensor = self .video_preprocess (frame )
58
- out_future = asyncio .Future ()
59
- tensor_cache .image_outputs .put (out_future )
60
- tensor_cache .image_inputs .put (inp_tensor )
44
+ out_future = self .client .put_video_input (inp_tensor )
61
45
await self .video_futures .put ((out_future , frame .pts , frame .time_base ))
62
46
63
47
async def put_audio_frame (self , frame : av .AudioFrame ):
64
48
inp_tensor = self .audio_preprocess (frame )
65
- out_future = asyncio .Future ()
66
- tensor_cache .audio_outputs .put (out_future )
67
- tensor_cache .audio_inputs .put (inp_tensor )
68
- await self .audio_futures .put (out_future )
49
+ out_future = self .client .put_audio_input (inp_tensor )
50
+ await self .audio_futures .put ((out_future , frame .pts , frame .time_base , frame .sample_rate ))
69
51
70
52
def video_preprocess (self , frame : av .VideoFrame ) -> torch .Tensor :
71
53
frame_np = frame .to_ndarray (format = "rgb24" ).astype (np .float32 ) / 255.0
@@ -80,18 +62,7 @@ def video_postprocess(self, output: torch.Tensor) -> av.VideoFrame:
80
62
)
81
63
82
64
def audio_postprocess (self , output : torch .Tensor ) -> av .AudioFrame :
83
- frames = []
84
- for idx in range (0 , len (output ), self .frame_size ):
85
- frame_samples = output [idx :idx + self .frame_size ]
86
- frame_samples = frame_samples .reshape (1 , - 1 ).astype (np .int16 )
87
- frame = AudioFrame .from_ndarray (frame_samples , layout = "mono" )
88
- frame .sample_rate = self .sample_rate
89
- frame .pts = self .curr_pts
90
- frame .time_base = self .time_base
91
- self .curr_pts += 960
92
-
93
- frames .append (frame )
94
- return frames
65
+ return av .AudioFrame .from_ndarray (output .reshape (1 , - 1 ), layout = "mono" )
95
66
96
67
async def get_processed_video_frame (self ):
97
68
out_fut , pts , time_base = await self .video_futures .get ()
@@ -101,14 +72,12 @@ async def get_processed_video_frame(self):
101
72
return frame
102
73
103
74
async def get_processed_audio_frame (self ):
104
- while not self .audio_output_frames :
105
- out_fut = await self .audio_futures .get ()
106
- output = await out_fut
107
- if output is None :
108
- print ("No Audio output" )
109
- continue
110
- self .audio_output_frames .extend (self .audio_postprocess (output ))
111
- return self .audio_output_frames .pop (0 )
75
+ out_fut , pts , time_base , sample_rate = await self .audio_futures .get ()
76
+ frame = self .audio_postprocess (await out_fut )
77
+ frame .pts = pts
78
+ frame .time_base = time_base
79
+ frame .sample_rate = sample_rate
80
+ return frame
112
81
113
82
async def get_nodes_info (self ) -> Dict [str , Any ]:
114
83
"""Get information about all nodes in the current prompt including metadata."""
0 commit comments