Skip to content

Commit

Permalink
Tidy up connection logic (#90)
Browse files Browse the repository at this point in the history
* Add code:

* code

* code

---------

Co-authored-by: Freddy Boulton <[email protected]>
  • Loading branch information
freddyaboulton and Freddy Boulton authored Feb 26, 2025
1 parent e44341d commit 43e42c1
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 12 deletions.
23 changes: 20 additions & 3 deletions backend/fastrtc/webrtc_connection_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class WebRTCConnectionMixin:
data_channels: dict[str, DataChannel] = {}
additional_outputs: dict[str, OutputQueue] = defaultdict(OutputQueue)
handlers: dict[str, HandlerType | Callable] = {}

connection_timeouts: dict[str, asyncio.Event] = defaultdict(asyncio.Event)
concurrency_limit: int | float
event_handler: HandlerType
time_limit: float | int | None
Expand All @@ -82,8 +82,24 @@ async def wait_for_time_limit(pc: RTCPeerConnection, time_limit: float):
await asyncio.sleep(time_limit)
await pc.close()

async def connection_timeout(
self,
pc: RTCPeerConnection,
webrtc_id: str,
time_limit: float,
):
try:
await asyncio.wait_for(
self.connection_timeouts[webrtc_id].wait(), time_limit
)
except (asyncio.TimeoutError, TimeoutError):
await pc.close()
self.connection_timeouts[webrtc_id].clear()
self.clean_up(webrtc_id)

def clean_up(self, webrtc_id: str):
self.handlers.pop(webrtc_id, None)
self.connection_timeouts.pop(webrtc_id, None)
connection = self.connections.pop(webrtc_id, [])
for conn in connection:
if isinstance(conn, AudioCallback):
Expand Down Expand Up @@ -132,7 +148,7 @@ async def handle_offer(self, body, set_outputs):
logger.debug("Offer body %s", body)
if len(self.connections) >= cast(int, self.concurrency_limit):
return JSONResponse(
status_code=429,
status_code=200,
content={
"status": "failed",
"meta": {
Expand Down Expand Up @@ -181,6 +197,7 @@ async def _():
conn.stop()
self.pcs.discard(pc)
if pc.connectionState == "connected":
self.connection_timeouts[body["webrtc_id"]].set()
if self.time_limit is not None:
asyncio.create_task(self.wait_for_time_limit(pc, self.time_limit))

Expand Down Expand Up @@ -269,7 +286,7 @@ def _(message):

# handle offer
await pc.setRemoteDescription(offer)

asyncio.create_task(self.connection_timeout(pc, body["webrtc_id"], 30))
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer) # type: ignore
Expand Down
15 changes: 8 additions & 7 deletions demo/moonshine_live/app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from functools import lru_cache
from typing import Generator, Literal

import gradio as gr
import numpy as np
from fastrtc import (
Stream,
AdditionalOutputs,
audio_to_float32,
ReplyOnPause,
Stream,
audio_to_float32,
get_twilio_turn_credentials,
)
from functools import lru_cache
import gradio as gr
from typing import Generator, Literal
from numpy.typing import NDArray
import numpy as np
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
from numpy.typing import NDArray


@lru_cache(maxsize=None)
Expand Down
4 changes: 3 additions & 1 deletion docs/userguide/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def stream_updates(webrtc_id: str):

### Handling Errors

When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 429 error.
When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` route with a JSON response. If there are too many connections, the server will respond with a 200 error.

```json
{
Expand All @@ -122,6 +122,8 @@ When connecting via `WebRTC`, the server will respond to the `/webrtc/offer` rou

Over `WebSocket`, the server will send the same message before closing the connection.

!!! tip
The server will sends a 200 status code because otherwise the gradio client will not be able to process the json response and display the error.

<style>
.config-selector {
Expand Down
8 changes: 8 additions & 0 deletions frontend/shared/InteractiveAudio.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@
_time_limit = null;
stop(pc);
break;
case "failed":
console.info("failed");
stream_state = "closed";
_time_limit = null;
dispatch("error", "Connection failed!");
stop(pc);
break;
default:
break;
}
Expand Down Expand Up @@ -209,6 +216,7 @@
})
.catch(() => {
console.info("catching");
clearTimeout(timeoutId);
stream_state = "closed";
});
}
Expand Down
5 changes: 5 additions & 0 deletions frontend/shared/StaticAudio.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@
console.info("closed");
stop(pc);
break;
case "failed":
stream_state = "closed";
dispatch("error", "Connection failed!");
stop(pc);
break;
default:
break;
}
Expand Down
5 changes: 5 additions & 0 deletions frontend/shared/StaticVideo.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
case "disconnected":
stop(pc);
break;
case "failed":
stream_state = "closed";
dispatch("error", "Connection failed!");
stop(pc);
break;
default:
break;
}
Expand Down
6 changes: 6 additions & 0 deletions frontend/shared/Webcam.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@
stop(pc);
await access_webcam();
break;
case "failed":
stream_state = "closed";
_time_limit = null;
dispatch("error", "Connection failed!");
stop(pc);
break;
default:
break;
}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "hatchling.build"

[project]
name = "fastrtc"
version = "0.0.6"
version = "0.0.8post1"
description = "The realtime communication library for Python"
readme = "README.md"
license = "apache-2.0"
Expand Down

0 comments on commit 43e42c1

Please sign in to comment.