diff --git a/backend/fastrtc/webrtc_connection_mixin.py b/backend/fastrtc/webrtc_connection_mixin.py index 7e03e46..bc72b3c 100644 --- a/backend/fastrtc/webrtc_connection_mixin.py +++ b/backend/fastrtc/webrtc_connection_mixin.py @@ -70,7 +70,7 @@ class WebRTCConnectionMixin: data_channels: dict[str, DataChannel] = {} additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue) handlers: dict[str, HandlerType | Callable] = {} - + connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event) concurrency_limit: int | float event_handler: HandlerType time_limit: float | int | None @@ -82,8 +82,24 @@ async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float): await asyncio.sleep(time_limit) await pc.close() + async def connection_timeout( + self, + pc: RTCPeerConnection, + webrtc_id: str, + time_limit: float, + ): + try: + await asyncio.wait_for( + self.connection_timeouts[webrtc_id].wait(), time_limit + ) + except (asyncio.TimeoutError, TimeoutError): + await pc.close() + self.connection_timeouts[webrtc_id].clear() + self.clean_up(webrtc_id) + def clean_up(self, webrtc_id: str): self.handlers.pop(webrtc_id, None) + self.connection_timeouts.pop(webrtc_id, None) connection = self.connections.pop(webrtc_id, []) for conn in connection: if isinstance(conn, AudioCallback): @@ -132,7 +148,7 @@ async def handle_offer(self, body, set_outputs): logger.debug("Offer body %s", body) if len(self.connections) >= cast(int, self.concurrency_limit): return JSONResponse( - status_code=429, + status_code=200, content={ "status": "failed", "meta": { @@ -181,6 +197,7 @@ async def _(): conn.stop() self.pcs.discard(pc) if pc.connectionState == "connected": + self.connection_timeouts[body["webrtc_id"]].set() if self.time_limit is not None: asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit)) @@ -269,7 +286,7 @@ def _(message): # handle offer await pc.setRemoteDescription(offer) - + asyncio.create_task(self.connection_timeout(pc, body["webrtc_id"], 30)) # send answer answer = await pc.createAnswer() await pc.setLocalDescription(answer) # type: ignore diff --git a/demo/moonshine_live/app.py b/demo/moonshine_live/app.py index bbb05cb..9832394 100644 --- a/demo/moonshine_live/app.py +++ b/demo/moonshine_live/app.py @@ -1,16 +1,17 @@ +from functools import lru_cache +from typing import Generator, Literal + +import gradio as gr +import numpy as np from fastrtc import ( - Stream, AdditionalOutputs, - audio_to_float32, ReplyOnPause, + Stream, + audio_to_float32, get_twilio_turn_credentials, ) -from functools import lru_cache -import gradio as gr -from typing import Generator, Literal -from numpy.typing import NDArray -import numpy as np from moonshine_onnx import MoonshineOnnxModel, load_tokenizer +from numpy.typing import NDArray @lru_cache(maxsize=None) diff --git a/docs/userguide/api.md b/docs/userguide/api.md index c7fe3aa..4ff8cdb 100644 --- a/docs/userguide/api.md +++ b/docs/userguide/api.md @@ -109,7 +109,7 @@ async def stream_updates(webrtc_id: str): ### Handling Errors -When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 429 error. +When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 200 error. ```json { @@ -122,6 +122,8 @@ When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` rou Over `WebSocket`, the server will send the same message before closing the connection. +!!! tip + The server will sends a 200 status code because otherwise the gradio client will not be able to process the json response and display the error.