Skip to content

Commit

Permalink
Merge pull request #105 from bolna-ai/first_message_non_stream
Browse files Browse the repository at this point in the history
forcing initial message in non-stream ws
  • Loading branch information
prateeksachan authored Jan 19, 2025
2 parents 84ed61f + 0cc7db5 commit 2554a26
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
31 changes: 19 additions & 12 deletions bolna/agent_manager/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,12 @@ def __setup_output_handlers(self, turn_based_conversation, output_queue):
def __setup_input_handlers(self, turn_based_conversation, input_queue, should_record):
if self.task_config["tools_config"]["input"]["provider"] in SUPPORTED_INPUT_HANDLERS.keys():
logger.info(f"Connected through dashboard {turn_based_conversation}")
input_kwargs = {"queues": self.queues,
"websocket": self.websocket,
"input_types": get_required_input_types(self.task_config),
"mark_set": self.mark_set}
input_kwargs = {
"queues": self.queues,
"websocket": self.websocket,
"input_types": get_required_input_types(self.task_config),
"mark_set": self.mark_set
}

if self.task_config["tools_config"]["input"]["provider"] == "daily":
input_kwargs['room_url'] = self.room_url
Expand Down Expand Up @@ -581,7 +583,8 @@ def __setup_synthesizer(self, llm_config=None):
self.task_config["tools_config"]["synthesizer"]["stream"] = True if self.enforce_streaming else False #Hardcode stream to be False as we don't want to get blocked by a __listen_synthesizer co-routine

self.tools["synthesizer"] = synthesizer_class(**self.task_config["tools_config"]["synthesizer"], **provider_config, **self.kwargs, caching=caching)
self.synthesizer_monitor_task = asyncio.create_task(self.tools['synthesizer'].monitor_connection())
if not self.turn_based_conversation:
self.synthesizer_monitor_task = asyncio.create_task(self.tools['synthesizer'].monitor_connection())
if self.task_config["tools_config"]["llm_agent"] is not None and llm_config is not None:
llm_config["buffer_size"] = self.task_config["tools_config"]["synthesizer"].get('buffer_size')

Expand Down Expand Up @@ -928,7 +931,6 @@ def __update_preprocessed_tree_node(self):
# LLM task
##############################################################
async def _handle_llm_output(self, next_step, text_chunk, should_bypass_synth, meta_info, is_filler = False):

logger.info("received text from LLM for output processing: {} which belongs to sequence id {}".format(text_chunk, meta_info['sequence_id']))
if "request_id" not in meta_info:
meta_info["request_id"] = str(uuid.uuid4())
Expand Down Expand Up @@ -1192,7 +1194,7 @@ async def _process_conversation_task(self, message, sequence, meta_info):
logger.info("agent flow is not preprocessed")

start_time = time.time()
should_bypass_synth = 'bypass_synth' in meta_info and meta_info['bypass_synth'] == True
should_bypass_synth = 'bypass_synth' in meta_info and meta_info['bypass_synth'] is True
next_step = self._get_next_step(sequence, "llm")
meta_info['llm_start_time'] = time.time()
route = None
Expand Down Expand Up @@ -1926,7 +1928,15 @@ async def __first_message(self, timeout=10.0):
text = self.kwargs.get('agent_welcome_message', None)
logger.info(f"Generating {text}")
meta_info = {'io': self.tools["output"].get_provider(), 'message_category': 'agent_welcome_message', 'stream_sid': stream_sid, "request_id": str(uuid.uuid4()), "cached": True, "sequence_id": -1, 'format': self.task_config["tools_config"]["output"]["format"], 'text': text}
await self._synthesize(create_ws_data_packet(text, meta_info=meta_info))
if self.turn_based_conversation:
meta_info['type'] = 'text'
bos_packet = create_ws_data_packet("<beginning_of_stream>", meta_info)
await self.tools["output"].handle(bos_packet)
await self.tools["output"].handle(create_ws_data_packet(text, meta_info))
eos_packet = create_ws_data_packet("<end_of_stream>", meta_info)
await self.tools["output"].handle(eos_packet)
else:
await self._synthesize(create_ws_data_packet(text, meta_info=meta_info))
break
else:
logger.info(f"Stream id is still None, so not passing it")
Expand Down Expand Up @@ -1990,12 +2000,9 @@ async def run(self):

logger.info(f"Starting the first message task {self.enforce_streaming}")
self.output_task = asyncio.create_task(self.__process_output_loop())
self.first_message_task = asyncio.create_task(self.__first_message())
if not self.turn_based_conversation or self.enforce_streaming:
logger.info(f"Setting up other servers")
self.first_message_task = asyncio.create_task(self.__first_message())
#if not self.use_llm_to_determine_hangup :
# By default we will hang up after x amount of silence
# We still need to
self.hangup_task = asyncio.create_task(self.__check_for_completion())

if self.should_backchannel:
Expand Down
4 changes: 4 additions & 0 deletions bolna/input_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import base64
import time
import uuid
from dotenv import load_dotenv
from bolna.helpers.logger_config import configure_logger
from bolna.helpers.utils import create_ws_data_packet
Expand Down Expand Up @@ -28,6 +29,9 @@ async def stop_handler(self):
except Exception as e:
logger.error(f"Error closing WebSocket: {e}")

def get_stream_sid(self):
return str(uuid.uuid4())

def __process_audio(self, audio):
data = base64.b64decode(audio)
ws_data_packet = create_ws_data_packet(
Expand Down

0 comments on commit 2554a26

Please sign in to comment.