Skip to content

Commit

Permalink
Clear websocket queue on interrupt
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Mar 6, 2025
1 parent a0b46f4 commit 2f82469
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions backend/fastrtc/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fastapi import WebSocket

from .tracks import AsyncStreamHandler, StreamHandlerImpl
from .utils import AdditionalOutputs, DataChannel, split_output
from .utils import AdditionalOutputs, DataChannel, split_output, wait_for_item


class WebSocketDataChannel(DataChannel):
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
],
):
self.stream_handler = stream_handler
self.stream_handler._clear_queue = lambda: None
self.stream_handler._clear_queue = self.clear_queue
self.websocket: Optional[WebSocket] = None
self._emit_task: Optional[asyncio.Task] = None
self.stream_id: Optional[str] = None
Expand All @@ -64,6 +64,19 @@ def __init__(
self.set_handler = set_handler
self.quit = asyncio.Event()
self.clean_up = clean_up
self.queue = asyncio.Queue()

def clear_queue(self):
# Replace the queue with a new empty queue
old_queue = self.queue
self.queue = asyncio.Queue()

# Drain the old queue to ensure task_done() is called for all items
while not old_queue.empty():
try:
old_queue.get_nowait()
except asyncio.QueueEmpty:
break

def set_args(self, args: list[Any]):
self.stream_handler.set_args(args)
Expand All @@ -77,6 +90,7 @@ async def handle_websocket(self, websocket: WebSocket):
self.stream_handler._loop = loop
self.stream_handler.set_channel(self.data_channel)
self._emit_task = asyncio.create_task(self._emit_loop())
self._emit_to_queue_task = asyncio.create_task(self._emit_to_queue())
if isinstance(self.stream_handler, AsyncStreamHandler):
start_up = self.stream_handler.start_up()
else:
Expand Down Expand Up @@ -137,17 +151,32 @@ async def handle_websocket(self, websocket: WebSocket):
finally:
if self._emit_task:
self._emit_task.cancel()
if self._emit_to_queue_task:
self._emit_to_queue_task.cancel()
if self.start_up_task:
self.start_up_task.cancel()
await websocket.close()

async def _emit_loop(self):
async def _emit_to_queue(self):
try:
while not self.quit.is_set():
if isinstance(self.stream_handler, AsyncStreamHandler):
output = await self.stream_handler.emit()
else:
output = await run_sync(self.stream_handler.emit)
self.queue.put_nowait(output)
except asyncio.CancelledError:
logger.debug("Emit loop cancelled")
except Exception as e:
import traceback

traceback.print_exc()
logger.debug("Error in emit loop: %s", e)

async def _emit_loop(self):
try:
while not self.quit.is_set():
output = await self.queue.get()

if output is not None:
frame, output = split_output(output)
Expand Down

0 comments on commit 2f82469

Please sign in to comment.