Skip to content

Commit

Permalink
temp: working state
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith15 committed Dec 29, 2024
1 parent 21e4310 commit 2a3d086
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
2 changes: 2 additions & 0 deletions nodes/whisper_utils/apply_whisper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# TODO: move it to a seperate repo

from .whisper_online import FasterWhisperASR, VACOnlineASRProcessor

class ApplyWhisper:
Expand Down
37 changes: 29 additions & 8 deletions server/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@

WARMUP_RUNS = 5

# TODO: remove, was just for temp UI
import logging

display_logger = logging.getLogger('display_logger')
display_logger.setLevel(logging.INFO)
handler = logging.FileHandler('display_logs.txt')
formatter = logging.Formatter('%(message)s')
handler.setFormatter(formatter)
display_logger.addHandler(handler)


class VideoPipeline:
def __init__(self, **kwargs):
Expand Down Expand Up @@ -47,6 +57,7 @@ async def __call__(self, frame: av.VideoFrame) -> av.VideoFrame:
class AudioPipeline:
def __init__(self, **kwargs):
self.client = ComfyStreamClient(**kwargs, type="audio")
self.resampler = av.audio.resampler.AudioResampler(format='s16', layout='mono', rate=16000)

async def warm(self):
dummy_audio = torch.randn(16000)
Expand All @@ -57,10 +68,10 @@ def set_prompt(self, prompt: Dict[Any, Any]):
self.client.set_prompt(prompt)

def preprocess(self, frame: av.AudioFrame) -> torch.Tensor:
self.sample_rate = frame.sample_rate
samples = frame.to_ndarray(format="s16", layout="stereo")
resampled_frame = self.resampler.resample(frame)[0]
samples = resampled_frame.to_ndarray()
samples = samples.astype(np.float32) / 32768.0
return torch.from_numpy(samples)
return samples

def postprocess(self, output: torch.Tensor) -> Optional[Union[av.AudioFrame, str]]:
out_np = output.cpu().numpy()
Expand All @@ -72,10 +83,20 @@ async def predict(self, frame: torch.Tensor) -> torch.Tensor:
return await self.client.queue_prompt(frame)

async def __call__(self, frame: av.AudioFrame):
# TODO: clean this up later for audio-to-text and audio-to-audio
pre_output = self.preprocess(frame)
pred_output = await self.predict(pre_output)
post_output = self.postprocess(pred_output)
post_output.sample_rate = self.sample_rate
post_output.pts = frame.pts
post_output.time_base = frame.time_base
return post_output
if type(pred_output) == tuple:
if pred_output[0] is not None:
await self.log_text(f"{pred_output[0]} {pred_output[1]} {pred_output[2]}")
return frame
else:
post_output = self.postprocess(pred_output)
post_output.sample_rate = frame.sample_rate
post_output.pts = frame.pts
post_output.time_base = frame.time_base
return post_output

async def log_text(self, text: str):
# TODO: remove, was just for temp UI
display_logger.info(text)
4 changes: 3 additions & 1 deletion ui/src/components/webcam.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ export function Webcam({ onStreamReady, deviceId, frameRate, selectedAudioDevice
...(selectedAudioDeviceId ? { deviceId: { exact: selectedAudioDeviceId } } : {}),
sampleRate: { ideal: 16000 },
sampleSize: { ideal: 16 },
channelCount: { exact: 1 },
channelCount: { ideal: 1 },
echoCancellation: true,
noiseSuppression: true,
},
});
return newStream;
Expand Down

0 comments on commit 2a3d086

Please sign in to comment.