Skip to content

Commit

Permalink
Voice hooks in RealtimeAgent (ag2ai#1040)
Browse files Browse the repository at this point in the history
* input audio messages passed to RealtimeClient

* Added AudioObserver

* clear outputs from the notebook

* notebook formatting

* more notebook formatting

* even more notebook formatting

* exciting more notebook formatting

* catch CancelledError

* catch queue get exception

* see if timeout influences result of test

* check the number of calls in test

* add timeout in test in order to allow background tasks to run

* move away from asyncio

* remove asyncio.sleep from the test

* skip failing realtime tests
  • Loading branch information
davorinrusevljan authored Feb 19, 2025
1 parent 74902f3 commit 5d2f165
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 22 deletions.
2 changes: 2 additions & 0 deletions autogen/agentchat/realtime/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
# SPDX-License-Identifier: Apache-2.0

from .audio_adapters import TwilioAudioAdapter, WebSocketAudioAdapter
from .audio_observer import AudioObserver
from .function_observer import FunctionObserver
from .realtime_agent import RealtimeAgent
from .realtime_observer import RealtimeObserver
from .realtime_swarm import register_swarm

__all__ = [
"AudioObserver",
"FunctionObserver",
"RealtimeAgent",
"RealtimeObserver",
Expand Down
42 changes: 42 additions & 0 deletions autogen/agentchat/realtime/experimental/audio_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, Optional

from ....doc_utils import export_module
from .realtime_events import InputAudioBufferDelta, RealtimeEvent
from .realtime_observer import RealtimeObserver

if TYPE_CHECKING:
from logging import Logger


@export_module("autogen.agentchat.realtime.experimental")
class AudioObserver(RealtimeObserver):
"""Observer for user voice input"""

def __init__(self, *, logger: Optional["Logger"] = None) -> None:
"""Observer for user voice input"""
super().__init__(logger=logger)

async def on_event(self, event: RealtimeEvent) -> None:
"""Observe voice input events from the Realtime.
Args:
event (dict[str, Any]): The event from the OpenAI Realtime API.
"""
if isinstance(event, InputAudioBufferDelta):
self.logger.info("Received audio buffer delta")

async def initialize_session(self) -> None:
"""No need to initialize session from this observer"""
pass

async def run_loop(self) -> None:
"""Run the observer loop."""
pass


if TYPE_CHECKING:
function_observer: RealtimeObserver = AudioObserver()
13 changes: 10 additions & 3 deletions autogen/agentchat/realtime/experimental/clients/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ......doc_utils import export_module
from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated
from ..realtime_client import Role, register_realtime_client
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client

if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
Expand All @@ -30,7 +30,7 @@

@register_realtime_client()
@export_module("autogen.agentchat.realtime.experimental.clients")
class GeminiRealtimeClient:
class GeminiRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for Gemini Realtime API."""

def __init__(
Expand All @@ -44,6 +44,7 @@ def __init__(
Args:
llm_config (dict[str, Any]): The config for the client.
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger

Expand Down Expand Up @@ -123,6 +124,7 @@ async def send_audio(self, audio: str) -> None:
]
}
}
await self.queue_input_audio_buffer_delta(audio)
if self._is_reading_events:
await self.connection.send(json.dumps(msg))

Expand Down Expand Up @@ -185,13 +187,18 @@ async def connect(self) -> AsyncGenerator[None, None]:
self._connection = None

async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read Events from the Gemini Realtime API."""
"""Read Events from the Gemini Realtime Client"""
if self._connection is None:
raise RuntimeError("Client is not connected, call connect() first.")
await self._initialize_session()

self._is_reading_events = True

async for event in self._read_events():
yield event

async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the Gemini Realtime connection."""
async for raw_message in self.connection:
message = raw_message.decode("ascii") if isinstance(raw_message, bytes) else raw_message
events = self._parse_message(json.loads(message))
Expand Down
17 changes: 12 additions & 5 deletions autogen/agentchat/realtime/experimental/clients/oai/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ......doc_utils import export_module
from ...realtime_events import RealtimeEvent
from ..realtime_client import Role, register_realtime_client
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message

if TYPE_CHECKING:
Expand All @@ -25,7 +25,7 @@

@register_realtime_client()
@export_module("autogen.agentchat.realtime.experimental.clients")
class OpenAIRealtimeClient:
class OpenAIRealtimeClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API."""

def __init__(
Expand All @@ -39,6 +39,7 @@ def __init__(
Args:
llm_config (dict[str, Any]): The config for the client.
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger

Expand Down Expand Up @@ -110,6 +111,7 @@ async def send_audio(self, audio: str) -> None:
Args:
audio (str): The audio to send.
"""
await self.queue_input_audio_buffer_delta(audio)
await self.connection.input_audio_buffer.append(audio=audio)

async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
Expand Down Expand Up @@ -163,13 +165,18 @@ async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
raise RuntimeError("Client is not connected, call connect() first.")

try:
async for message in self._connection:
for event in self._parse_message(message.model_dump()):
yield event
async for event in self._read_events():
yield event

finally:
self._connection = None

async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API."""
async for message in self._connection:
for event in self._parse_message(message.model_dump()):
yield event

def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
"""Parse a message from the OpenAI Realtime API.
Expand Down
15 changes: 11 additions & 4 deletions autogen/agentchat/realtime/experimental/clients/oai/rtc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ......doc_utils import export_module
from ...realtime_events import RealtimeEvent
from ..realtime_client import Role, register_realtime_client
from ..realtime_client import RealtimeClientBase, Role, register_realtime_client
from .utils import parse_oai_message

if TYPE_CHECKING:
Expand All @@ -26,7 +26,7 @@

@register_realtime_client()
@export_module("autogen.agentchat.realtime.experimental.clients.oai")
class OpenAIRealtimeWebRTCClient:
class OpenAIRealtimeWebRTCClient(RealtimeClientBase):
"""(Experimental) Client for OpenAI Realtime API that uses WebRTC protocol."""

def __init__(
Expand All @@ -41,6 +41,7 @@ def __init__(
Args:
llm_config (dict[str, Any]): The config for the client.
"""
super().__init__()
self._llm_config = llm_config
self._logger = logger
self._websocket = websocket
Expand Down Expand Up @@ -94,11 +95,12 @@ async def send_text(self, *, role: Role, text: str) -> None:

async def send_audio(self, audio: str) -> None:
"""Send audio to the OpenAI Realtime API.
in case of WebRTC, audio is already sent by js client, so we just queue it in order to be logged.
Args:
audio (str): The audio to send.
"""
await self._websocket.send_json({"type": "input_audio_buffer.append", "audio": audio})
await self.queue_input_audio_buffer_delta(audio)

async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None:
"""Truncate audio in the OpenAI Realtime API.
Expand Down Expand Up @@ -176,7 +178,12 @@ async def connect(self) -> AsyncGenerator[None, None]:
pass

async def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API.
"""Read events from the OpenAI Realtime API."""
async for event in self._read_events():
yield event

async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from the OpenAI Realtime API connection.
Again, in case of WebRTC, we do not read OpenAI messages directly since we
do not hold connection to OpenAI. Instead we read messages from the websocket, and javascript
client on the other side of the websocket that is connected to OpenAI is relaying events to us.
Expand Down
12 changes: 11 additions & 1 deletion autogen/agentchat/realtime/experimental/clients/oai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
import json
from typing import Any

from ...realtime_events import AudioDelta, FunctionCall, RealtimeEvent, SessionCreated, SessionUpdated, SpeechStarted
from ...realtime_events import (
AudioDelta,
FunctionCall,
InputAudioBufferDelta,
RealtimeEvent,
SessionCreated,
SessionUpdated,
SpeechStarted,
)

__all__ = ["parse_oai_message"]

Expand All @@ -27,6 +35,8 @@ def parse_oai_message(message: dict[str, Any]) -> RealtimeEvent:
return AudioDelta(raw_message=message, delta=message["delta"], item_id=message["item_id"])
elif message.get("type") == "input_audio_buffer.speech_started":
return SpeechStarted(raw_message=message)
elif message.get("type") == "input_audio_buffer.delta":
return InputAudioBufferDelta(delta=message.delta, item_id=None, raw_message=message)
elif message.get("type") == "response.function_call_arguments.done":
return FunctionCall(
raw_message=message,
Expand Down
49 changes: 47 additions & 2 deletions autogen/agentchat/realtime/experimental/clients/realtime_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
from collections.abc import AsyncGenerator
from logging import Logger
from typing import Any, AsyncContextManager, Callable, Literal, Optional, Protocol, TypeVar, runtime_checkable

from asyncer import create_task_group

from .....doc_utils import export_module
from ..realtime_events import RealtimeEvent
from ..realtime_events import InputAudioBufferDelta, RealtimeEvent

__all__ = ["RealtimeClientProtocol", "Role", "get_client", "register_realtime_client"]

Expand Down Expand Up @@ -65,7 +68,11 @@ async def session_update(self, session_options: dict[str, Any]) -> None:
def connect(self) -> AsyncContextManager[None]: ...

def read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read messages from a Realtime API."""
"""Read events from a Realtime Client."""
...

async def _read_from_connection(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime connection."""
...

def _parse_message(self, message: dict[str, Any]) -> list[RealtimeEvent]:
Expand Down Expand Up @@ -95,6 +102,44 @@ def get_factory(
...


class RealtimeClientBase:
def __init__(self):
self._eventQueue = asyncio.Queue()

async def add_event(self, event: Optional[RealtimeEvent]):
await self._eventQueue.put(event)

async def get_event(self) -> Optional[RealtimeEvent]:
return await self._eventQueue.get()

async def _read_from_connection_task(self):
async for event in self._read_from_connection():
await self.add_event(event)
self.add_event(None)

async def _read_events(self) -> AsyncGenerator[RealtimeEvent, None]:
"""Read events from a Realtime Client."""
async with create_task_group() as tg:
tg.start_soon(self._read_from_connection_task)
while True:
try:
event = await self._eventQueue.get()
if event is not None:
yield event
else:
break
except Exception:
break

async def queue_input_audio_buffer_delta(self, audio: str) -> None:
"""queue InputAudioBufferDelta.
Args:
audio (str): The audio.
"""
await self.add_event(InputAudioBufferDelta(delta=audio, item_id=None, raw_message=dict()))


_realtime_client_classes: dict[str, type[RealtimeClientProtocol]] = {}

T = TypeVar("T", bound=RealtimeClientProtocol)
Expand Down
6 changes: 4 additions & 2 deletions autogen/agentchat/realtime/experimental/realtime_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
system_message: str = "You are a helpful AI Assistant.",
llm_config: dict[str, Any],
logger: Optional[Logger] = None,
observers: Optional[list[RealtimeObserver]] = None,
**client_kwargs: Any,
):
"""(Experimental) Agent for interacting with the Realtime Clients.
Expand All @@ -48,6 +49,7 @@ def __init__(
system_message (str): The system message for the agent.
llm_config (dict[str, Any], bool): The config for the agent.
logger (Optional[Logger]): The logger for the agent.
observers (Optional[list[RealtimeObserver]]): The additional observers for the agent.
**client_kwargs (Any): The keyword arguments for the client.
"""
self._logger = logger
Expand All @@ -59,8 +61,8 @@ def __init__(
)

self._registered_realtime_tools: dict[str, Tool] = {}
self._observers: list[RealtimeObserver] = [FunctionObserver(logger=logger)]

self._observers: list[RealtimeObserver] = observers if observers else []
self._observers.append(FunctionObserver(logger=logger))
if audio_adapter:
self._observers.append(audio_adapter)

Expand Down
6 changes: 6 additions & 0 deletions autogen/agentchat/realtime/experimental/realtime_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ class AudioDelta(RealtimeEvent):
item_id: Any


class InputAudioBufferDelta(RealtimeEvent):
type: Literal["input_audio_buffer.delta"] = "input_audio_buffer.delta"
delta: str
item_id: Any


class SpeechStarted(RealtimeEvent):
type: Literal["input_audio_buffer.speech_started"] = "input_audio_buffer.speech_started"

Expand Down
4 changes: 4 additions & 0 deletions autogen/agentchat/realtime/experimental/realtime_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,7 @@ async def on_event(self, event: RealtimeEvent) -> None:
event (RealtimeServerEvent): The event from the OpenAI Realtime API.
"""
...

async def on_close(self) -> None:
"""Handle close of RealtimeClient."""
...
5 changes: 3 additions & 2 deletions notebook/agentchat_realtime_websocket.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"from fastapi.templating import Jinja2Templates\n",
"\n",
"import autogen\n",
"from autogen.agentchat.realtime.experimental import RealtimeAgent, WebSocketAudioAdapter"
"from autogen.agentchat.realtime.experimental import AudioObserver, RealtimeAgent, WebSocketAudioAdapter"
]
},
{
Expand Down Expand Up @@ -268,6 +268,7 @@
" llm_config=realtime_llm_config,\n",
" audio_adapter=audio_adapter,\n",
" logger=logger,\n",
" observers=[AudioObserver(logger=logger)],\n",
" )\n",
"\n",
" @realtime_agent.register_realtime_function(name=\"get_weather\", description=\"Get the current weather\")\n",
Expand Down Expand Up @@ -317,7 +318,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
"version": "3.11.3"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 5d2f165

Please sign in to comment.