diff --git a/.github/workflows/contrib-openai.yml b/.github/workflows/contrib-openai.yml index 5d74c59022..a839af92ac 100644 --- a/.github/workflows/contrib-openai.yml +++ b/.github/workflows/contrib-openai.yml @@ -174,9 +174,8 @@ jobs: run: | docker --version python -m pip install --upgrade pip wheel - pip install -e .[teachable] + pip install -e .[teachable,test] python -c "import autogen" - pip install pytest-cov>=5 - name: Coverage env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -211,7 +210,7 @@ jobs: run: | docker --version python -m pip install --upgrade pip wheel - pip install -e . + pip install -e ".[test]" python -c "import autogen" pip install pytest-cov>=5 pytest-asyncio - name: Install packages for test when needed @@ -290,9 +289,8 @@ jobs: run: | docker --version python -m pip install --upgrade pip wheel - pip install -e .[lmm] + pip install -e .[lmm,test] python -c "import autogen" - pip install pytest-cov>=5 - name: Coverage env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/.github/workflows/openai.yml b/.github/workflows/openai.yml index d7dbbba492..fbe3ed6f3c 100644 --- a/.github/workflows/openai.yml +++ b/.github/workflows/openai.yml @@ -48,9 +48,8 @@ jobs: run: | docker --version python -m pip install --upgrade pip wheel - pip install -e. + pip install -e ".[test]" python -c "import autogen" - pip install pytest-cov>=5 pytest-asyncio - name: Install packages for test when needed if: matrix.python-version == '3.9' run: | diff --git a/autogen/agentchat/realtime_agent/__init__.py b/autogen/agentchat/realtime_agent/__init__.py index a3d258b1c3..709d87b926 100644 --- a/autogen/agentchat/realtime_agent/__init__.py +++ b/autogen/agentchat/realtime_agent/__init__.py @@ -1,6 +1,11 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + from .function_observer import FunctionObserver from .realtime_agent import RealtimeAgent -from .twilio_observer import TwilioAudioAdapter -from .websocket_observer import WebsocketAudioAdapter +from .realtime_observer import RealtimeObserver +from .twilio_audio_adapter import TwilioAudioAdapter +from .websocket_audio_adapter import WebSocketAudioAdapter -__all__ = ["RealtimeAgent", "FunctionObserver", "TwilioAudioAdapter", "WebsocketAudioAdapter"] +__all__ = ["FunctionObserver", "RealtimeAgent", "RealtimeObserver", "TwilioAudioAdapter", "WebSocketAudioAdapter"] diff --git a/autogen/agentchat/realtime_agent/client.py b/autogen/agentchat/realtime_agent/client.py deleted file mode 100644 index ca337a5769..0000000000 --- a/autogen/agentchat/realtime_agent/client.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai -# -# SPDX-License-Identifier: Apache-2.0 -# -# Portions derived from https://github.com/microsoft/autogen are under the MIT License. -# SPDX-License-Identifier: MIT - -# import asyncio -import json -import logging -from typing import TYPE_CHECKING, Any, Optional - -from asyncer import TaskGroup, asyncify, create_task_group -from websockets import connect -from websockets.asyncio.client import ClientConnection - -from ..contrib.swarm_agent import AfterWorkOption, SwarmAgent, initiate_swarm_chat - -if TYPE_CHECKING: - from .function_observer import FunctionObserver - from .realtime_agent import RealtimeAgent - from .realtime_observer import RealtimeObserver - -logger = logging.getLogger(__name__) - - -class OpenAIRealtimeClient: - """(Experimental) Client for OpenAI Realtime API.""" - - def __init__( - self, agent: "RealtimeAgent", audio_adapter: "RealtimeObserver", function_observer: "FunctionObserver" - ) -> None: - """(Experimental) Client for OpenAI Realtime API. - - Args: - agent (RealtimeAgent): The agent that the client is associated with. - audio_adapter (RealtimeObserver): The audio adapter for the client. - function_observer (FunctionObserver): The function observer for the client. - - """ - self._agent = agent - self._observers: list["RealtimeObserver"] = [] - self._openai_ws: Optional[ClientConnection] = None # todo factor out to OpenAIClient - self.register(audio_adapter) - self.register(function_observer) - - # LLM config - llm_config = self._agent.llm_config - - config: dict[str, Any] = llm_config["config_list"][0] # type: ignore[index] - - self.model: str = config["model"] - self.temperature: float = llm_config["temperature"] # type: ignore[index] - self.api_key: str = config["api_key"] - - # create a task group to manage the tasks - self.tg: Optional[TaskGroup] = None - - @property - def openai_ws(self) -> ClientConnection: - """Get the OpenAI WebSocket connection.""" - if self._openai_ws is None: - raise RuntimeError("OpenAI WebSocket is not initialized") - return self._openai_ws - - def register(self, observer: "RealtimeObserver") -> None: - """Register an observer to the client.""" - observer.register_client(self) - self._observers.append(observer) - - async def notify_observers(self, message: dict[str, Any]) -> None: - """Notify all observers of a message from the OpenAI Realtime API. - - Args: - message (dict[str, Any]): The message from the OpenAI Realtime API. - - """ - for observer in self._observers: - await observer.update(message) - - async def function_result(self, call_id: str, result: str) -> None: - """Send the result of a function call to the OpenAI Realtime API. - - Args: - call_id (str): The ID of the function call. - result (str): The result of the function call. - """ - result_item = { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": call_id, - "output": result, - }, - } - if self._openai_ws is None: - raise RuntimeError("OpenAI WebSocket is not initialized") - - await self._openai_ws.send(json.dumps(result_item)) - await self._openai_ws.send(json.dumps({"type": "response.create"})) - - async def send_text(self, *, role: str, text: str) -> None: - """Send a text message to the OpenAI Realtime API. - - Args: - role (str): The role of the message. - text (str): The text of the message. - """ - - if self._openai_ws is None: - raise RuntimeError("OpenAI WebSocket is not initialized") - - await self._openai_ws.send(json.dumps({"type": "response.cancel"})) - text_item = { - "type": "conversation.item.create", - "item": {"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]}, - } - await self._openai_ws.send(json.dumps(text_item)) - await self._openai_ws.send(json.dumps({"type": "response.create"})) - - # todo override in specific clients - async def initialize_session(self) -> None: - """Control initial session with OpenAI.""" - session_update = { - # todo: move to config - "turn_detection": {"type": "server_vad"}, - "voice": self._agent.voice, - "instructions": self._agent.system_message, - "modalities": ["audio", "text"], - "temperature": 0.8, - } - await self.session_update(session_update) - - # todo override in specific clients - async def session_update(self, session_options: dict[str, Any]) -> None: - """Send a session update to the OpenAI Realtime API. - - Args: - session_options (dict[str, Any]): The session options to update. - """ - if self._openai_ws is None: - raise RuntimeError("OpenAI WebSocket is not initialized") - - update = {"type": "session.update", "session": session_options} - logger.info("Sending session update:", json.dumps(update)) - await self._openai_ws.send(json.dumps(update)) - logger.info("Sending session update finished") - - async def _read_from_client(self) -> None: - """Read messages from the OpenAI Realtime API.""" - if self._openai_ws is None: - raise RuntimeError("OpenAI WebSocket is not initialized") - - try: - async for openai_message in self._openai_ws: - response = json.loads(openai_message) - await self.notify_observers(response) - except Exception as e: - logger.warning(f"Error in _read_from_client: {e}") - - async def run(self) -> None: - """Run the client.""" - async with connect( - f"wss://api.openai.com/v1/realtime?model={self.model}", - additional_headers={ - "Authorization": f"Bearer {self.api_key}", - "OpenAI-Beta": "realtime=v1", - }, - ) as openai_ws: - self._openai_ws = openai_ws - await self.initialize_session() - async with create_task_group() as tg: - self.tg = tg - self.tg.soonify(self._read_from_client)() - for observer in self._observers: - self.tg.soonify(observer.run)() - - initial_agent = self._agent._initial_agent - agents = self._agent._agents - user_agent = self._agent - - if not (initial_agent and agents): - raise RuntimeError("Swarm not registered.") - - if self._agent._start_swarm_chat: - self.tg.soonify(asyncify(initiate_swarm_chat))( - initial_agent=initial_agent, - agents=agents, - user_agent=user_agent, # type: ignore[arg-type] - messages="Find out what the user wants.", - after_work=AfterWorkOption.REVERT_TO_USER, - ) diff --git a/autogen/agentchat/realtime_agent/function_observer.py b/autogen/agentchat/realtime_agent/function_observer.py index 9e4c8d2649..89d49c792c 100644 --- a/autogen/agentchat/realtime_agent/function_observer.py +++ b/autogen/agentchat/realtime_agent/function_observer.py @@ -1,48 +1,37 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 -# -# Portions derived from https://github.com/microsoft/autogen are under the MIT License. -# SPDX-License-Identifier: MIT import asyncio import json -import logging -from typing import TYPE_CHECKING, Any +from logging import Logger, getLogger +from typing import TYPE_CHECKING, Any, Optional from asyncer import asyncify from pydantic import BaseModel from .realtime_observer import RealtimeObserver -if TYPE_CHECKING: - from .realtime_agent import RealtimeAgent - -logger = logging.getLogger(__name__) - class FunctionObserver(RealtimeObserver): """Observer for handling function calls from the OpenAI Realtime API.""" - def __init__(self, agent: "RealtimeAgent") -> None: - """Observer for handling function calls from the OpenAI Realtime API. - - Args: - agent (RealtimeAgent): The realtime agent attached to the observer. - """ - super().__init__() - self._agent = agent + def __init__(self, *, logger: Optional[Logger] = None) -> None: + """Observer for handling function calls from the OpenAI Realtime API.""" + super().__init__(logger=logger) - async def update(self, response: dict[str, Any]) -> None: + async def on_event(self, event: dict[str, Any]) -> None: """Handle function call events from the OpenAI Realtime API. Args: - response (dict[str, Any]): The response from the OpenAI Realtime API. + event (dict[str, Any]): The event from the OpenAI Realtime API. """ - if response.get("type") == "response.function_call_arguments.done": - logger.info(f"Received event: {response['type']}", response) + if event["type"] == "response.function_call_arguments.done": + self.logger.info(f"Received event: {event['type']}", event) await self.call_function( - call_id=response["call_id"], name=response["name"], kwargs=json.loads(response["arguments"]) + call_id=event["call_id"], + name=event["name"], + kwargs=json.loads(event["arguments"]), ) async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None: @@ -54,33 +43,37 @@ async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) - kwargs (Any[str, Any]): The arguments to pass to the function. """ - if name in self._agent.realtime_functions: - _, func = self._agent.realtime_functions[name] + if name in self.agent._registred_realtime_functions: + _, func = self.agent._registred_realtime_functions[name] func = func if asyncio.iscoroutinefunction(func) else asyncify(func) try: result = await func(**kwargs) except Exception: result = "Function call failed" - logger.warning(f"Function call failed: {name}") + self.logger.info(f"Function call failed: {name=}, {kwargs=}", stack_info=True) if isinstance(result, BaseModel): result = result.model_dump_json() elif not isinstance(result, str): - result = json.dumps(result) + try: + result = json.dumps(result) + except Exception: + result = str(result) - await self.client.function_result(call_id, result) - - async def run(self) -> None: - """Run the observer. - - Initialize the session with the OpenAI Realtime API. - """ - await self.initialize_session() + await self.realtime_client.send_function_result(call_id, result) async def initialize_session(self) -> None: """Add registered tools to OpenAI with a session update.""" session_update = { - "tools": [schema for schema, _ in self._agent.realtime_functions.values()], + "tools": [schema for schema, _ in self.agent._registred_realtime_functions.values()], "tool_choice": "auto", } - await self.client.session_update(session_update) + await self.realtime_client.session_update(session_update) + + async def run_loop(self) -> None: + """Run the observer loop.""" + pass + + +if TYPE_CHECKING: + function_observer: RealtimeObserver = FunctionObserver() diff --git a/autogen/agentchat/realtime_agent/oai_realtime_client.py b/autogen/agentchat/realtime_agent/oai_realtime_client.py new file mode 100644 index 0000000000..d074c20eb8 --- /dev/null +++ b/autogen/agentchat/realtime_agent/oai_realtime_client.py @@ -0,0 +1,175 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import asynccontextmanager +from logging import Logger, getLogger +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional + +from asyncer import TaskGroup, create_task_group +from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, AsyncOpenAI +from openai.resources.beta.realtime.realtime import AsyncRealtimeConnection + +from .realtime_client import Role + +if TYPE_CHECKING: + from .realtime_client import RealtimeClientProtocol + +__all__ = ["OpenAIRealtimeClient", "Role"] + +global_logger = getLogger(__name__) + + +class OpenAIRealtimeClient: + """(Experimental) Client for OpenAI Realtime API.""" + + def __init__( + self, + *, + llm_config: dict[str, Any], + voice: str, + system_message: str, + logger: Optional[Logger] = None, + ) -> None: + """(Experimental) Client for OpenAI Realtime API. + + Args: + llm_config (dict[str, Any]): The config for the client. + """ + self._llm_config = llm_config + self._voice = voice + self._system_message = system_message + self._logger = logger + + self._connection: Optional[AsyncRealtimeConnection] = None + + config = llm_config["config_list"][0] + self._model: str = config["model"] + self._temperature: float = llm_config.get("temperature", 0.8) # type: ignore[union-attr] + + self._client = AsyncOpenAI( + api_key=config.get("api_key", None), + organization=config.get("organization", None), + project=config.get("project", None), + base_url=config.get("base_url", None), + websocket_base_url=config.get("websocket_base_url", None), + timeout=config.get("timeout", NOT_GIVEN), + max_retries=config.get("max_retries", DEFAULT_MAX_RETRIES), + default_headers=config.get("default_headers", None), + default_query=config.get("default_query", None), + ) + + @property + def logger(self) -> Logger: + """Get the logger for the OpenAI Realtime API.""" + return self._logger or global_logger + + @property + def connection(self) -> AsyncRealtimeConnection: + """Get the OpenAI WebSocket connection.""" + if self._connection is None: + raise RuntimeError("OpenAI WebSocket is not initialized") + return self._connection + + async def send_function_result(self, call_id: str, result: str) -> None: + """Send the result of a function call to the OpenAI Realtime API. + + Args: + call_id (str): The ID of the function call. + result (str): The result of the function call. + """ + await self.connection.conversation.item.create( + item={ + "type": "function_call_output", + "call_id": call_id, + "output": result, + }, + ) + + await self.connection.response.create() + + async def send_text(self, *, role: Role, text: str) -> None: + """Send a text message to the OpenAI Realtime API. + + Args: + role (str): The role of the message. + text (str): The text of the message. + """ + await self.connection.response.cancel() + await self.connection.conversation.item.create( + item={"type": "message", "role": role, "content": [{"type": "input_text", "text": text}]} + ) + await self.connection.response.create() + + async def send_audio(self, audio: str) -> None: + """Send audio to the OpenAI Realtime API. + + Args: + audio (str): The audio to send. + """ + await self.connection.input_audio_buffer.append(audio=audio) + + async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None: + """Truncate audio in the OpenAI Realtime API. + + Args: + audio_end_ms (int): The end of the audio to truncate. + content_index (int): The index of the content to truncate. + item_id (str): The ID of the item to truncate. + """ + await self.connection.conversation.item.truncate( + audio_end_ms=audio_end_ms, content_index=content_index, item_id=item_id + ) + + async def _initialize_session(self) -> None: + """Control initial session with OpenAI.""" + session_update = { + "turn_detection": {"type": "server_vad"}, + "voice": self._voice, + "instructions": self._system_message, + "modalities": ["audio", "text"], + "temperature": self._temperature, + } + await self.session_update(session_options=session_update) + + async def session_update(self, session_options: dict[str, Any]) -> None: + """Send a session update to the OpenAI Realtime API. + + Args: + session_options (dict[str, Any]): The session options to update. + """ + logger = self.logger + logger.info(f"Sending session update: {session_options}") + await self.connection.session.update(session=session_options) # type: ignore[arg-type] + logger.info("Sending session update finished") + + @asynccontextmanager + async def connect(self) -> AsyncGenerator[None, None]: + """Connect to the OpenAI Realtime API.""" + try: + async with self._client.beta.realtime.connect( + model=self._model, + ) as self._connection: + await self._initialize_session() + yield + finally: + self._connection = None + + async def read_events(self) -> AsyncGenerator[dict[str, Any], None]: + """Read messages from the OpenAI Realtime API.""" + if self._connection is None: + raise RuntimeError("Client is not connected, call connect() first.") + + try: + async for event in self._connection: + yield event.model_dump() + + finally: + self._connection = None + + +# needed for mypy to check if OpenAIRealtimeClient implements RealtimeClientProtocol +if TYPE_CHECKING: + _client: RealtimeClientProtocol = OpenAIRealtimeClient( + llm_config={}, voice="alloy", system_message="You are a helpful AI voice assistant." + ) diff --git a/autogen/agentchat/realtime_agent/realtime_agent.py b/autogen/agentchat/realtime_agent/realtime_agent.py index b4456715bb..c6150d5b6b 100644 --- a/autogen/agentchat/realtime_agent/realtime_agent.py +++ b/autogen/agentchat/realtime_agent/realtime_agent.py @@ -1,33 +1,25 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 -# -# Portions derived from https://github.com/microsoft/autogen are under the MIT License. -# SPDX-License-Identifier: MIT -import asyncio -import json -import logging -from abc import ABC, abstractmethod -from collections.abc import Generator -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union +from logging import Logger, getLogger +from typing import Any, Callable, Literal, Optional, TypeVar, Union import anyio -import websockets -from asyncer import TaskGroup, asyncify, create_task_group, syncify +from asyncer import create_task_group, syncify -from autogen import ON_CONDITION, AfterWorkOption, SwarmAgent, initiate_swarm_chat -from autogen.agentchat.agent import Agent, LLMAgent +from autogen import SwarmAgent +from autogen.agentchat.agent import Agent from autogen.agentchat.conversable_agent import ConversableAgent from autogen.function_utils import get_function_schema -from .client import OpenAIRealtimeClient from .function_observer import FunctionObserver +from .oai_realtime_client import OpenAIRealtimeClient, Role from .realtime_observer import RealtimeObserver F = TypeVar("F", bound=Callable[..., Any]) -logger = logging.getLogger(__name__) +global_logger = getLogger(__name__) SWARM_SYSTEM_MESSAGE = ( "You are a helpful voice assistant. Your task is to listen to user and to coordinate the tasks based on his/her inputs." @@ -35,7 +27,7 @@ "You can communicate and will communicate using audio output only." ) -QUESTION_ROLE = "user" +QUESTION_ROLE: Role = "user" QUESTION_MESSAGE = ( "I have a question/information for myself. DO NOT ANSWER YOURSELF, GET THE ANSWER FROM ME. " "repeat the question to me **WITH AUDIO OUTPUT** and then call 'answer_task_question' AFTER YOU GET THE ANSWER FROM ME\n\n" @@ -52,23 +44,19 @@ def __init__( *, name: str, audio_adapter: RealtimeObserver, - system_message: Optional[Union[str, list[str]]] = "You are a helpful AI Assistant.", - llm_config: Optional[Union[dict[str, Any], Literal[False]]] = None, + system_message: str = "You are a helpful AI Assistant.", + llm_config: dict[str, Any], voice: str = "alloy", + logger: Optional[Logger] = None, ): """(Experimental) Agent for interacting with the Realtime Clients. Args: - name: str - the name of the agent - audio_adapter: RealtimeObserver - adapter for streaming the audio from the client - system_message: str or list - the system message for the client - llm_config: dict or False - the config for the LLM - voice: str - the voice to be used for the agent + name (str): The name of the agent. + audio_adapter (RealtimeObserver): The audio adapter for the agent. + system_message (str): The system message for the agent. + llm_config (dict[str, Any], bool): The config for the agent. + voice (str): The voice for the agent. """ super().__init__( name=name, @@ -77,18 +65,28 @@ def __init__( human_input_mode="ALWAYS", function_map=None, code_execution_config=False, + # no LLM config is passed down to the ConversableAgent + llm_config=False, default_auto_reply="", description=None, chat_messages=None, silent=None, context_variables=None, ) - self.llm_config = llm_config # type: ignore[assignment] - self._client = OpenAIRealtimeClient(self, audio_adapter, FunctionObserver(self)) - self.voice = voice - self.realtime_functions: dict[str, tuple[dict[str, Any], Callable[..., Any]]] = {} + self._logger = logger + self._function_observer = FunctionObserver(logger=logger) + self._audio_adapter = audio_adapter + self._realtime_client = OpenAIRealtimeClient( + llm_config=llm_config, voice=voice, system_message=system_message, logger=logger + ) + self._voice = voice + + self._observers: list[RealtimeObserver] = [self._function_observer, self._audio_adapter] + + self._registred_realtime_functions: dict[str, tuple[dict[str, Any], Callable[..., Any]]] = {} - self._oai_system_message = [{"content": system_message, "role": "system"}] # todo still needed? + # is this all Swarm related? + self._oai_system_message = [{"content": system_message, "role": "system"}] # todo still needed? see below self.register_reply( [Agent, None], RealtimeAgent.check_termination_and_human_reply, remove_other_reply_funcs=True ) @@ -99,6 +97,24 @@ def __init__( self._initial_agent: Optional[SwarmAgent] = None self._agents: Optional[list[SwarmAgent]] = None + @property + def logger(self) -> Logger: + """Get the logger for the agent.""" + return self._logger or global_logger + + @property + def realtime_client(self) -> OpenAIRealtimeClient: + """Get the OpenAI Realtime Client.""" + return self._realtime_client + + def register_observer(self, observer: RealtimeObserver) -> None: + """Register an observer with the Realtime Agent. + + Args: + observer (RealtimeObserver): The observer to register. + """ + self._observers.append(observer) + def register_swarm( self, *, @@ -109,13 +125,11 @@ def register_swarm( """Register a swarm of agents with the Realtime Agent. Args: - initial_agent: SwarmAgent - the initial agent in the swarm - agents: list of SwarmAgent - the agents in the swarm - system_message: str - the system message for the client + initial_agent (SwarmAgent): The initial agent. + agents (list[SwarmAgent]): The agents in the swarm. + system_message (str): The system message for the agent. """ + logger = self.logger if not system_message: if self.system_message != "You are a helpful AI Assistant.": logger.warning( @@ -135,7 +149,24 @@ def register_swarm( async def run(self) -> None: """Run the agent.""" - await self._client.run() + # everything is run in the same task group to enable easy cancellation using self._tg.cancel_scope.cancel() + async with create_task_group() as self._tg: + + # connect with the client first (establishes a connection and initializes a session) + async with self._realtime_client.connect(): + + # start the observers + for observer in self._observers: + self._tg.soonify(observer.run)(self) + + # wait for the observers to be ready + for observer in self._observers: + await observer.wait_for_ready() + + # iterate over the events + async for event in self.realtime_client.read_events(): + for observer in self._observers: + await observer.on_event(event) def register_realtime_function( self, @@ -163,7 +194,7 @@ def _decorator(func: F, name: Optional[str] = name) -> F: schema = get_function_schema(func, name=name, description=description)["function"] schema["type"] = "function" - self.realtime_functions[name] = (schema, func) + self._registred_realtime_functions[name] = (schema, func) return func @@ -195,7 +226,7 @@ async def ask_question(self, question: str, question_timeout: int) -> None: """ self.reset_answer() - await self._client.send_text(role=QUESTION_ROLE, text=question) + await self._realtime_client.send_text(role=QUESTION_ROLE, text=question) async def _check_event_set(timeout: int = question_timeout) -> bool: for _ in range(timeout): @@ -205,7 +236,7 @@ async def _check_event_set(timeout: int = question_timeout) -> bool: return False while not await _check_event_set(): - await self._client.send_text(role=QUESTION_ROLE, text=question) + await self._realtime_client.send_text(role=QUESTION_ROLE, text=question) def check_termination_and_human_reply( self, diff --git a/autogen/agentchat/realtime_agent/realtime_client.py b/autogen/agentchat/realtime_agent/realtime_client.py new file mode 100644 index 0000000000..2195a78f83 --- /dev/null +++ b/autogen/agentchat/realtime_agent/realtime_client.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, AsyncContextManager, AsyncGenerator, Literal, Protocol, runtime_checkable + +__all__ = ["RealtimeClientProtocol", "Role"] + +# define role literal type for typing +Role = Literal["user", "assistant", "system"] + + +@runtime_checkable +class RealtimeClientProtocol(Protocol): + async def send_function_result(self, call_id: str, result: str) -> None: + """Send the result of a function call to a Realtime API. + + Args: + call_id (str): The ID of the function call. + result (str): The result of the function call. + """ + ... + + async def send_text(self, *, role: Role, text: str) -> None: + """Send a text message to a Realtime API. + + Args: + role (str): The role of the message. + text (str): The text of the message. + """ + ... + + async def send_audio(self, audio: str) -> None: + """Send audio to a Realtime API. + + Args: + audio (str): The audio to send. + """ + ... + + async def truncate_audio(self, audio_end_ms: int, content_index: int, item_id: str) -> None: + """Truncate audio in a Realtime API. + + Args: + audio_end_ms (int): The end of the audio to truncate. + content_index (int): The index of the content to truncate. + item_id (str): The ID of the item to truncate. + """ + ... + + async def session_update(self, session_options: dict[str, Any]) -> None: + """Send a session update to a Realtime API. + + Args: + session_options (dict[str, Any]): The session options to update. + """ + ... + + def connect(self) -> AsyncContextManager[None]: ... + + def read_events(self) -> AsyncGenerator[dict[str, Any], None]: + """Read messages from a Realtime API.""" + ... diff --git a/autogen/agentchat/realtime_agent/realtime_observer.py b/autogen/agentchat/realtime_agent/realtime_observer.py index 6061efb230..8b4b9c5eb3 100644 --- a/autogen/agentchat/realtime_agent/realtime_observer.py +++ b/autogen/agentchat/realtime_agent/realtime_observer.py @@ -1,41 +1,93 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 -# -# Portions derived from https://github.com/microsoft/autogen are under the MIT License. -# SPDX-License-Identifier: MIT from abc import ABC, abstractmethod +from logging import Logger, getLogger from typing import TYPE_CHECKING, Any, Optional +from anyio import Event + +from .realtime_client import RealtimeClientProtocol + if TYPE_CHECKING: - from .client import OpenAIRealtimeClient + from .realtime_agent import RealtimeAgent + +__all__ = ["RealtimeObserver"] + +global_logger = getLogger(__name__) class RealtimeObserver(ABC): """Observer for the OpenAI Realtime API.""" - def __init__(self) -> None: - self._client: Optional["OpenAIRealtimeClient"] = None + def __init__(self, *, logger: Optional[Logger] = None) -> None: + """Observer for the OpenAI Realtime API. + + Args: + logger (Logger): The logger for the observer. + """ + self._ready_event = Event() + self._agent: Optional["RealtimeAgent"] = None + self._logger = logger + + @property + def logger(self) -> Logger: + return self._logger or global_logger + + @property + def agent(self) -> "RealtimeAgent": + if self._agent is None: + raise RuntimeError("Agent has not been set.") + return self._agent @property - def client(self) -> "OpenAIRealtimeClient": - """Get the client associated with the observer.""" - if self._client is None: - raise ValueError("Observer client is not registered.") + def realtime_client(self) -> RealtimeClientProtocol: + if self._agent is None: + raise RuntimeError("Agent has not been set.") + if self._agent.realtime_client is None: + raise RuntimeError("Realtime client has not been set.") - return self._client + return self._agent.realtime_client - def register_client(self, client: "OpenAIRealtimeClient") -> None: - """Register a client with the observer.""" - self._client = client + async def run(self, agent: "RealtimeAgent") -> None: + """Run the observer with the agent. + + When implementing, be sure to call `self._ready_event.set()` when the observer is ready to process events. + + Args: + agent (RealtimeAgent): The realtime agent attached to the observer. + """ + self._agent = agent + await self.initialize_session() + self._ready_event.set() + + await self.run_loop() @abstractmethod - async def run(self) -> None: - """Run the observer.""" + async def run_loop(self) -> None: + """Run the loop if needed. + + This method is called after the observer is ready to process events. + Events will be processed by the on_event method, this is just a hook for additional processing. + Use initialize_session to set up the session. + """ ... @abstractmethod - async def update(self, message: dict[str, Any]) -> None: - """Update the observer with a message from the OpenAI Realtime API.""" + async def initialize_session(self) -> None: + """Initialize the session for the observer.""" + ... + + async def wait_for_ready(self) -> None: + """Get the event that is set when the observer is ready.""" + await self._ready_event.wait() + + @abstractmethod + async def on_event(self, event: dict[str, Any]) -> None: + """Handle an event from the OpenAI Realtime API. + + Args: + event (RealtimeServerEvent): The event from the OpenAI Realtime API. + """ ... diff --git a/autogen/agentchat/realtime_agent/twilio_observer.py b/autogen/agentchat/realtime_agent/twilio_audio_adapter.py similarity index 62% rename from autogen/agentchat/realtime_agent/twilio_observer.py rename to autogen/agentchat/realtime_agent/twilio_audio_adapter.py index 7dff6d4bdd..c8b275c5e7 100644 --- a/autogen/agentchat/realtime_agent/twilio_observer.py +++ b/autogen/agentchat/realtime_agent/twilio_audio_adapter.py @@ -1,20 +1,21 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 -# -# Portions derived from https://github.com/microsoft/autogen are under the MIT License. -# SPDX-License-Identifier: MIT import base64 import json -import logging +from logging import Logger, getLogger from typing import TYPE_CHECKING, Any, Optional +from openai.types.beta.realtime.realtime_server_event import RealtimeServerEvent + from .realtime_observer import RealtimeObserver if TYPE_CHECKING: from fastapi.websockets import WebSocket + from .realtime_agent import RealtimeAgent + LOG_EVENT_TYPES = [ "error", "response.content.done", @@ -27,36 +28,36 @@ ] SHOW_TIMING_MATH = False -logger = logging.getLogger(__name__) - class TwilioAudioAdapter(RealtimeObserver): """Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa.""" - def __init__(self, websocket: "WebSocket"): + def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None): """Adapter for streaming audio from Twilio to OpenAI Realtime API and vice versa. Args: websocket: WebSocket the websocket connection to the Twilio service """ - super().__init__() + super().__init__(logger=logger) self.websocket = websocket # Connection specific state self.stream_sid = None self.latest_media_timestamp = 0 - self.last_assistant_item = None + self.last_assistant_item: Optional[str] = None self.mark_queue: list[str] = [] self.response_start_timestamp_twilio: Optional[int] = None - async def update(self, response: dict[str, Any]) -> None: + async def on_event(self, event: dict[str, Any]) -> None: """Receive events from the OpenAI Realtime API, send audio back to Twilio.""" - if response["type"] in LOG_EVENT_TYPES: - logger.info(f"Received event: {response['type']}", response) + logger = self.logger - if response.get("type") == "response.audio.delta" and "delta" in response: - audio_payload = base64.b64encode(base64.b64decode(response["delta"])).decode("utf-8") + if event["type"] in LOG_EVENT_TYPES: + logger.info(f"Received event: {event['type']}", event) + + if event["type"] == "response.audio.delta": + audio_payload = base64.b64encode(base64.b64decode(event["delta"])).decode("utf-8") audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}} await self.websocket.send_json(audio_delta) @@ -66,13 +67,13 @@ async def update(self, response: dict[str, Any]) -> None: logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_twilio}ms") # Update last_assistant_item safely - if response.get("item_id"): - self.last_assistant_item = response["item_id"] + if event["item_id"]: + self.last_assistant_item = event["item_id"] await self.send_mark() # Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two. - if response.get("type") == "input_audio_buffer.speech_started": + if event["type"] == "input_audio_buffer.speech_started": logger.info("Speech started detected.") if self.last_assistant_item: logger.info(f"Interrupting response with id: {self.last_assistant_item}") @@ -80,6 +81,8 @@ async def update(self, response: dict[str, Any]) -> None: async def handle_speech_started_event(self) -> None: """Handle interruption when the caller's speech starts.""" + logger = self.logger + logger.info("Handling speech started event.") if self.mark_queue and self.response_start_timestamp_twilio is not None: elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_twilio @@ -92,13 +95,11 @@ async def handle_speech_started_event(self) -> None: if SHOW_TIMING_MATH: logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms") - truncate_event = { - "type": "conversation.item.truncate", - "item_id": self.last_assistant_item, - "content_index": 0, - "audio_end_ms": elapsed_time, - } - await self._client._openai_ws.send(json.dumps(truncate_event)) + await self.realtime_client.truncate_audio( + audio_end_ms=elapsed_time, + content_index=0, + item_id=self.last_assistant_item, + ) await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid}) @@ -113,29 +114,27 @@ async def send_mark(self) -> None: await self.websocket.send_json(mark_event) self.mark_queue.append("responsePart") - async def run(self) -> None: - """Run the adapter. - - Start reading messages from the Twilio websocket and send audio to OpenAI. - """ - openai_ws = self.client.openai_ws - await self.initialize_session() + async def run_loop(self) -> None: + """Run the adapter loop.""" + logger = self.logger async for message in self.websocket.iter_text(): - data = json.loads(message) - if data["event"] == "media": - self.latest_media_timestamp = int(data["media"]["timestamp"]) - audio_append = {"type": "input_audio_buffer.append", "audio": data["media"]["payload"]} - await openai_ws.send(json.dumps(audio_append)) - elif data["event"] == "start": - self.stream_sid = data["start"]["streamSid"] - logger.info(f"Incoming stream has started {self.stream_sid}") - self.response_start_timestamp_twilio = None - self.latest_media_timestamp = 0 - self.last_assistant_item = None - elif data["event"] == "mark": - if self.mark_queue: - self.mark_queue.pop(0) + try: + data = json.loads(message) + if data["event"] == "media": + self.latest_media_timestamp = int(data["media"]["timestamp"]) + await self.realtime_client.send_audio(audio=data["media"]["payload"]) + elif data["event"] == "start": + self.stream_sid = data["start"]["streamSid"] + logger.info(f"Incoming stream has started {self.stream_sid}") + self.response_start_timestamp_twilio = None + self.latest_media_timestamp = 0 + self.last_assistant_item = None + elif data["event"] == "mark": + if self.mark_queue: + self.mark_queue.pop(0) + except Exception as e: + logger.warning(f"Error processing Twilio message: {e}", stack_info=True) async def initialize_session(self) -> None: """Control initial session with OpenAI.""" @@ -143,4 +142,10 @@ async def initialize_session(self) -> None: "input_audio_format": "g711_ulaw", "output_audio_format": "g711_ulaw", } - await self.client.session_update(session_update) + await self.realtime_client.session_update(session_update) + + +if TYPE_CHECKING: + + def twilio_audio_adapter(websocket: WebSocket) -> RealtimeObserver: + return TwilioAudioAdapter(websocket) diff --git a/autogen/agentchat/realtime_agent/websocket_audio_adapter.py b/autogen/agentchat/realtime_agent/websocket_audio_adapter.py new file mode 100644 index 0000000000..6a7e8e4656 --- /dev/null +++ b/autogen/agentchat/realtime_agent/websocket_audio_adapter.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import json +from logging import Logger, getLogger +from typing import TYPE_CHECKING, Any, Optional + +from openai.types.beta.realtime.realtime_server_event import RealtimeServerEvent + +if TYPE_CHECKING: + from fastapi.websockets import WebSocket + + from .realtime_agent import RealtimeAgent + +from .realtime_observer import RealtimeObserver + +LOG_EVENT_TYPES = [ + "error", + "response.content.done", + "rate_limits.updated", + "response.done", + "input_audio_buffer.committed", + "input_audio_buffer.speech_stopped", + "input_audio_buffer.speech_started", + "session.created", +] +SHOW_TIMING_MATH = False + + +class WebSocketAudioAdapter(RealtimeObserver): + def __init__(self, websocket: "WebSocket", *, logger: Optional[Logger] = None) -> None: + """Observer for handling function calls from the OpenAI Realtime API. + + Args: + websocket (WebSocket): The websocket connection. + logger (Logger): The logger for the observer. + """ + super().__init__(logger=logger) + self.websocket = websocket + + # Connection specific state + self.stream_sid = None + self.latest_media_timestamp = 0 + self.last_assistant_item: Optional[str] = None + self.mark_queue: list[str] = [] + self.response_start_timestamp_socket: Optional[int] = None + + async def on_event(self, event: dict[str, Any]) -> None: + """Receive events from the OpenAI Realtime API, send audio back to websocket.""" + logger = self.logger + if event["type"] in LOG_EVENT_TYPES: + logger.info(f"Received event: {event['type']}", event) + + if event["type"] == "response.audio.delta": + audio_payload = base64.b64encode(base64.b64decode(event["delta"])).decode("utf-8") + audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}} + await self.websocket.send_json(audio_delta) + + if self.response_start_timestamp_socket is None: + self.response_start_timestamp_socket = self.latest_media_timestamp + if SHOW_TIMING_MATH: + logger.info(f"Setting start timestamp for new response: {self.response_start_timestamp_socket}ms") + + # Update last_assistant_item safely + if event["item_id"]: + self.last_assistant_item = event["item_id"] + + await self.send_mark() + + # Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two. + if event["type"] == "input_audio_buffer.speech_started": + logger.info("Speech started detected.") + if self.last_assistant_item: + logger.info(f"Interrupting response with id: {self.last_assistant_item}") + await self.handle_speech_started_event() + + async def handle_speech_started_event(self) -> None: + """Handle interruption when the caller's speech starts.""" + logger = self.logger + logger.info("Handling speech started event.") + if self.mark_queue and self.response_start_timestamp_socket is not None: + elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_socket + if SHOW_TIMING_MATH: + logger.info( + f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_socket} = {elapsed_time}ms" + ) + + if self.last_assistant_item: + if SHOW_TIMING_MATH: + logger.info(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms") + + await self.realtime_client.truncate_audio( + audio_end_ms=elapsed_time, + content_index=0, + item_id=self.last_assistant_item, + ) + + await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid}) + + self.mark_queue.clear() + self.last_assistant_item = None + self.response_start_timestamp_socket = None + + async def send_mark(self) -> None: + if self.stream_sid: + mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}} + await self.websocket.send_json(mark_event) + self.mark_queue.append("responsePart") + + async def initialize_session(self) -> None: + """Control initial session with OpenAI.""" + session_update = {"input_audio_format": "pcm16", "output_audio_format": "pcm16"} + await self.realtime_client.session_update(session_update) + + async def run_loop(self) -> None: + """Reads data from websocket and sends it to the RealtimeClient.""" + logger = self.logger + async for message in self.websocket.iter_text(): + try: + data = json.loads(message) + if data["event"] == "media": + self.latest_media_timestamp = int(data["media"]["timestamp"]) + await self.realtime_client.send_audio(audio=data["media"]["payload"]) + elif data["event"] == "start": + self.stream_sid = data["start"]["streamSid"] + logger.info(f"Incoming stream has started {self.stream_sid}") + self.response_start_timestamp_socket = None + self.latest_media_timestamp = 0 + self.last_assistant_item = None + elif data["event"] == "mark": + if self.mark_queue: + self.mark_queue.pop(0) + except Exception as e: + logger.warning(f"Failed to process message: {e}", stack_info=True) + + +if TYPE_CHECKING: + + def websocket_audio_adapter(websocket: WebSocket) -> RealtimeObserver: + return WebSocketAudioAdapter(websocket) diff --git a/autogen/agentchat/realtime_agent/websocket_observer.py b/autogen/agentchat/realtime_agent/websocket_observer.py deleted file mode 100644 index dd0b67a87d..0000000000 --- a/autogen/agentchat/realtime_agent/websocket_observer.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai -# -# SPDX-License-Identifier: Apache-2.0 -# -# Portions derived from https://github.com/microsoft/autogen are under the MIT License. -# SPDX-License-Identifier: MIT - -import base64 -import json -from typing import TYPE_CHECKING, Any, Optional - -if TYPE_CHECKING: - from fastapi.websockets import WebSocket - -from .realtime_observer import RealtimeObserver - -LOG_EVENT_TYPES = [ - "error", - "response.content.done", - "rate_limits.updated", - "response.done", - "input_audio_buffer.committed", - "input_audio_buffer.speech_stopped", - "input_audio_buffer.speech_started", - "session.created", -] -SHOW_TIMING_MATH = False - - -class WebsocketAudioAdapter(RealtimeObserver): - def __init__(self, websocket: "WebSocket"): - super().__init__() - self.websocket = websocket - - # Connection specific state - self.stream_sid = None - self.latest_media_timestamp = 0 - self.last_assistant_item = None - self.mark_queue: list[str] = [] - self.response_start_timestamp_socket: Optional[int] = None - - async def update(self, response: dict[str, Any]) -> None: - """Receive events from the OpenAI Realtime API, send audio back to websocket.""" - if response["type"] in LOG_EVENT_TYPES: - print(f"Received event: {response['type']}", response) - - if response.get("type") == "response.audio.delta" and "delta" in response: - audio_payload = base64.b64encode(base64.b64decode(response["delta"])).decode("utf-8") - audio_delta = {"event": "media", "streamSid": self.stream_sid, "media": {"payload": audio_payload}} - await self.websocket.send_json(audio_delta) - - if self.response_start_timestamp_socket is None: - self.response_start_timestamp_socket = self.latest_media_timestamp - if SHOW_TIMING_MATH: - print(f"Setting start timestamp for new response: {self.response_start_timestamp_socket}ms") - - # Update last_assistant_item safely - if response.get("item_id"): - self.last_assistant_item = response["item_id"] - - await self.send_mark() - - # Trigger an interruption. Your use case might work better using `input_audio_buffer.speech_stopped`, or combining the two. - if response.get("type") == "input_audio_buffer.speech_started": - print("Speech started detected.") - if self.last_assistant_item: - print(f"Interrupting response with id: {self.last_assistant_item}") - await self.handle_speech_started_event() - - async def handle_speech_started_event(self) -> None: - """Handle interruption when the caller's speech starts.""" - print("Handling speech started event.") - if self.mark_queue and self.response_start_timestamp_socket is not None: - elapsed_time = self.latest_media_timestamp - self.response_start_timestamp_socket - if SHOW_TIMING_MATH: - print( - f"Calculating elapsed time for truncation: {self.latest_media_timestamp} - {self.response_start_timestamp_socket} = {elapsed_time}ms" - ) - - if self.last_assistant_item: - if SHOW_TIMING_MATH: - print(f"Truncating item with ID: {self.last_assistant_item}, Truncated at: {elapsed_time}ms") - - truncate_event = { - "type": "conversation.item.truncate", - "item_id": self.last_assistant_item, - "content_index": 0, - "audio_end_ms": elapsed_time, - } - await self._client._openai_ws.send(json.dumps(truncate_event)) - - await self.websocket.send_json({"event": "clear", "streamSid": self.stream_sid}) - - self.mark_queue.clear() - self.last_assistant_item = None - self.response_start_timestamp_socket = None - - async def send_mark(self) -> None: - if self.stream_sid: - mark_event = {"event": "mark", "streamSid": self.stream_sid, "mark": {"name": "responsePart"}} - await self.websocket.send_json(mark_event) - self.mark_queue.append("responsePart") - - async def run(self) -> None: - openai_ws = self.client.openai_ws - await self.initialize_session() - - async for message in self.websocket.iter_text(): - data = json.loads(message) - if data["event"] == "media": - self.latest_media_timestamp = int(data["media"]["timestamp"]) - audio_append = {"type": "input_audio_buffer.append", "audio": data["media"]["payload"]} - await openai_ws.send(json.dumps(audio_append)) - elif data["event"] == "start": - self.stream_sid = data["start"]["streamSid"] - print(f"Incoming stream has started {self.stream_sid}") - self.response_start_timestamp_socket = None - self.latest_media_timestamp = 0 - self.last_assistant_item = None - elif data["event"] == "mark": - if self.mark_queue: - self.mark_queue.pop(0) - - async def initialize_session(self) -> None: - """Control initial session with OpenAI.""" - session_update = {"input_audio_format": "pcm16", "output_audio_format": "pcm16"} # g711_ulaw # "g711_ulaw", - await self.client.session_update(session_update) diff --git a/notebook/agentchat_nested_chats_chess.ipynb b/notebook/agentchat_nested_chats_chess.ipynb index 69a604c53b..1975509f55 100644 --- a/notebook/agentchat_nested_chats_chess.ipynb +++ b/notebook/agentchat_nested_chats_chess.ipynb @@ -289,7 +289,7 @@ "to make a move, before communicating with the other player agent.\n", "\n", "In the code below, in each nested chat, the board proxy agent starts\n", - "a conversation with the player agent using the message recieved from the other\n", + "a conversation with the player agent using the message received from the other\n", "player agent (e.g., \"Your move\"). The two agents continue the conversation\n", "until a legal move is made using the `make_move` tool.\n", "The last message in the nested chat is a message from the player agent about\n", diff --git a/notebook/agentchat_realtime_swarm.ipynb b/notebook/agentchat_realtime_swarm.ipynb index 088ffc3caa..b7ade2c755 100644 --- a/notebook/agentchat_realtime_swarm.ipynb +++ b/notebook/agentchat_realtime_swarm.ipynb @@ -43,6 +43,15 @@ "````" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install \"ag2[twilio]\"" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -54,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -80,14 +89,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "realtime_config_list = autogen.config_list_from_json(\n", " \"OAI_CONFIG_LIST\",\n", " filter_dict={\n", - " \"tags\": [\"realtime\"],\n", + " \"tags\": [\"gpt-4o-mini-realtime\"],\n", " },\n", ")\n", "\n", @@ -103,7 +112,7 @@ " {\n", " \"model\": \"gpt-4o-realtime-preview\",\n", " \"api_key\": \"sk-***********************...*\",\n", - " \"tags\": [\"gpt-4o-realtime\", \"realtime\"]\n", + " \"tags\": [\"gpt-4o-mini-realtime\", \"realtime\"]\n", " }\"\"\"\n", ")" ] @@ -119,7 +128,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -261,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -321,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -367,7 +376,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ diff --git a/notebook/agentchat_realtime_tool.ipynb b/notebook/agentchat_realtime_tool.ipynb deleted file mode 100644 index 77e7bea86c..0000000000 --- a/notebook/agentchat_realtime_tool.ipynb +++ /dev/null @@ -1,160 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "import asyncio\n", - "import os\n", - "import time\n", - "from typing import Annotated, Union\n", - "\n", - "import nest_asyncio\n", - "import uvicorn\n", - "from fastapi import FastAPI, Request, WebSocket\n", - "from fastapi.responses import HTMLResponse, JSONResponse\n", - "from pydantic import BaseModel\n", - "from twilio.twiml.voice_response import Connect, VoiceResponse\n", - "\n", - "from autogen.agentchat.realtime_agent import FunctionObserver, RealtimeAgent, TwilioAudioAdapter\n", - "\n", - "# from autogen.agentchat.realtime_agent.swarm_observer import SwarmObserver" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "# Configuration\n", - "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n", - "PORT = int(os.getenv(\"PORT\", 5050))\n", - "\n", - "if not OPENAI_API_KEY:\n", - " raise ValueError(\"Missing the OpenAI API key. Please set it in the .env file.\")\n", - "\n", - "llm_config = {\n", - " \"timeout\": 600,\n", - " \"cache_seed\": 45, # change the seed for different trials\n", - " \"config_list\": [\n", - " {\n", - " \"model\": \"gpt-4o-realtime-preview-2024-10-01\",\n", - " \"api_key\": OPENAI_API_KEY,\n", - " }\n", - " ],\n", - " \"temperature\": 0.8,\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "nest_asyncio.apply()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: Started server process [3628527]\n", - "INFO: Waiting for application startup.\n", - "INFO: Application startup complete.\n", - "INFO: Uvicorn running on http://0.0.0.0:5050 (Press CTRL+C to quit)\n", - "INFO: Shutting down\n", - "INFO: Waiting for application shutdown.\n", - "INFO: Application shutdown complete.\n", - "INFO: Finished server process [3628527]\n" - ] - } - ], - "source": [ - "app = FastAPI()\n", - "\n", - "\n", - "@app.get(\"/\", response_class=JSONResponse)\n", - "async def index_page():\n", - " return {\"message\": \"Twilio Media Stream Server is running!\"}\n", - "\n", - "\n", - "@app.api_route(\"/incoming-call\", methods=[\"GET\", \"POST\"])\n", - "async def handle_incoming_call(request: Request):\n", - " \"\"\"Handle incoming call and return TwiML response to connect to Media Stream.\"\"\"\n", - " response = VoiceResponse()\n", - " # punctuation to improve text-to-speech flow\n", - " response.say(\n", - " \"Please wait while we connect your call to the A. I. voice assistant, powered by Twilio and the Open-A.I. Realtime API\"\n", - " )\n", - " response.pause(length=1)\n", - " response.say(\"O.K. you can start talking!\")\n", - " host = request.url.hostname\n", - " connect = Connect()\n", - " connect.stream(url=f\"wss://{host}/media-stream\")\n", - " response.append(connect)\n", - " return HTMLResponse(content=str(response), media_type=\"application/xml\")\n", - "\n", - "\n", - "@app.websocket(\"/media-stream\")\n", - "async def handle_media_stream(websocket: WebSocket):\n", - " \"\"\"Handle WebSocket connections between Twilio and OpenAI.\"\"\"\n", - " await websocket.accept()\n", - "\n", - " audio_adapter = TwilioAudioAdapter(websocket)\n", - " openai_client = RealtimeAgent(\n", - " name=\"Weather Bot\",\n", - " system_message=\"Hello there! I am an AI voice assistant powered by Twilio and the OpenAI Realtime API. You can ask me for facts, jokes, or anything you can imagine. How can I help you?\",\n", - " llm_config=llm_config,\n", - " audio_adapter=audio_adapter,\n", - " )\n", - "\n", - " @openai_client.register_realtime_function(name=\"get_weather\", description=\"Get the current weather\")\n", - " def get_weather(location: Annotated[str, \"city\"]) -> str:\n", - " ...\n", - " return \"The weather is cloudy.\" if location == \"Seattle\" else \"The weather is sunny.\"\n", - "\n", - " await openai_client.run()\n", - "\n", - "\n", - "uvicorn.run(app, host=\"0.0.0.0\", port=PORT)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebook/agentchat_realtime_websocket.ipynb b/notebook/agentchat_realtime_websocket.ipynb index 934901f447..fe7019ad00 100644 --- a/notebook/agentchat_realtime_websocket.ipynb +++ b/notebook/agentchat_realtime_websocket.ipynb @@ -1,49 +1,137 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# RealtimeAgent with local websocket connection\n", + "\n", + "\n", + "AG2 supports **RealtimeAgent**, a powerful agent type that connects seamlessly to OpenAI's [Realtime API](https://openai.com/index/introducing-the-realtime-api). In this example we will start a local RealtimeAgent and register a mock get_weather function that the agent will be able to call.\n", + "\n", + "**Note**: This notebook cannot be run in Google Colab because it depends on local JavaScript files and HTML templates. To execute the notebook successfully, run it locally within the cloned project so that the `notebooks/agentchat_realtime_websocket/static` and `notebooks/agentchat_realtime_websocket/templates` folders are available in the correct relative paths.\n", + "\n", + "````{=mdx}\n", + ":::info Requirements\n", + "Install `ag2`:\n", + "```bash\n", + "git clone https://github.com/ag2ai/ag2.git\n", + "cd ag2\n", + "```\n", + ":::\n", + "````\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Install AG2 and dependencies\n", + "\n", + "To use the realtime agent we will connect it to a local websocket trough the browser.\n", + "\n", + "We have prepared a `WebsocketAudioAdapter` to enable you to connect your realtime agent to a websocket service.\n", + "\n", + "To be able to run this notebook, you will need to install ag2, fastapi and uvicorn.\n", + "````{=mdx}\n", + ":::info Requirements\n", + "Install `ag2`:\n", + "```bash\n", + "pip install \"ag2\", \"fastapi>=0.115.0,<1\", \"uvicorn>=0.30.6,<1\"\n", + "```\n", + "For more information, please refer to the [installation guide](/docs/installation/Installation).\n", + ":::\n", + "````" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], + "source": [ + "!pip install \"ag2\" \"fastapi>=0.115.0,<1\" \"uvicorn>=0.30.6,<1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import the dependencies\n", + "\n", + "After installing the necessary requirements, we can import the necessary dependencies for the example" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], "source": [ "import os\n", + "from logging import getLogger\n", "from pathlib import Path\n", - "from typing import Annotated, Union\n", + "from typing import Annotated\n", "\n", - "import nest_asyncio\n", "import uvicorn\n", "from fastapi import FastAPI, Request, WebSocket\n", "from fastapi.responses import HTMLResponse, JSONResponse\n", "from fastapi.staticfiles import StaticFiles\n", "from fastapi.templating import Jinja2Templates\n", "\n", - "from autogen.agentchat.realtime_agent import FunctionObserver, RealtimeAgent, WebsocketAudioAdapter" + "import autogen\n", + "from autogen.agentchat.realtime_agent import RealtimeAgent, WebsocketAudioAdapter" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare your `llm_config` and `realtime_llm_config`\n", + "\n", + "The [`config_list_from_json`](https://docs.ag2.ai/docs/reference/oai/openai_utils#config-list-from-json) function loads a list of configurations from an environment variable or a json file." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "# Configuration\n", - "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n", - "PORT = int(os.getenv(\"PORT\", 5050))\n", - "\n", - "if not OPENAI_API_KEY:\n", - " raise ValueError(\"Missing the OpenAI API key. Please set it in the .env file.\")\n", + "realtime_config_list = autogen.config_list_from_json(\n", + " \"OAI_CONFIG_LIST\",\n", + " filter_dict={\n", + " \"tags\": [\"gpt-4o-mini-realtime\"],\n", + " },\n", + ")\n", "\n", - "llm_config = {\n", + "realtime_llm_config = {\n", " \"timeout\": 600,\n", - " \"cache_seed\": 45, # change the seed for different trials\n", - " \"config_list\": [\n", - " {\n", - " \"model\": \"gpt-4o-realtime-preview-2024-10-01\",\n", - " \"api_key\": OPENAI_API_KEY,\n", - " }\n", - " ],\n", + " \"config_list\": realtime_config_list,\n", " \"temperature\": 0.8,\n", - "}" + "}\n", + "\n", + "assert realtime_config_list, (\n", + " \"No LLM found for the given model, please add the following lines to the OAI_CONFIG_LIST file:\"\n", + " \"\"\"\n", + " {\n", + " \"model\": \"gpt-4o-realtime-preview\",\n", + " \"api_key\": \"sk-***********************...*\",\n", + " \"tags\": [\"gpt-4o-mini-realtime\", \"realtime\"]\n", + " }\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Before you start the server\n", + "\n", + "To run uvicorn server inside the notebook, you will need to use nest_asyncio. This is because Jupyter uses the asyncio event loop, and uvicorn uses its own event loop. nest_asyncio will allow uvicorn to run in Jupyter.\n", + "\n", + "Please install nest_asyncio by running the following cell." ] }, { @@ -52,17 +140,77 @@ "metadata": {}, "outputs": [], "source": [ + "!pip install nest_asyncio" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", "nest_asyncio.apply()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Implementing and Running a Basic App\n", + "\n", + "Let us set up and execute a FastAPI application that integrates real-time agent interactions." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define basic FastAPI app\n", + "\n", + "1. **Define Port**: Sets the `PORT` variable to `5050`, which will be used for the server.\n", + "2. **Initialize FastAPI App**: Creates a `FastAPI` instance named `app`, which serves as the main application.\n", + "3. **Define Root Endpoint**: Adds a `GET` endpoint at the root URL (`/`). When accessed, it returns a JSON response with the message `\"Websocket Audio Stream Server is running!\"`.\n", + "\n", + "This sets up a basic FastAPI server and provides a simple health-check endpoint to confirm that the server is operational." + ] + }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ + "PORT = 5050\n", + "\n", "app = FastAPI()\n", "\n", + "@app.get(\"/\", response_class=JSONResponse)\n", + "async def index_page():\n", + " return {\"message\": \"Websocket Audio Stream Server is running!\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare `start-chat` endpoint\n", + "\n", + "1. **Set the Working Directory**: Define `notebook_path` as the current working directory using `os.getcwd()`.\n", + "2. **Mount Static Files**: Mount the `static` directory (inside `agentchat_realtime_websocket`) to serve JavaScript, CSS, and other static assets under the `/static` path.\n", + "3. **Set Up Templates**: Configure Jinja2 to render HTML templates from the `templates` directory within `agentchat_realtime_websocket`.\n", + "4. **Create the `/start-chat/` Endpoint**: Define a `GET` route that serves the `chat.html` template. Pass the client's `request` and the `port` variable to the template for rendering a dynamic page for the audio chat interface.\n", + "\n", + "This code sets up static file handling, template rendering, and a dedicated endpoint to deliver the chat interface.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ "notebook_path = os.getcwd()\n", "\n", "app.mount(\n", @@ -73,41 +221,67 @@ "\n", "templates = Jinja2Templates(directory=Path(notebook_path) / \"agentchat_realtime_websocket\" / \"templates\")\n", "\n", - "\n", - "@app.get(\"/\", response_class=JSONResponse)\n", - "async def index_page():\n", - " return {\"message\": \"Websocket Audio Stream Server is running!\"}\n", - "\n", - "\n", "@app.get(\"/start-chat/\", response_class=HTMLResponse)\n", "async def start_chat(request: Request):\n", " \"\"\"Endpoint to return the HTML page for audio chat.\"\"\"\n", " port = PORT # Extract the client's port\n", - " return templates.TemplateResponse(\"chat.html\", {\"request\": request, \"port\": port})\n", - "\n", + " return templates.TemplateResponse(\"chat.html\", {\"request\": request, \"port\": port})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prepare endpint for converstion audio stream\n", "\n", + "1. **Set Up the WebSocket Endpoint**: Define the `/media-stream` WebSocket route to handle audio streaming.\n", + "2. **Accept WebSocket Connections**: Accept incoming WebSocket connections from clients.\n", + "3. **Initialize Logger**: Retrieve a logger instance for logging purposes.\n", + "4. **Configure Audio Adapter**: Instantiate a `WebsocketAudioAdapter`, connecting the WebSocket to handle audio streaming with logging.\n", + "5. **Set Up Realtime Agent**: Create a `RealtimeAgent` with the following:\n", + " - **Name**: `Weather Bot`.\n", + " - **System Message**: Introduces the AI assistant and its capabilities.\n", + " - **LLM Configuration**: Uses `realtime_llm_config` for language model settings.\n", + " - **Audio Adapter**: Leverages the previously created `audio_adapter`.\n", + " - **Logger**: Logs activities for debugging and monitoring.\n", + "6. **Register a Realtime Function**: Add a function `get_weather` to the agent, allowing it to respond with basic weather information based on the provided `location`.\n", + "7. **Run the Agent**: Start the `realtime_agent` to handle interactions in real time.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ "@app.websocket(\"/media-stream\")\n", "async def handle_media_stream(websocket: WebSocket):\n", " \"\"\"Handle WebSocket connections providing audio stream and OpenAI.\"\"\"\n", " await websocket.accept()\n", "\n", - " audio_adapter = WebsocketAudioAdapter(websocket)\n", - " openai_client = RealtimeAgent(\n", + " logger = getLogger(\"uvicorn.error\")\n", + "\n", + " audio_adapter = WebsocketAudioAdapter(websocket, logger=logger)\n", + " realtime_agent = RealtimeAgent(\n", " name=\"Weather Bot\",\n", - " system_message=\"Hello there! I am an AI voice assistant powered by Autogen and the OpenAI Realtime API. You can ask me about weather, jokes, or anything you can imagine. Start by saying How can I help you?\",\n", - " llm_config=llm_config,\n", + " system_message=\"Hello there! I am an AI voice assistant powered by Autogen and the OpenAI Realtime API. You can ask me about weather, jokes, or anything you can imagine. Start by saying 'How can I help you'?\",\n", + " llm_config=realtime_llm_config,\n", " audio_adapter=audio_adapter,\n", + " logger=logger,\n", " )\n", "\n", - " @openai_client.register_handover(name=\"get_weather\", description=\"Get the current weather\")\n", + " @realtime_agent.register_realtime_function(name=\"get_weather\", description=\"Get the current weather\")\n", " def get_weather(location: Annotated[str, \"city\"]) -> str:\n", - " ...\n", " return \"The weather is cloudy.\" if location == \"Seattle\" else \"The weather is sunny.\"\n", "\n", - " await openai_client.run()\n", - "\n", - "\n", - "uvicorn.run(app, host=\"0.0.0.0\", port=PORT)" + " await realtime_agent.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the app using uvicorn" ] }, { @@ -115,12 +289,14 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "uvicorn.run(app, host=\"0.0.0.0\", port=PORT)" + ] } ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": ".venv-3.9", "language": "python", "name": "python3" }, @@ -134,7 +310,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.9.20" } }, "nbformat": 4, diff --git a/setup.py b/setup.py index 8683843068..6d53075d89 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ current_os = platform.system() install_requires = [ - "openai>=1.57", + "openai>=1.58", "diskcache", "termcolor", "flaml", @@ -51,6 +51,7 @@ "pytest-asyncio", "pytest>=8,<9", "pandas", + "fastapi>=0.115.0,<1", ] jupyter_executor = [ @@ -90,7 +91,7 @@ interop_pydantic_ai = ["pydantic-ai==0.0.13"] interop = interop_crewai + interop_langchain + interop_pydantic_ai -types = ["mypy==1.9.0"] + test + jupyter_executor + interop + ["fastapi>=0.115.0,<1"] +types = ["mypy==1.9.0"] + test + jupyter_executor + interop if current_os in ["Windows", "Darwin"]: retrieve_chat_pgvector.extend(["psycopg[binary]>=3.1.18"]) diff --git a/test/agentchat/realtime_agent/__init__.py b/test/agentchat/realtime_agent/__init__.py index 87ec7612a0..bcd5401d54 100644 --- a/test/agentchat/realtime_agent/__init__.py +++ b/test/agentchat/realtime_agent/__init__.py @@ -1,6 +1,3 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 -# -# Portions derived from https://github.com/microsoft/autogen are under the MIT License. -# SPDX-License-Identifier: MIT diff --git a/test/agentchat/realtime_agent/test_e2e.py b/test/agentchat/realtime_agent/test_e2e.py new file mode 100644 index 0000000000..0d8ca300bb --- /dev/null +++ b/test/agentchat/realtime_agent/test_e2e.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Annotated, Any, Generator + +import pytest +from anyio import sleep +from asyncer import create_task_group +from conftest import MOCK_OPEN_AI_API_KEY, reason, skip_openai # noqa: E402 +from fastapi import FastAPI, WebSocket +from fastapi.testclient import TestClient +from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST + +import autogen +from autogen.agentchat.realtime_agent import RealtimeAgent, WebSocketAudioAdapter +from autogen.agentchat.realtime_agent.oai_realtime_client import OpenAIRealtimeClient + + +@pytest.mark.skipif(skip_openai, reason=reason) +class TestE2E: + @pytest.fixture + def llm_config(self) -> dict[str, Any]: + config_list = autogen.config_list_from_json( + OAI_CONFIG_LIST, + filter_dict={ + "tags": ["gpt-4o-realtime"], + }, + file_location=KEY_LOC, + ) + assert config_list, "No config list found" + return { + "config_list": config_list, + "temperature": 0.8, + } + + @pytest.mark.asyncio() + async def test_init(self, llm_config: dict[str, Any]) -> None: + + app = FastAPI() + + @app.websocket("/media-stream") + async def handle_media_stream(websocket: WebSocket) -> None: + """Handle WebSocket connections providing audio stream and OpenAI.""" + print("test_init() Waiting for connection to be accepted...", flush=True) + await websocket.accept() + print("test_init() Connection accepted.", flush=True) + + audio_adapter = WebSocketAudioAdapter(websocket) + agent = RealtimeAgent( + name="Weather Bot", + system_message="Hello there! I am an AI voice assistant powered by Autogen and the OpenAI Realtime API. You can ask me about weather, jokes, or anything you can imagine. Start by saying 'How can I help you?'", + llm_config=llm_config, + audio_adapter=audio_adapter, + ) + + @agent.register_realtime_function(name="get_weather", description="Get the current weather") + def get_weather(location: Annotated[str, "city"]) -> str: + return "The weather is cloudy." if location == "Seattle" else "The weather is sunny." + + print("test_init() Running agent...", flush=True) + async with create_task_group() as tg: + tg.soonify(agent.run)() + await sleep(3) + tg.cancel_scope.cancel() + + # todo: the rest of the scenario + ... + + await websocket.send_json({"msg": "Hello, World!"}) + + print("test_init() Running agent finished", flush=True) + await websocket.close() + + client = TestClient(app) + with client.websocket_connect("/media-stream") as websocket: + data = websocket.receive_json() + assert data == {"msg": "Hello, World!"} + print("test_init() client.websocket_connect() finished", flush=True) + + print("test_init() finished") diff --git a/test/agentchat/realtime_agent/test_oai_realtime_client.py b/test/agentchat/realtime_agent/test_oai_realtime_client.py new file mode 100644 index 0000000000..f5f64e2fa1 --- /dev/null +++ b/test/agentchat/realtime_agent/test_oai_realtime_client.py @@ -0,0 +1,141 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from anyio import move_on_after +from asyncer import create_task_group +from conftest import MOCK_OPEN_AI_API_KEY, reason, skip_openai # noqa: E402 +from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST + +import autogen +from autogen.agentchat.realtime_agent.oai_realtime_client import OpenAIRealtimeClient +from autogen.agentchat.realtime_agent.realtime_client import RealtimeClientProtocol + + +class TestOAIRealtimeClient: + @pytest.fixture + def llm_config(self) -> dict[str, Any]: + config_list = autogen.config_list_from_json( + OAI_CONFIG_LIST, + filter_dict={ + "tags": ["gpt-4o-realtime"], + }, + file_location=KEY_LOC, + ) + assert config_list, "No config list found" + return { + "config_list": config_list, + "temperature": 0.8, + } + + @pytest.fixture + def client(self, llm_config: dict[str, Any]) -> RealtimeClientProtocol: + return OpenAIRealtimeClient( + llm_config=llm_config, + voice="alloy", + system_message="You are a helpful AI assistant with voice capabilities.", + ) + + def test_init(self) -> None: + llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": MOCK_OPEN_AI_API_KEY, + }, + ], + "temperature": 0.8, + } + client = OpenAIRealtimeClient( + llm_config=llm_config, + voice="alloy", + system_message="You are a helpful AI assistant with voice capabilities.", + ) + assert isinstance(client, RealtimeClientProtocol) + + @pytest.mark.skipif(skip_openai, reason=reason) + @pytest.mark.asyncio() + async def test_not_connected(self, client: OpenAIRealtimeClient) -> None: + + with pytest.raises(RuntimeError, match=r"Client is not connected, call connect\(\) first."): + with move_on_after(1) as scope: + async for _ in client.read_events(): + pass + + assert not scope.cancelled_caught + + @pytest.mark.skipif(skip_openai, reason=reason) + @pytest.mark.asyncio() + async def test_start_read_events(self, client: OpenAIRealtimeClient) -> None: + + mock = MagicMock() + + async with client.connect(): + # read events for 3 seconds and then interrupt + with move_on_after(3) as scope: + print("Reading events...") + + async for event in client.read_events(): + print(f"-> Received event: {event}") + mock(**event) + + # checking if the scope was cancelled by move_on_after + assert scope.cancelled_caught + + # check that we received the expected two events + calls_kwargs = [arg_list.kwargs for arg_list in mock.call_args_list] + assert calls_kwargs[0]["type"] == "session.created" + assert calls_kwargs[1]["type"] == "session.updated" + + @pytest.mark.skipif(skip_openai, reason=reason) + @pytest.mark.asyncio() + async def test_send_text(self, client: OpenAIRealtimeClient) -> None: + + mock = MagicMock() + + async with client.connect(): + # read events for 3 seconds and then interrupt + with move_on_after(3) as scope: + print("Reading events...") + async for event in client.read_events(): + print(f"-> Received event: {event}") + mock(**event) + + if event["type"] == "session.updated": + await client.send_text(role="user", text="Hello, how are you?") + + # checking if the scope was cancelled by move_on_after + assert scope.cancelled_caught + + # check that we received the expected two events + calls_kwargs = [arg_list.kwargs for arg_list in mock.call_args_list] + assert calls_kwargs[0]["type"] == "session.created" + assert calls_kwargs[1]["type"] == "session.updated" + + assert calls_kwargs[2]["type"] == "error" + assert calls_kwargs[2]["error"]["message"] == "Cancellation failed: no active response found" + + assert calls_kwargs[3]["type"] == "conversation.item.created" + assert calls_kwargs[3]["item"]["content"][0]["text"] == "Hello, how are you?" + + @pytest.mark.skip(reason="Not implemented") + @pytest.mark.skipif(skip_openai, reason=reason) + @pytest.mark.asyncio() + async def test_send_audio(self, client: OpenAIRealtimeClient) -> None: + raise NotImplementedError + + @pytest.mark.skip(reason="Not implemented") + @pytest.mark.skipif(skip_openai, reason=reason) + @pytest.mark.asyncio() + async def test_truncate_audio(self, client: OpenAIRealtimeClient) -> None: + raise NotImplementedError + + @pytest.mark.skip(reason="Not implemented") + @pytest.mark.skipif(skip_openai, reason=reason) + @pytest.mark.asyncio() + async def test_initialize_session(self, client: OpenAIRealtimeClient) -> None: + raise NotImplementedError diff --git a/test/agentchat/realtime_agent/test_realtime_observer.py b/test/agentchat/realtime_agent/test_realtime_observer.py new file mode 100644 index 0000000000..7bf4306daa --- /dev/null +++ b/test/agentchat/realtime_agent/test_realtime_observer.py @@ -0,0 +1,60 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from asyncio import sleep +from typing import Any +from unittest.mock import MagicMock + +import pytest +from asyncer import create_task_group +from openai.types.beta.realtime.realtime_server_event import RealtimeServerEvent + +from autogen.agentchat.realtime_agent import RealtimeAgent, RealtimeObserver + + +class MyObserver(RealtimeObserver): + def __init__(self, mock: MagicMock) -> None: + super().__init__() + self.mock = mock + + async def initialize_session(self) -> None: + pass + + async def run_loop(self) -> None: + self.mock("started") + try: + self.mock("running") + print("-> running", end="", flush=True) + while True: + await sleep(0.05) + print(".", end="", flush=True) + finally: + print("stopped", flush=True) + self.mock("stopped") + + async def on_event(self, event: dict[str, Any]) -> None: + pass + + +class TestRealtimeObserver: + @pytest.mark.asyncio() + async def test_shutdown(self) -> None: + + mock = MagicMock() + observer = MyObserver(mock) + + agent = MagicMock() + + try: + async with create_task_group() as tg: + tg.soonify(observer.run)(agent) + await sleep(1.0) + tg.cancel_scope.cancel() + + except Exception as e: + print(e) + + mock.assert_any_call("started") + mock.assert_any_call("running") + mock.assert_called_with("stopped") diff --git a/test/agentchat/realtime_agent/test_submodule.py b/test/agentchat/realtime_agent/test_submodule.py index eff9f04964..bafd8a649e 100644 --- a/test/agentchat/realtime_agent/test_submodule.py +++ b/test/agentchat/realtime_agent/test_submodule.py @@ -11,5 +11,5 @@ def test_import() -> None: FunctionObserver, RealtimeAgent, TwilioAudioAdapter, - WebsocketAudioAdapter, + WebSocketAudioAdapter, ) diff --git a/test/interop/crewai/__init__.py b/test/interop/crewai/__init__.py index bcd5401d54..4323612a70 100644 --- a/test/interop/crewai/__init__.py +++ b/test/interop/crewai/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 +#