Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
freddyaboulton committed Mar 6, 2025
1 parent 2f82469 commit 0d0dc12
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
5 changes: 3 additions & 2 deletions backend/fastrtc/reply_on_pause.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,11 @@ def receive(self, frame: tuple[int, np.ndarray]) -> None:
self.process_audio(frame, self.state)
if self.state.pause_detected:
self.event.set()
if self.can_interrupt:
self.clear_queue()
if self.can_interrupt and self.state.responding:
self._close_generator()
self.generator = None
if self.can_interrupt:
self.clear_queue()

def _close_generator(self):
"""Properly close the generator to ensure resources are released."""
Expand Down
4 changes: 2 additions & 2 deletions backend/fastrtc/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ async def telephone_handler(self, websocket: WebSocket):
handler.phone_mode = True

async def set_handler(s: str, a: WebSocketHandler):
if len(self.connections) >= self.concurrency_limit:
if len(self.connections) >= self.concurrency_limit: # type: ignore
await cast(WebSocket, a.websocket).send_json(
{
"status": "failed",
Expand All @@ -532,7 +532,7 @@ async def websocket_offer(self, websocket: WebSocket):
handler.phone_mode = False

async def set_handler(s: str, a: WebSocketHandler):
if len(self.connections) >= self.concurrency_limit:
if len(self.connections) >= self.concurrency_limit: # type: ignore
await cast(WebSocket, a.websocket).send_json(
{
"status": "failed",
Expand Down
12 changes: 8 additions & 4 deletions backend/fastrtc/tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,14 @@ def __init__(
self.set_additional_outputs = set_additional_outputs

def clear_queue(self):
if self.queue:
while not self.queue.empty():
self.queue.get_nowait()
self._start = None
logger.debug("clearing queue")
logger.debug("queue size: %d", self.queue.qsize())
i = 0
while not self.queue.empty():
self.queue.get_nowait()
i += 1
logger.debug("popped %d items from queue", i)
self._start = None

def set_channel(self, channel: DataChannel):
self.channel = channel
Expand Down
13 changes: 7 additions & 6 deletions backend/fastrtc/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fastapi import WebSocket

from .tracks import AsyncStreamHandler, StreamHandlerImpl
from .utils import AdditionalOutputs, DataChannel, split_output, wait_for_item
from .utils import AdditionalOutputs, DataChannel, split_output


class WebSocketDataChannel(DataChannel):
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
],
):
self.stream_handler = stream_handler
self.stream_handler._clear_queue = self.clear_queue
self.stream_handler._clear_queue = self._clear_queue
self.websocket: Optional[WebSocket] = None
self._emit_task: Optional[asyncio.Task] = None
self.stream_id: Optional[str] = None
Expand All @@ -66,17 +66,18 @@ def __init__(
self.clean_up = clean_up
self.queue = asyncio.Queue()

def clear_queue(self):
# Replace the queue with a new empty queue
def _clear_queue(self):
old_queue = self.queue
self.queue = asyncio.Queue()

# Drain the old queue to ensure task_done() is called for all items
logger.debug("clearing queue")
i = 0
while not old_queue.empty():
try:
old_queue.get_nowait()
i += 1
except asyncio.QueueEmpty:
break
logger.debug("popped %d items from queue", i)

def set_args(self, args: list[Any]):
self.stream_handler.set_args(args)
Expand Down

0 comments on commit 0d0dc12

Please sign in to comment.