Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReplyOnPause and ReplyOnStopWords can be interrupted #119

Merged
merged 3 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions backend/fastrtc/reply_on_pause.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import inspect
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from logging import getLogger
from threading import Event
Expand Down Expand Up @@ -59,6 +59,10 @@ class AppState:
stopped: bool = False
buffer: np.ndarray | None = None
responded_audio: bool = False
interrupted: asyncio.Event = field(default_factory=asyncio.Event)

def new(self):
return AppState()


ReplyFnGenerator = (
Expand Down Expand Up @@ -91,6 +95,7 @@ def __init__(
fn: ReplyFnGenerator,
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
Expand All @@ -102,6 +107,7 @@ def __init__(
output_frame_size,
input_sample_rate=input_sample_rate,
)
self.can_interrupt = can_interrupt
self.expected_layout: Literal["mono", "stereo"] = expected_layout
self.output_sample_rate = output_sample_rate
self.output_frame_size = output_frame_size
Expand All @@ -123,6 +129,7 @@ def copy(self):
self.fn,
self.algo_options,
self.model_options,
self.can_interrupt,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
Expand Down Expand Up @@ -170,11 +177,14 @@ def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None:
state.pause_detected = pause_detected

def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.state.responding:
if self.state.responding and not self.can_interrupt:
return
self.process_audio(frame, self.state)
if self.state.pause_detected:
self.event.set()
if self.can_interrupt:
self.clear_queue()
self.generator = None

def reset(self):
super().reset()
Expand Down Expand Up @@ -207,6 +217,7 @@ def emit(self):
else:
self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore
logger.debug("Latest args: %s", self.latest_args)
self.state = self.state.new()
self.state.responding = True
try:
if self.is_async:
Expand Down
6 changes: 6 additions & 0 deletions backend/fastrtc/reply_on_stopwords.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class ReplyOnStopWordsState(AppState):
post_stop_word_buffer: np.ndarray | None = None
started_talking_pre_stop_word: bool = False

def new(self):
return ReplyOnStopWordsState()


class ReplyOnStopWords(ReplyOnPause):
def __init__(
Expand All @@ -31,6 +34,7 @@ def __init__(
stop_words: list[str],
algo_options: AlgoOptions | None = None,
model_options: SileroVadOptions | None = None,
can_interrupt: bool = True,
expected_layout: Literal["mono", "stereo"] = "mono",
output_sample_rate: int = 24000,
output_frame_size: int = 480,
Expand All @@ -40,6 +44,7 @@ def __init__(
fn,
algo_options=algo_options,
model_options=model_options,
can_interrupt=can_interrupt,
expected_layout=expected_layout,
output_sample_rate=output_sample_rate,
output_frame_size=output_frame_size,
Expand Down Expand Up @@ -144,6 +149,7 @@ def copy(self):
self.stop_words,
self.algo_options,
self.model_options,
self.can_interrupt,
self.expected_layout,
self.output_sample_rate,
self.output_frame_size,
Expand Down
6 changes: 3 additions & 3 deletions backend/fastrtc/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def _generate_default_ui(
image = WebRTC(
label="Stream",
rtc_configuration=self.rtc_configuration,
mode="send-receive",
mode="send",
modality="audio",
icon=ui_args.get("icon"),
icon_button_color=ui_args.get("icon_button_color"),
Expand Down Expand Up @@ -505,7 +505,7 @@ async def handle_incoming_call(self, request: Request):
return HTMLResponse(content=str(response), media_type="application/xml")

async def telephone_handler(self, websocket: WebSocket):
handler = cast(StreamHandlerImpl, self.event_handler.copy())
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
handler.phone_mode = True

async def set_handler(s: str, a: WebSocketHandler):
Expand All @@ -528,7 +528,7 @@ async def set_handler(s: str, a: WebSocketHandler):
await ws.handle_websocket(websocket)

async def websocket_offer(self, websocket: WebSocket):
handler = cast(StreamHandlerImpl, self.event_handler.copy())
handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore
handler.phone_mode = False

async def set_handler(s: str, a: WebSocketHandler):
Expand Down
25 changes: 23 additions & 2 deletions backend/fastrtc/tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def __init__(
self.args_set = asyncio.Event()
self.channel_set = asyncio.Event()
self._phone_mode = False
self._clear_queue: Callable | None = None

@property
def clear_queue(self) -> Callable:
return cast(Callable, self._clear_queue)

@property
def loop(self) -> asyncio.AbstractEventLoop:
Expand Down Expand Up @@ -237,8 +242,11 @@ async def send_message(self, msg: str):
logger.debug("Sent msg %s", msg)

def send_message_sync(self, msg: str):
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
logger.debug("Sent msg %s", msg)
try:
asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result()
logger.debug("Sent msg %s", msg)
except Exception as e:
logger.debug("Exception sending msg %s", e)

def set_args(self, args: list[Any]):
logger.debug("setting args in audio callback %s", args)
Expand Down Expand Up @@ -411,6 +419,7 @@ def __init__(
super().__init__()
self.track = track
self.event_handler = cast(StreamHandlerImpl, event_handler)
self.event_handler._clear_queue = self.clear_queue
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.queue = asyncio.Queue()
Expand All @@ -421,6 +430,12 @@ def __init__(
self.channel = channel
self.set_additional_outputs = set_additional_outputs

def clear_queue(self):
if self.queue:
while not self.queue.empty():
self.queue.get_nowait()
self._start = None

def set_channel(self, channel: DataChannel):
self.channel = channel
self.event_handler.set_channel(channel)
Expand Down Expand Up @@ -608,6 +623,7 @@ def __init__(
) -> None:
self.generator: Generator[Any, None, Any] | None = None
self.event_handler = event_handler
self.event_handler._clear_queue = self.clear_queue
self.current_timestamp = 0
self.latest_args: str | list[Any] = "not_set"
self.args_set = threading.Event()
Expand All @@ -619,6 +635,11 @@ def __init__(
self._start: float | None = None
super().__init__()

def clear_queue(self):
while not self.queue.empty():
self.queue.get_nowait()
self._start = None

def set_channel(self, channel: DataChannel):
self.channel = channel

Expand Down
2 changes: 1 addition & 1 deletion backend/fastrtc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def audio_to_int16(
>>> audio_int16 = audio_to_int16(audio_tuple)
"""
if audio[1].dtype == np.int16:
return audio[1]
return audio[1] # type: ignore
elif audio[1].dtype == np.float32:
# Convert float32 to int16 by scaling to the int16 range
return (audio[1] * 32767.0).astype(np.int16)
Expand Down
1 change: 1 addition & 0 deletions backend/fastrtc/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
],
):
self.stream_handler = stream_handler
self.stream_handler._clear_queue = lambda: None
self.websocket: Optional[WebSocket] = None
self._emit_task: Optional[asyncio.Task] = None
self.stream_id: Optional[str] = None
Expand Down
4 changes: 2 additions & 2 deletions demo/llm_voice_chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ def response(
response_text = (
groq_client.chat.completions.create(
model="llama-3.1-8b-instant",
max_tokens=512,
max_tokens=200,
messages=messages, # type: ignore
)
.choices[0]
.message.content
)

chatbot.append({"role": "assistant", "content": response_text})
yield AdditionalOutputs(chatbot)

for chunk in tts_client.text_to_speech.convert_as_stream(
text=response_text, # type: ignore
Expand All @@ -58,7 +59,6 @@ def response(
):
audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1)
yield (24000, audio_array)
yield AdditionalOutputs(chatbot)


chatbot = gr.Chatbot(type="messages")
Expand Down
14 changes: 11 additions & 3 deletions demo/moonshine_live/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Expand All @@ -13,6 +14,8 @@
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
from numpy.typing import NDArray

load_dotenv()


@lru_cache(maxsize=None)
def load_moonshine(
Expand All @@ -27,6 +30,7 @@ def load_moonshine(
def stt(
audio: tuple[int, NDArray[np.int16 | np.float32]],
model_name: Literal["moonshine/base", "moonshine/tiny"],
captions: str,
) -> Generator[AdditionalOutputs, None, None]:
moonshine = load_moonshine(model_name)
sr, audio_np = audio # type: ignore
Expand All @@ -35,9 +39,12 @@ def stt(
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
tokens = moonshine.generate(audio_np)
yield AdditionalOutputs(tokenizer.decode_batch(tokens)[0])
yield AdditionalOutputs(
(captions + "\n" + tokenizer.decode_batch(tokens)[0]).strip()
)


captions = gr.Textbox(label="Captions")
stream = Stream(
ReplyOnPause(stt, input_sample_rate=16000),
modality="audio",
Expand All @@ -55,9 +62,10 @@ def stt(
choices=["moonshine/base", "moonshine/tiny"],
value="moonshine/base",
label="Model",
)
),
captions,
],
additional_outputs=[gr.Textbox(label="Captions")],
additional_outputs=[captions],
additional_outputs_handler=lambda prev, current: (prev + "\n" + current).strip(),
)

Expand Down
29 changes: 20 additions & 9 deletions demo/whisper_realtime/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from gradio.utils import get_space
from groq import AsyncClient
from pydantic import BaseModel

cur_dir = Path(__file__).parent

Expand All @@ -24,23 +25,23 @@
groq_client = AsyncClient()


async def transcribe(audio: tuple[int, np.ndarray]):
transcript = await groq_client.audio.transcriptions.create(
async def transcribe(audio: tuple[int, np.ndarray], transcript: str):
response = await groq_client.audio.transcriptions.create(
file=("audio-file.mp3", audio_to_bytes(audio)),
model="whisper-large-v3-turbo",
response_format="verbose_json",
)
yield AdditionalOutputs(transcript.text)
yield AdditionalOutputs(transcript + "\n" + response.text)


transcript = gr.Textbox(label="Transcript")
stream = Stream(
ReplyOnPause(transcribe),
modality="audio",
mode="send",
additional_outputs=[
gr.Textbox(label="Transcript"),
],
additional_outputs_handler=lambda a, b: a + " " + b,
additional_inputs=[transcript],
additional_outputs=[transcript],
additional_outputs_handler=lambda a, b: b,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
Expand All @@ -51,11 +52,21 @@ async def transcribe(audio: tuple[int, np.ndarray]):
stream.mount(app)


class SendInput(BaseModel):
webrtc_id: str
transcript: str


@app.post("/send_input")
def send_input(body: SendInput):
stream.set_input(body.webrtc_id, body.transcript)


@app.get("/transcript")
def _(webrtc_id: str):
async def output_stream():
async for output in stream.output_stream(webrtc_id):
transcript = output.args[0]
transcript = output.args[0].split("\n")[-1]
yield f"event: output\ndata: {transcript}\n\n"

return StreamingResponse(output_stream(), media_type="text/event-stream")
Expand All @@ -73,7 +84,7 @@ def index():
import os

if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860, server_name="0.0.0.0")
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
Expand Down
15 changes: 13 additions & 2 deletions demo/whisper_realtime/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ <h1>Real-time Transcription</h1>
</div>

<div class="container">
<div class="transcript-container" id="transcript"></div>
<div class="transcript-container" id="transcript">
</div>
<div class="controls">
<button id="start-button">Start Recording</button>
</div>
Expand All @@ -220,13 +221,23 @@ <h1>Real-time Transcription</h1>
}, 5000);
}

function handleMessage(event) {
async function handleMessage(event) {
// Handle any WebRTC data channel messages if needed
const eventJson = JSON.parse(event.data);
if (eventJson.type === "error") {
showError(eventJson.message);
} else if (eventJson.type === "send_input") {
const response = await fetch('/send_input', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
webrtc_id: webrtc_id,
transcript: ""
})
});
}
console.log('Received message:', event.data);

}

function updateButtonState() {
Expand Down
Loading