Skip to content

Commit

Permalink
Merge pull request #44 from harishmohanraj/fixes
Browse files Browse the repository at this point in the history
Fixes
  • Loading branch information
harishmohanraj authored Dec 20, 2024
2 parents dfaabea + 07eb38c commit fc90407
Show file tree
Hide file tree
Showing 16 changed files with 336 additions and 169 deletions.
2 changes: 1 addition & 1 deletion autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def initiate_swarm_chat(
user_agent: Optional[UserProxyAgent] = None,
max_rounds: int = 20,
context_variables: Optional[dict[str, Any]] = None,
after_work: Optional[Union[AFTER_WORK, Callable]] = AFTER_WORK(AfterWorkOption.TERMINATE),
after_work: Optional[Union[AfterWorkOption, Callable]] = AFTER_WORK(AfterWorkOption.TERMINATE),
) -> tuple[ChatResult, dict[str, Any], "SwarmAgent"]:
"""Initialize and run a swarm chat
Expand Down
123 changes: 86 additions & 37 deletions autogen/agentchat/realtime_agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,63 +8,83 @@
# import asyncio
import json
import logging
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import anyio
import websockets
from asyncer import TaskGroup, asyncify, create_task_group, syncify
from asyncer import TaskGroup, asyncify, create_task_group
from websockets import connect
from websockets.asyncio.client import ClientConnection

from autogen.agentchat.contrib.swarm_agent import AfterWorkOption, initiate_swarm_chat
from ..contrib.swarm_agent import AfterWorkOption, SwarmAgent, initiate_swarm_chat

from .function_observer import FunctionObserver
if TYPE_CHECKING:
from .function_observer import FunctionObserver
from .realtime_agent import RealtimeAgent
from .realtime_observer import RealtimeObserver

logger = logging.getLogger(__name__)


class OpenAIRealtimeClient:
"""(Experimental) Client for OpenAI Realtime API."""

def __init__(self, agent, audio_adapter, function_observer: FunctionObserver):
def __init__(
self, agent: "RealtimeAgent", audio_adapter: "RealtimeObserver", function_observer: "FunctionObserver"
) -> None:
"""(Experimental) Client for OpenAI Realtime API.
args:
agent: Agent instance
the agent to be used for the conversation
audio_adapter: RealtimeObserver
adapter for streaming the audio from the client
function_observer: FunctionObserver
observer for handling function calls
Args:
agent (RealtimeAgent): The agent that the client is associated with.
audio_adapter (RealtimeObserver): The audio adapter for the client.
function_observer (FunctionObserver): The function observer for the client.
"""
self._agent = agent
self._observers = []
self._openai_ws = None # todo factor out to OpenAIClient
self._observers: list["RealtimeObserver"] = []
self._openai_ws: Optional[ClientConnection] = None # todo factor out to OpenAIClient
self.register(audio_adapter)
self.register(function_observer)

# LLM config
llm_config = self._agent.llm_config

config = llm_config["config_list"][0]
config: dict[str, Any] = llm_config["config_list"][0] # type: ignore[index]

self.model = config["model"]
self.temperature = llm_config["temperature"]
self.api_key = config["api_key"]
self.model: str = config["model"]
self.temperature: float = llm_config["temperature"] # type: ignore[index]
self.api_key: str = config["api_key"]

# create a task group to manage the tasks
self.tg: Optional[TaskGroup] = None

def register(self, observer):
@property
def openai_ws(self) -> ClientConnection:
"""Get the OpenAI WebSocket connection."""
if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")
return self._openai_ws

def register(self, observer: "RealtimeObserver") -> None:
"""Register an observer to the client."""
observer.register_client(self)
self._observers.append(observer)

async def notify_observers(self, message):
"""Notify all observers of a message from the OpenAI Realtime API."""
async def notify_observers(self, message: dict[str, Any]) -> None:
"""Notify all observers of a message from the OpenAI Realtime API.
Args:
message (dict[str, Any]): The message from the OpenAI Realtime API.
"""
for observer in self._observers:
await observer.update(message)

async def function_result(self, call_id, result):
"""Send the result of a function call to the OpenAI Realtime API."""
async def function_result(self, call_id: str, result: str) -> None:
"""Send the result of a function call to the OpenAI Realtime API.
Args:
call_id (str): The ID of the function call.
result (str): The result of the function call.
"""
result_item = {
"type": "conversation.item.create",
"item": {
Expand All @@ -73,11 +93,23 @@ async def function_result(self, call_id, result):
"output": result,
},
}
if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")

await self._openai_ws.send(json.dumps(result_item))
await self._openai_ws.send(json.dumps({"type": "response.create"}))

async def send_text(self, *, role: str, text: str):
"""Send a text message to the OpenAI Realtime API."""
async def send_text(self, *, role: str, text: str) -> None:
"""Send a text message to the OpenAI Realtime API.
Args:
role (str): The role of the message.
text (str): The text of the message.
"""

if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")

await self._openai_ws.send(json.dumps({"type": "response.cancel"}))
text_item = {
"type": "conversation.item.create",
Expand All @@ -87,7 +119,7 @@ async def send_text(self, *, role: str, text: str):
await self._openai_ws.send(json.dumps({"type": "response.create"}))

# todo override in specific clients
async def initialize_session(self):
async def initialize_session(self) -> None:
"""Control initial session with OpenAI."""
session_update = {
# todo: move to config
Expand All @@ -100,25 +132,35 @@ async def initialize_session(self):
await self.session_update(session_update)

# todo override in specific clients
async def session_update(self, session_options):
"""Send a session update to the OpenAI Realtime API."""
async def session_update(self, session_options: dict[str, Any]) -> None:
"""Send a session update to the OpenAI Realtime API.
Args:
session_options (dict[str, Any]): The session options to update.
"""
if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")

update = {"type": "session.update", "session": session_options}
logger.info("Sending session update:", json.dumps(update))
await self._openai_ws.send(json.dumps(update))
logger.info("Sending session update finished")

async def _read_from_client(self):
async def _read_from_client(self) -> None:
"""Read messages from the OpenAI Realtime API."""
if self._openai_ws is None:
raise RuntimeError("OpenAI WebSocket is not initialized")

try:
async for openai_message in self._openai_ws:
response = json.loads(openai_message)
await self.notify_observers(response)
except Exception as e:
logger.warning(f"Error in _read_from_client: {e}")

async def run(self):
async def run(self) -> None:
"""Run the client."""
async with websockets.connect(
async with connect(
f"wss://api.openai.com/v1/realtime?model={self.model}",
additional_headers={
"Authorization": f"Bearer {self.api_key}",
Expand All @@ -127,17 +169,24 @@ async def run(self):
) as openai_ws:
self._openai_ws = openai_ws
await self.initialize_session()
# await asyncio.gather(self._read_from_client(), *[observer.run() for observer in self._observers])
async with create_task_group() as tg:
self.tg = tg
self.tg.soonify(self._read_from_client)()
for observer in self._observers:
self.tg.soonify(observer.run)()

initial_agent = self._agent._initial_agent
agents = self._agent._agents
user_agent = self._agent

if not (initial_agent and agents):
raise RuntimeError("Swarm not registered.")

if self._agent._start_swarm_chat:
self.tg.soonify(asyncify(initiate_swarm_chat))(
initial_agent=self._agent._initial_agent,
agents=self._agent._agents,
user_agent=self._agent,
initial_agent=initial_agent,
agents=agents,
user_agent=user_agent, # type: ignore[arg-type]
messages="Find out what the user wants.",
after_work=AfterWorkOption.REVERT_TO_USER,
)
36 changes: 25 additions & 11 deletions autogen/agentchat/realtime_agent/function_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,52 @@
import asyncio
import json
import logging
from typing import TYPE_CHECKING, Any

from asyncer import asyncify
from pydantic import BaseModel

from .realtime_observer import RealtimeObserver

if TYPE_CHECKING:
from .realtime_agent import RealtimeAgent

logger = logging.getLogger(__name__)


class FunctionObserver(RealtimeObserver):
"""Observer for handling function calls from the OpenAI Realtime API."""

def __init__(self, agent):
def __init__(self, agent: "RealtimeAgent") -> None:
"""Observer for handling function calls from the OpenAI Realtime API.
Args:
agent: Agent instance
the agent to be used for the conversation
agent (RealtimeAgent): The realtime agent attached to the observer.
"""
super().__init__()
self._agent = agent

async def update(self, response):
"""Handle function call events from the OpenAI Realtime API."""
async def update(self, response: dict[str, Any]) -> None:
"""Handle function call events from the OpenAI Realtime API.
Args:
response (dict[str, Any]): The response from the OpenAI Realtime API.
"""
if response.get("type") == "response.function_call_arguments.done":
logger.info(f"Received event: {response['type']}", response)
await self.call_function(
call_id=response["call_id"], name=response["name"], kwargs=json.loads(response["arguments"])
)

async def call_function(self, call_id, name, kwargs):
"""Call a function registered with the agent."""
async def call_function(self, call_id: str, name: str, kwargs: dict[str, Any]) -> None:
"""Call a function registered with the agent.
Args:
call_id (str): The ID of the function call.
name (str): The name of the function to call.
kwargs (Any[str, Any]): The arguments to pass to the function.
"""

if name in self._agent.realtime_functions:
_, func = self._agent.realtime_functions[name]
func = func if asyncio.iscoroutinefunction(func) else asyncify(func)
Expand All @@ -54,19 +68,19 @@ async def call_function(self, call_id, name, kwargs):
elif not isinstance(result, str):
result = json.dumps(result)

await self._client.function_result(call_id, result)
await self.client.function_result(call_id, result)

async def run(self):
async def run(self) -> None:
"""Run the observer.
Initialize the session with the OpenAI Realtime API.
"""
await self.initialize_session()

async def initialize_session(self):
async def initialize_session(self) -> None:
"""Add registered tools to OpenAI with a session update."""
session_update = {
"tools": [schema for schema, _ in self._agent.realtime_functions.values()],
"tool_choice": "auto",
}
await self._client.session_update(session_update)
await self.client.session_update(session_update)
Loading

0 comments on commit fc90407

Please sign in to comment.