diff --git a/backend/fastrtc/reply_on_pause.py b/backend/fastrtc/reply_on_pause.py index 23dc0b7..2591391 100644 --- a/backend/fastrtc/reply_on_pause.py +++ b/backend/fastrtc/reply_on_pause.py @@ -1,6 +1,6 @@ import asyncio import inspect -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import lru_cache from logging import getLogger from threading import Event @@ -59,6 +59,10 @@ class AppState: stopped: bool = False buffer: np.ndarray | None = None responded_audio: bool = False + interrupted: asyncio.Event = field(default_factory=asyncio.Event) + + def new(self): + return AppState() ReplyFnGenerator = ( @@ -91,6 +95,7 @@ def __init__( fn: ReplyFnGenerator, algo_options: AlgoOptions | None = None, model_options: SileroVadOptions | None = None, + can_interrupt: bool = True, expected_layout: Literal["mono", "stereo"] = "mono", output_sample_rate: int = 24000, output_frame_size: int = 480, @@ -102,6 +107,7 @@ def __init__( output_frame_size, input_sample_rate=input_sample_rate, ) + self.can_interrupt = can_interrupt self.expected_layout: Literal["mono", "stereo"] = expected_layout self.output_sample_rate = output_sample_rate self.output_frame_size = output_frame_size @@ -123,6 +129,7 @@ def copy(self): self.fn, self.algo_options, self.model_options, + self.can_interrupt, self.expected_layout, self.output_sample_rate, self.output_frame_size, @@ -170,11 +177,14 @@ def process_audio(self, audio: tuple[int, np.ndarray], state: AppState) -> None: state.pause_detected = pause_detected def receive(self, frame: tuple[int, np.ndarray]) -> None: - if self.state.responding: + if self.state.responding and not self.can_interrupt: return self.process_audio(frame, self.state) if self.state.pause_detected: self.event.set() + if self.can_interrupt: + self.clear_queue() + self.generator = None def reset(self): super().reset() @@ -207,6 +217,7 @@ def emit(self): else: self.generator = self.fn((self.state.sampling_rate, audio)) # type: ignore logger.debug("Latest args: %s", self.latest_args) + self.state = self.state.new() self.state.responding = True try: if self.is_async: diff --git a/backend/fastrtc/reply_on_stopwords.py b/backend/fastrtc/reply_on_stopwords.py index 6b5b4a7..d0063ac 100644 --- a/backend/fastrtc/reply_on_stopwords.py +++ b/backend/fastrtc/reply_on_stopwords.py @@ -23,6 +23,9 @@ class ReplyOnStopWordsState(AppState): post_stop_word_buffer: np.ndarray | None = None started_talking_pre_stop_word: bool = False + def new(self): + return ReplyOnStopWordsState() + class ReplyOnStopWords(ReplyOnPause): def __init__( @@ -31,6 +34,7 @@ def __init__( stop_words: list[str], algo_options: AlgoOptions | None = None, model_options: SileroVadOptions | None = None, + can_interrupt: bool = True, expected_layout: Literal["mono", "stereo"] = "mono", output_sample_rate: int = 24000, output_frame_size: int = 480, @@ -40,6 +44,7 @@ def __init__( fn, algo_options=algo_options, model_options=model_options, + can_interrupt=can_interrupt, expected_layout=expected_layout, output_sample_rate=output_sample_rate, output_frame_size=output_frame_size, @@ -144,6 +149,7 @@ def copy(self): self.stop_words, self.algo_options, self.model_options, + self.can_interrupt, self.expected_layout, self.output_sample_rate, self.output_frame_size, diff --git a/backend/fastrtc/stream.py b/backend/fastrtc/stream.py index f3b4c8e..8162c1a 100644 --- a/backend/fastrtc/stream.py +++ b/backend/fastrtc/stream.py @@ -360,7 +360,7 @@ def _generate_default_ui( image = WebRTC( label="Stream", rtc_configuration=self.rtc_configuration, - mode="send-receive", + mode="send", modality="audio", icon=ui_args.get("icon"), icon_button_color=ui_args.get("icon_button_color"), @@ -505,7 +505,7 @@ async def handle_incoming_call(self, request: Request): return HTMLResponse(content=str(response), media_type="application/xml") async def telephone_handler(self, websocket: WebSocket): - handler = cast(StreamHandlerImpl, self.event_handler.copy()) + handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore handler.phone_mode = True async def set_handler(s: str, a: WebSocketHandler): @@ -528,7 +528,7 @@ async def set_handler(s: str, a: WebSocketHandler): await ws.handle_websocket(websocket) async def websocket_offer(self, websocket: WebSocket): - handler = cast(StreamHandlerImpl, self.event_handler.copy()) + handler = cast(StreamHandlerImpl, self.event_handler.copy()) # type: ignore handler.phone_mode = False async def set_handler(s: str, a: WebSocketHandler): diff --git a/backend/fastrtc/tracks.py b/backend/fastrtc/tracks.py index 5742e4f..4ef50bc 100644 --- a/backend/fastrtc/tracks.py +++ b/backend/fastrtc/tracks.py @@ -188,6 +188,11 @@ def __init__( self.args_set = asyncio.Event() self.channel_set = asyncio.Event() self._phone_mode = False + self._clear_queue: Callable | None = None + + @property + def clear_queue(self) -> Callable: + return cast(Callable, self._clear_queue) @property def loop(self) -> asyncio.AbstractEventLoop: @@ -237,8 +242,11 @@ async def send_message(self, msg: str): logger.debug("Sent msg %s", msg) def send_message_sync(self, msg: str): - asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result() - logger.debug("Sent msg %s", msg) + try: + asyncio.run_coroutine_threadsafe(self.send_message(msg), self.loop).result() + logger.debug("Sent msg %s", msg) + except Exception as e: + logger.debug("Exception sending msg %s", e) def set_args(self, args: list[Any]): logger.debug("setting args in audio callback %s", args) @@ -411,6 +419,7 @@ def __init__( super().__init__() self.track = track self.event_handler = cast(StreamHandlerImpl, event_handler) + self.event_handler._clear_queue = self.clear_queue self.current_timestamp = 0 self.latest_args: str | list[Any] = "not_set" self.queue = asyncio.Queue() @@ -421,6 +430,12 @@ def __init__( self.channel = channel self.set_additional_outputs = set_additional_outputs + def clear_queue(self): + if self.queue: + while not self.queue.empty(): + self.queue.get_nowait() + self._start = None + def set_channel(self, channel: DataChannel): self.channel = channel self.event_handler.set_channel(channel) @@ -608,6 +623,7 @@ def __init__( ) -> None: self.generator: Generator[Any, None, Any] | None = None self.event_handler = event_handler + self.event_handler._clear_queue = self.clear_queue self.current_timestamp = 0 self.latest_args: str | list[Any] = "not_set" self.args_set = threading.Event() @@ -619,6 +635,11 @@ def __init__( self._start: float | None = None super().__init__() + def clear_queue(self): + while not self.queue.empty(): + self.queue.get_nowait() + self._start = None + def set_channel(self, channel: DataChannel): self.channel = channel diff --git a/backend/fastrtc/utils.py b/backend/fastrtc/utils.py index 52e0de2..dafc882 100644 --- a/backend/fastrtc/utils.py +++ b/backend/fastrtc/utils.py @@ -320,7 +320,7 @@ def audio_to_int16( >>> audio_int16 = audio_to_int16(audio_tuple) """ if audio[1].dtype == np.int16: - return audio[1] + return audio[1] # type: ignore elif audio[1].dtype == np.float32: # Convert float32 to int16 by scaling to the int16 range return (audio[1] * 32767.0).astype(np.int16) diff --git a/backend/fastrtc/websocket.py b/backend/fastrtc/websocket.py index 88ab003..738e2c0 100644 --- a/backend/fastrtc/websocket.py +++ b/backend/fastrtc/websocket.py @@ -55,6 +55,7 @@ def __init__( ], ): self.stream_handler = stream_handler + self.stream_handler._clear_queue = lambda: None self.websocket: Optional[WebSocket] = None self._emit_task: Optional[asyncio.Task] = None self.stream_id: Optional[str] = None diff --git a/demo/llm_voice_chat/app.py b/demo/llm_voice_chat/app.py index 0c304d6..bbf8647 100644 --- a/demo/llm_voice_chat/app.py +++ b/demo/llm_voice_chat/app.py @@ -41,7 +41,7 @@ def response( response_text = ( groq_client.chat.completions.create( model="llama-3.1-8b-instant", - max_tokens=512, + max_tokens=200, messages=messages, # type: ignore ) .choices[0] @@ -49,6 +49,7 @@ def response( ) chatbot.append({"role": "assistant", "content": response_text}) + yield AdditionalOutputs(chatbot) for chunk in tts_client.text_to_speech.convert_as_stream( text=response_text, # type: ignore @@ -58,7 +59,6 @@ def response( ): audio_array = np.frombuffer(chunk, dtype=np.int16).reshape(1, -1) yield (24000, audio_array) - yield AdditionalOutputs(chatbot) chatbot = gr.Chatbot(type="messages") diff --git a/demo/moonshine_live/app.py b/demo/moonshine_live/app.py index 9832394..f6db735 100644 --- a/demo/moonshine_live/app.py +++ b/demo/moonshine_live/app.py @@ -3,6 +3,7 @@ import gradio as gr import numpy as np +from dotenv import load_dotenv from fastrtc import ( AdditionalOutputs, ReplyOnPause, @@ -13,6 +14,8 @@ from moonshine_onnx import MoonshineOnnxModel, load_tokenizer from numpy.typing import NDArray +load_dotenv() + @lru_cache(maxsize=None) def load_moonshine( @@ -27,6 +30,7 @@ def load_moonshine( def stt( audio: tuple[int, NDArray[np.int16 | np.float32]], model_name: Literal["moonshine/base", "moonshine/tiny"], + captions: str, ) -> Generator[AdditionalOutputs, None, None]: moonshine = load_moonshine(model_name) sr, audio_np = audio # type: ignore @@ -35,9 +39,12 @@ def stt( if audio_np.ndim == 1: audio_np = audio_np.reshape(1, -1) tokens = moonshine.generate(audio_np) - yield AdditionalOutputs(tokenizer.decode_batch(tokens)[0]) + yield AdditionalOutputs( + (captions + "\n" + tokenizer.decode_batch(tokens)[0]).strip() + ) +captions = gr.Textbox(label="Captions") stream = Stream( ReplyOnPause(stt, input_sample_rate=16000), modality="audio", @@ -55,9 +62,10 @@ def stt( choices=["moonshine/base", "moonshine/tiny"], value="moonshine/base", label="Model", - ) + ), + captions, ], - additional_outputs=[gr.Textbox(label="Captions")], + additional_outputs=[captions], additional_outputs_handler=lambda prev, current: (prev + "\n" + current).strip(), ) diff --git a/demo/whisper_realtime/app.py b/demo/whisper_realtime/app.py index 74b5c83..a364830 100644 --- a/demo/whisper_realtime/app.py +++ b/demo/whisper_realtime/app.py @@ -15,6 +15,7 @@ ) from gradio.utils import get_space from groq import AsyncClient +from pydantic import BaseModel cur_dir = Path(__file__).parent @@ -24,23 +25,23 @@ groq_client = AsyncClient() -async def transcribe(audio: tuple[int, np.ndarray]): - transcript = await groq_client.audio.transcriptions.create( +async def transcribe(audio: tuple[int, np.ndarray], transcript: str): + response = await groq_client.audio.transcriptions.create( file=("audio-file.mp3", audio_to_bytes(audio)), model="whisper-large-v3-turbo", response_format="verbose_json", ) - yield AdditionalOutputs(transcript.text) + yield AdditionalOutputs(transcript + "\n" + response.text) +transcript = gr.Textbox(label="Transcript") stream = Stream( ReplyOnPause(transcribe), modality="audio", mode="send", - additional_outputs=[ - gr.Textbox(label="Transcript"), - ], - additional_outputs_handler=lambda a, b: a + " " + b, + additional_inputs=[transcript], + additional_outputs=[transcript], + additional_outputs_handler=lambda a, b: b, rtc_configuration=get_twilio_turn_credentials() if get_space() else None, concurrency_limit=5 if get_space() else None, time_limit=90 if get_space() else None, @@ -51,11 +52,21 @@ async def transcribe(audio: tuple[int, np.ndarray]): stream.mount(app) +class SendInput(BaseModel): + webrtc_id: str + transcript: str + + +@app.post("/send_input") +def send_input(body: SendInput): + stream.set_input(body.webrtc_id, body.transcript) + + @app.get("/transcript") def _(webrtc_id: str): async def output_stream(): async for output in stream.output_stream(webrtc_id): - transcript = output.args[0] + transcript = output.args[0].split("\n")[-1] yield f"event: output\ndata: {transcript}\n\n" return StreamingResponse(output_stream(), media_type="text/event-stream") @@ -73,7 +84,7 @@ def index(): import os if (mode := os.getenv("MODE")) == "UI": - stream.ui.launch(server_port=7860, server_name="0.0.0.0") + stream.ui.launch(server_port=7860) elif mode == "PHONE": stream.fastphone(host="0.0.0.0", port=7860) else: diff --git a/demo/whisper_realtime/index.html b/demo/whisper_realtime/index.html index 72789df..d757040 100644 --- a/demo/whisper_realtime/index.html +++ b/demo/whisper_realtime/index.html @@ -193,7 +193,8 @@

Real-time Transcription

-
+
+
@@ -220,13 +221,23 @@

Real-time Transcription

}, 5000); } - function handleMessage(event) { + async function handleMessage(event) { // Handle any WebRTC data channel messages if needed const eventJson = JSON.parse(event.data); if (eventJson.type === "error") { showError(eventJson.message); + } else if (eventJson.type === "send_input") { + const response = await fetch('/send_input', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + webrtc_id: webrtc_id, + transcript: "" + }) + }); } console.log('Received message:', event.data); + } function updateButtonState() { diff --git a/docs/userguide/audio.md b/docs/userguide/audio.md index e426e77..49357ca 100644 --- a/docs/userguide/audio.md +++ b/docs/userguide/audio.md @@ -3,6 +3,8 @@ Typically, you want to run a python function whenever a user has stopped speaking. This can be done by wrapping a python generator with the `ReplyOnPause` class and passing it to the `handler` argument of the `Stream` object. The `ReplyOnPause` class will handle the voice detection and turn taking logic automatically! +By default, the `ReplyOnPause` handler will allow you to interrupt the response at any time by speaking again. If you do not want to allow interruption, you can set the `can_interrupt` parameter to `False`. + === "Code" ```python from fastrtc import ReplyOnPause, Stream @@ -33,13 +35,14 @@ Typically, you want to run a python function whenever a user has stopped speakin You can also use an async generator with `ReplyOnPause`. !!! tip "Parameters" - You can customize the voice detection parameters by passing in `algo_options` and `model_options` to the `ReplyOnPause` class. + You can customize the voice detection parameters by passing in `algo_options` and `model_options` to the `ReplyOnPause` class. Also, you can set the `can_interrupt` parameter to `False` to prevent the user from interrupting the response. By default, `can_interrupt` is `True`. ```python from fastrtc import AlgoOptions, SileroVadOptions stream = Stream( handler=ReplyOnPause( response, + can_interrupt=True, algo_options=AlgoOptions( audio_chunk_duration=0.6, started_talking_threshold=0.2, diff --git a/frontend/shared/InteractiveAudio.svelte b/frontend/shared/InteractiveAudio.svelte index bd9332b..38861a8 100644 --- a/frontend/shared/InteractiveAudio.svelte +++ b/frontend/shared/InteractiveAudio.svelte @@ -7,10 +7,11 @@ import { StreamingBar } from "@gradio/statustracker"; import { Circle, - Square, Spinner, Music, DropdownArrow, + VolumeMuted, + VolumeHigh, Microphone, } from "@gradio/icons"; @@ -77,6 +78,7 @@ let available_audio_devices: MediaDeviceInfo[]; let selected_device: MediaDeviceInfo | null = null; let mic_accessed = false; + let is_muted = false; const audio_source_callback = () => { if (mode === "send") return stream; @@ -261,6 +263,13 @@ options_open = false; }; + function toggleMute(): void { + if (audio_player) { + audio_player.muted = !audio_player.muted; + is_muted = audio_player.muted; + } + } + $: if (stopword_recognized) { notification_sound.play(); } @@ -314,19 +323,28 @@
{:else if stream_state === "open"}
-
- stream} - stream_state={"open"} - icon={Circle} - {icon_button_color} - {pulse_color} - /> -
+ {#if mode === "send-receive"} +
+ stream} + stream_state={"open"} + icon={Circle} + {icon_button_color} + {pulse_color} + /> +
+ {:else} +
+ +
+ {/if} {button_labels.stop || i18n("audio.stop")}
{:else} @@ -347,6 +365,24 @@ {/if} + {#if stream_state === "open" && mode === "send-receive"} + + {/if} {#if options_open && selected_device}