Skip to content

Commit

Permalink
ReplyOnPause and ReplyOnStopWords can be interrupted (#119)
Browse files Browse the repository at this point in the history
* Add all this code

* add code

* Fix demo

---------

Co-authored-by: Freddy Boulton <[email protected]>
  • Loading branch information
freddyaboulton and Freddy Boulton authored Mar 4, 2025
1 parent 87954a6 commit 6ea5477
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 40 deletions.
15 changes: 13 additions & 2 deletions backend/fastrtc/reply_on_pause.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import inspect
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from logging import getLogger
from threading import Event
Expand Down Expand Up @@ -59,6 +59,10 @@ class AppState:
stopped: bool = False
buffer: np.ndarray | None = None
responded_audio: bool = False
interrupted: asyncio.Event = field(default_factory=asyncio.Event)

def new(self):
return AppState()


ReplyFnGenerator = (
Expand Down Expand Up @@ -91,6 +95,7 @@ def __init__(
fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
Expand All @@ -102,6 +107,7 @@ def __init__(
output_frame_size,
input_sample_rate=input_sample_rate,
)
self.can_interrupt = can_interrupt
self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
Expand All @@ -123,6 +129,7 @@ def copy(self):
self.fn,
self.algo_options,
self.model_options,
self.can_interrupt,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
Expand Down Expand Up @@ -170,11 +177,14 @@ def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
state.pause_detected = pause_detected

def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.state.responding:
if self.state.responding and not self.can_interrupt:
return
self.process_audio(frame, self.state)
if self.state.pause_detected:
self.event.set()
if self.can_interrupt:
self.clear_queue()
self.generator = None

def reset(self):
super().reset()
Expand Down Expand Up @@ -207,6 +217,7 @@ def emit(self):
else:
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
logger.debug("Latest args: %s", self.latest_args)
self.state = self.state.new()
self.state.responding = True
try:
if self.is_async:
Expand Down
6 changes: 6 additions & 0 deletions backend/fastrtc/reply_on_stopwords.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class ReplyOnStopWordsState(AppState):
post_stop_word_buffer: np.ndarray | None = None
started_talking_pre_stop_word: bool = False

def new(self):
return ReplyOnStopWordsState()


class ReplyOnStopWords(ReplyOnPause):
def __init__(
Expand All @@ -31,6 +34,7 @@ def __init__(
stop_words: list[str],
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
Expand All @@ -40,6 +44,7 @@ def __init__(
fn,
algo_options=algo_options,
model_options=model_options,
can_interrupt=can_interrupt,
expected_layout=expected_layout,
output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size,
Expand Down Expand Up @@ -144,6 +149,7 @@ def copy(self):
self.stop_words,
self.algo_options,
self.model_options,
self.can_interrupt,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
Expand Down
6 changes: 3 additions & 3 deletions backend/fastrtc/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def _generate_default_ui(
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
mode="send-receive",
mode="send",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
Expand Down Expand Up @@ -505,7 +505,7 @@ async def handle_incoming_call(self, request: Request):
return HTMLResponse(content=str(response), media_type="application/xml")

async def telephone_handler(self, websocket: WebSocket):
handler = cast(StreamHandlerImpl, self.event_handler.copy())
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
handler.phone_mode = True

async def set_handler(s: str, a: WebSocketHandler):
Expand All @@ -528,7 +528,7 @@ async def set_handler(s: str, a: WebSocketHandler):
await ws.handle_websocket(websocket)

async def websocket_offer(self, websocket: WebSocket):
handler = cast(StreamHandlerImpl, self.event_handler.copy())
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
handler.phone_mode = False

async def set_handler(s: str, a: WebSocketHandler):
Expand Down
25 changes: 23 additions & 2 deletions backend/fastrtc/tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def __init__(
self.args_set = asyncio.Event()
self.channel_set = asyncio.Event()
self._phone_mode = False
self._clear_queue: Callable | None = None

@property
def clear_queue(self) -> Callable:
return cast(Callable, self._clear_queue)

@property
def loop(self) -> asyncio.AbstractEventLoop:
Expand Down Expand Up @@ -237,8 +242,11 @@ async def send_message(self, msg: str):
logger.debug("Sent msg %s", msg)

def send_message_sync(self, msg: str):
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
logger.debug("Sent msg %s", msg)
try:
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
logger.debug("Sent msg %s", msg)
except Exception as e:
logger.debug("Exception sending msg %s", e)

def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args)
Expand Down Expand Up @@ -411,6 +419,7 @@ def __init__(
super().__init__()
self.track = track
self.event_handler = cast(StreamHandlerImpl, event_handler)
self.event_handler._clear_queue = self.clear_queue
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue()
Expand All @@ -421,6 +430,12 @@ def __init__(
self.channel = channel
self.set_additional_outputs = set_additional_outputs

def clear_queue(self):
if self.queue:
while not self.queue.empty():
self.queue.get_nowait()
self._start = None

def set_channel(self, channel: DataChannel):
self.channel = channel
self.event_handler.set_channel(channel)
Expand Down Expand Up @@ -608,6 +623,7 @@ def __init__(
) -> None:
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
self.event_handler._clear_queue = self.clear_queue
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event()
Expand All @@ -619,6 +635,11 @@ def __init__(
self._start: float | None = None
super().__init__()

def clear_queue(self):
while not self.queue.empty():
self.queue.get_nowait()
self._start = None

def set_channel(self, channel: DataChannel):
self.channel = channel

Expand Down
2 changes: 1 addition & 1 deletion backend/fastrtc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def audio_to_int16(
>>> audio_int16 = audio_to_int16(audio_tuple)
"""
if audio[1].dtype == np.int16:
return audio[1]
return audio[1] # type: ignore
elif audio[1].dtype == np.float32:
# Convert float32 to int16 by scaling to the int16 range
return (audio[1] * 32767.0).astype(np.int16)
Expand Down
1 change: 1 addition & 0 deletions backend/fastrtc/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
],
):
self.stream_handler = stream_handler
self.stream_handler._clear_queue = lambda: None
self.websocket: Optional[WebSocket] = None
self._emit_task: Optional[asyncio.Task] = None
self.stream_id: Optional[str] = None
Expand Down
4 changes: 2 additions & 2 deletions demo/llm_voice_chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ def response(
response_text = (
groq_client.chat.completions.create(
model="llama-3.1-8b-instant",
max_tokens=512,
max_tokens=200,
messages=messages, # type: ignore
)
.choices[0]
.message.content
)

chatbot.append({"role": "assistant", "content": response_text})
yield AdditionalOutputs(chatbot)

for chunk in tts_client.text_to_speech.convert_as_stream(
text=response_text, # type: ignore
Expand All @@ -58,7 +59,6 @@ def response(
):
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
yield (24000, audio_array)
yield AdditionalOutputs(chatbot)


chatbot = gr.Chatbot(type="messages")
Expand Down
14 changes: 11 additions & 3 deletions demo/moonshine_live/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Expand All @@ -13,6 +14,8 @@
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
from numpy.typing import NDArray

load_dotenv()


@lru_cache(maxsize=None)
def load_moonshine(
Expand All @@ -27,6 +30,7 @@ def load_moonshine(
def stt(
audio: tuple[int, NDArray[np.int16 | np.float32]],
model_name: Literal["moonshine/base", "moonshine/tiny"],
captions: str,
) -> Generator[AdditionalOutputs, None, None]:
moonshine = load_moonshine(model_name)
sr, audio_np = audio # type: ignore
Expand All @@ -35,9 +39,12 @@ def stt(
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
tokens = moonshine.generate(audio_np)
yield AdditionalOutputs(tokenizer.decode_batch(tokens)[0])
yield AdditionalOutputs(
(captions + "\n" + tokenizer.decode_batch(tokens)[0]).strip()
)


captions = gr.Textbox(label="Captions")
stream = Stream(
ReplyOnPause(stt, input_sample_rate=16000),
modality="audio",
Expand All @@ -55,9 +62,10 @@ def stt(
choices=["moonshine/base", "moonshine/tiny"],
value="moonshine/base",
label="Model",
)
),
captions,
],
additional_outputs=[gr.Textbox(label="Captions")],
additional_outputs=[captions],
additional_outputs_handler=lambda prev, current: (prev + "\n" + current).strip(),
)

Expand Down
29 changes: 20 additions & 9 deletions demo/whisper_realtime/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from gradio.utils import get_space
from groq import AsyncClient
from pydantic import BaseModel

cur_dir = Path(__file__).parent

Expand All @@ -24,23 +25,23 @@
groq_client = AsyncClient()


async def transcribe(audio: tuple[int, np.ndarray]):
transcript = await groq_client.audio.transcriptions.create(
async def transcribe(audio: tuple[int, np.ndarray], transcript: str):
response = await groq_client.audio.transcriptions.create(
file=("audio-file.mp3", audio_to_bytes(audio)),
model="whisper-large-v3-turbo",
response_format="verbose_json",
)
yield AdditionalOutputs(transcript.text)
yield AdditionalOutputs(transcript + "\n" + response.text)


transcript = gr.Textbox(label="Transcript")
stream = Stream(
ReplyOnPause(transcribe),
modality="audio",
mode="send",
additional_outputs=[
gr.Textbox(label="Transcript"),
],
additional_outputs_handler=lambda a, b: a + " " + b,
additional_inputs=[transcript],
additional_outputs=[transcript],
additional_outputs_handler=lambda a, b: b,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
Expand All @@ -51,11 +52,21 @@ async def transcribe(audio: tuple[int, np.ndarray]):
stream.mount(app)


class SendInput(BaseModel):
webrtc_id: str
transcript: str


@app.post("/send_input")
def send_input(body: SendInput):
stream.set_input(body.webrtc_id, body.transcript)


@app.get("/transcript")
def _(webrtc_id: str):
async def output_stream():
async for output in stream.output_stream(webrtc_id):
transcript = output.args[0]
transcript = output.args[0].split("\n")[-1]
yield f"event: output\ndata: {transcript}\n\n"

return StreamingResponse(output_stream(), media_type="text/event-stream")
Expand All @@ -73,7 +84,7 @@ def index():
import os

if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860, server_name="0.0.0.0")
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
Expand Down
15 changes: 13 additions & 2 deletions demo/whisper_realtime/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ <h1>Real-time Transcription</h1>
</div>

<div class="container">
<div class="transcript-container" id="transcript"></div>
<div class="transcript-container" id="transcript">
</div>
<div class="controls">
<button id="start-button">Start Recording</button>
</div>
Expand All @@ -220,13 +221,23 @@ <h1>Real-time Transcription</h1>
}, 5000);
}

function handleMessage(event) {
async function handleMessage(event) {
// Handle any WebRTC data channel messages if needed
const eventJson = JSON.parse(event.data);
if (eventJson.type === "error") {
showError(eventJson.message);
} else if (eventJson.type === "send_input") {
const response = await fetch('/send_input', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
webrtc_id: webrtc_id,
transcript: ""
})
});
}
console.log('Received message:', event.data);

}

function updateButtonState() {
Expand Down
Loading

0 comments on commit 6ea5477

Please sign in to comment.