From ac23408a0fce2f4db52ad043d0638515384c2770 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Mon, 3 Feb 2025 21:10:32 -0500 Subject: [PATCH] Add result arg to transition_callback, add a dynamic restaurant reservation demo --- CHANGELOG.md | 12 + README.md | 6 +- examples/dynamic/insurance_anthropic.py | 20 +- examples/dynamic/insurance_gemini.py | 20 +- examples/dynamic/insurance_openai.py | 20 +- examples/dynamic/restaurant_reservation.py | 352 +++++++++++++++++++++ examples/static/restaurant_reservation.py | 248 --------------- src/pipecat_flows/manager.py | 15 +- 8 files changed, 413 insertions(+), 280 deletions(-) create mode 100644 examples/dynamic/restaurant_reservation.py delete mode 100644 examples/static/restaurant_reservation.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fc3ba0f..2aa712b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,18 @@ Example usage: context = flow_manager.get_current_context() ``` +- Added a new dynamic example called `restaurant_reservation.py`. + +### Changed + +- Transition callbacks now receive function results directly as a second argument: + `async def handle_transition(args: Dict, result: FlowResult, flow_manager: FlowManager)`. + This enables direct access to typed function results for making routing decisions. + For backwards compatibility, the two-argument signature + `(args: Dict, flow_manager: FlowManager)` is still supported. + +- Updated dynamic examples to use the new result argument. + ### Deprecated - The `tts` parameter in `FlowManager.__init__()` is now deprecated and will diff --git a/README.md b/README.md index 919457f..91dee0c 100644 --- a/README.md +++ b/README.md @@ -404,8 +404,9 @@ In the `examples/static` directory, you'll find these examples: - `movie_explorer_openai.py` - Movie information bot demonstrating real API integration with TMDB - `movie_explorer_anthropic.py` - The same movie information demo adapted for Anthropic's format - `movie_explorer_gemini.py` - The same movie explorer demo adapted for Google Gemini's format -- `patient_intake.py` - A medical intake system showing complex state management -- `restaurant_reservation.py` - A reservation system with availability checking +- `patient_intake_openai.py` - A medical intake system showing complex state management +- `patient_intake_anthropic.py` - The same medical intake demo adapted for Anthropic's format +- `patient_intake_gemini.py` - The same medical intake demo adapted for Gemini's format - `travel_planner.py` - A vacation planning assistant with parallel paths ### Dynamic @@ -415,6 +416,7 @@ In the `examples/dynamic` directory, you'll find these examples: - `insurance_openai.py` - An insurance quote system using OpenAI's format - `insurance_anthropic.py` - The same insurance system adapted for Anthropic's format - `insurance_gemini.py` - The insurance system implemented with Google's format +- `restaurant_reservation.py` - A reservation system with availability checking Each LLM provider (OpenAI, Anthropic, Google) has slightly different function calling formats, but Pipecat Flows handles these differences internally while maintaining a consistent API for developers. diff --git a/examples/dynamic/insurance_anthropic.py b/examples/dynamic/insurance_anthropic.py index bacb735..202c4e9 100644 --- a/examples/dynamic/insurance_anthropic.py +++ b/examples/dynamic/insurance_anthropic.py @@ -149,13 +149,15 @@ async def end_quote() -> FlowResult: # Transition callbacks and handlers -async def handle_age_collection(args: Dict, flow_manager: FlowManager): - flow_manager.state["age"] = args["age"] +async def handle_age_collection(args: Dict, result: AgeCollectionResult, flow_manager: FlowManager): + flow_manager.state["age"] = result["age"] await flow_manager.set_node("marital_status", create_marital_status_node()) -async def handle_marital_status_collection(args: Dict, flow_manager: FlowManager): - flow_manager.state["marital_status"] = args["marital_status"] +async def handle_marital_status_collection( + args: Dict, result: MaritalStatusResult, flow_manager: FlowManager +): + flow_manager.state["marital_status"] = result["marital_status"] await flow_manager.set_node( "quote_calculation", create_quote_calculation_node( @@ -164,13 +166,13 @@ async def handle_marital_status_collection(args: Dict, flow_manager: FlowManager ) -async def handle_quote_calculation(args: Dict, flow_manager: FlowManager): - quote = await calculate_quote(args) - flow_manager.state["quote"] = quote - await flow_manager.set_node("quote_results", create_quote_results_node(quote)) +async def handle_quote_calculation( + args: Dict, result: QuoteCalculationResult, flow_manager: FlowManager +): + await flow_manager.set_node("quote_results", create_quote_results_node(result)) -async def handle_end_quote(_: Dict, flow_manager: FlowManager): +async def handle_end_quote(_: Dict, result: FlowResult, flow_manager: FlowManager): await flow_manager.set_node("end", create_end_node()) diff --git a/examples/dynamic/insurance_gemini.py b/examples/dynamic/insurance_gemini.py index bc01576..1ba22e9 100644 --- a/examples/dynamic/insurance_gemini.py +++ b/examples/dynamic/insurance_gemini.py @@ -149,13 +149,15 @@ async def end_quote() -> FlowResult: # Transition callbacks and handlers -async def handle_age_collection(args: Dict, flow_manager: FlowManager): - flow_manager.state["age"] = args["age"] +async def handle_age_collection(args: Dict, result: AgeCollectionResult, flow_manager: FlowManager): + flow_manager.state["age"] = result["age"] await flow_manager.set_node("marital_status", create_marital_status_node()) -async def handle_marital_status_collection(args: Dict, flow_manager: FlowManager): - flow_manager.state["marital_status"] = args["marital_status"] +async def handle_marital_status_collection( + args: Dict, result: MaritalStatusResult, flow_manager: FlowManager +): + flow_manager.state["marital_status"] = result["marital_status"] await flow_manager.set_node( "quote_calculation", create_quote_calculation_node( @@ -164,13 +166,13 @@ async def handle_marital_status_collection(args: Dict, flow_manager: FlowManager ) -async def handle_quote_calculation(args: Dict, flow_manager: FlowManager): - quote = await calculate_quote(args) - flow_manager.state["quote"] = quote - await flow_manager.set_node("quote_results", create_quote_results_node(quote)) +async def handle_quote_calculation( + args: Dict, result: QuoteCalculationResult, flow_manager: FlowManager +): + await flow_manager.set_node("quote_results", create_quote_results_node(result)) -async def handle_end_quote(_: Dict, flow_manager: FlowManager): +async def handle_end_quote(_: Dict, result: FlowResult, flow_manager: FlowManager): await flow_manager.set_node("end", create_end_node()) diff --git a/examples/dynamic/insurance_openai.py b/examples/dynamic/insurance_openai.py index a1891e8..8e6bac8 100644 --- a/examples/dynamic/insurance_openai.py +++ b/examples/dynamic/insurance_openai.py @@ -149,13 +149,15 @@ async def end_quote() -> FlowResult: # Transition callbacks and handlers -async def handle_age_collection(args: Dict, flow_manager: FlowManager): - flow_manager.state["age"] = args["age"] +async def handle_age_collection(args: Dict, result: AgeCollectionResult, flow_manager: FlowManager): + flow_manager.state["age"] = result["age"] await flow_manager.set_node("marital_status", create_marital_status_node()) -async def handle_marital_status_collection(args: Dict, flow_manager: FlowManager): - flow_manager.state["marital_status"] = args["marital_status"] +async def handle_marital_status_collection( + args: Dict, result: MaritalStatusResult, flow_manager: FlowManager +): + flow_manager.state["marital_status"] = result["marital_status"] await flow_manager.set_node( "quote_calculation", create_quote_calculation_node( @@ -164,13 +166,13 @@ async def handle_marital_status_collection(args: Dict, flow_manager: FlowManager ) -async def handle_quote_calculation(args: Dict, flow_manager: FlowManager): - quote = await calculate_quote(args) - flow_manager.state["quote"] = quote - await flow_manager.set_node("quote_results", create_quote_results_node(quote)) +async def handle_quote_calculation( + args: Dict, result: QuoteCalculationResult, flow_manager: FlowManager +): + await flow_manager.set_node("quote_results", create_quote_results_node(result)) -async def handle_end_quote(_: Dict, flow_manager: FlowManager): +async def handle_end_quote(_: Dict, result: FlowResult, flow_manager: FlowManager): await flow_manager.set_node("end", create_end_node()) diff --git a/examples/dynamic/restaurant_reservation.py b/examples/dynamic/restaurant_reservation.py new file mode 100644 index 0000000..1bf9901 --- /dev/null +++ b/examples/dynamic/restaurant_reservation.py @@ -0,0 +1,352 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys +from datetime import datetime, time +from pathlib import Path +from typing import Dict, TypedDict + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.deepgram import DeepgramSTTService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +from pipecat_flows import FlowArgs, FlowManager, FlowResult, NodeConfig + +sys.path.append(str(Path(__file__).parent.parent)) +from runner import configure + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +# Mock reservation system +class MockReservationSystem: + """Simulates a restaurant reservation system API.""" + + def __init__(self): + # Mock data: Times that are "fully booked" + self.booked_times = {"7:00 PM", "8:00 PM"} # Changed to AM/PM format + + async def check_availability( + self, party_size: int, requested_time: str + ) -> tuple[bool, list[str]]: + """Check if a table is available for the given party size and time.""" + # Simulate API call delay + await asyncio.sleep(0.5) + + # Check if time is booked + is_available = requested_time not in self.booked_times + + # If not available, suggest alternative times + alternatives = [] + if not is_available: + base_times = ["5:00 PM", "6:00 PM", "7:00 PM", "8:00 PM", "9:00 PM", "10:00 PM"] + alternatives = [t for t in base_times if t not in self.booked_times] + + return is_available, alternatives + + +# Initialize mock system +reservation_system = MockReservationSystem() + + +# Type definitions for function results +class PartySizeResult(FlowResult): + size: int + status: str + + +class TimeResult(FlowResult): + status: str + time: str + available: bool + alternative_times: list[str] + + +# Function handlers +async def collect_party_size(args: FlowArgs) -> PartySizeResult: + """Process party size collection.""" + size = args["size"] + return PartySizeResult(size=size, status="success") + + +async def check_availability(args: FlowArgs) -> TimeResult: + """Check reservation availability and return result.""" + time = args["time"] + party_size = args["party_size"] + + # Check availability with mock API + is_available, alternative_times = await reservation_system.check_availability(party_size, time) + + result = TimeResult( + status="success", time=time, available=is_available, alternative_times=alternative_times + ) + return result + + +# Transition handlers +async def handle_party_size_collection( + args: Dict, result: PartySizeResult, flow_manager: FlowManager +): + """Handle party size collection and transition to time selection.""" + # Store party size in flow state + flow_manager.state["party_size"] = result["size"] + await flow_manager.set_node("get_time", create_time_selection_node()) + + +async def handle_availability_check(args: Dict, result: TimeResult, flow_manager: FlowManager): + """Handle availability check result and transition based on availability.""" + # Store reservation details in flow state + flow_manager.state["requested_time"] = args["time"] + + # Use result directly instead of accessing state + if result["available"]: + logger.debug("Time is available, transitioning to confirmation node") + await flow_manager.set_node("confirm", create_confirmation_node()) + else: + logger.debug(f"Time not available, storing alternatives: {result['alternative_times']}") + await flow_manager.set_node( + "no_availability", create_no_availability_node(result["alternative_times"]) + ) + + +async def handle_end(_: Dict, result: FlowResult, flow_manager: FlowManager): + """Handle conversation end.""" + await flow_manager.set_node("end", create_end_node()) + + +# Node configurations +def create_initial_node() -> NodeConfig: + """Create initial node for party size collection.""" + return { + "role_messages": [ + { + "role": "system", + "content": "You are a restaurant reservation assistant for La Maison, an upscale French restaurant. Be casual and friendly. This is a voice conversation, so avoid special characters and emojis.", + } + ], + "task_messages": [ + { + "role": "system", + "content": "Warmly greet the customer and ask how many people are in their party.", + } + ], + "functions": [ + { + "type": "function", + "function": { + "name": "collect_party_size", + "handler": collect_party_size, + "description": "Record the number of people in the party", + "parameters": { + "type": "object", + "properties": {"size": {"type": "integer", "minimum": 1, "maximum": 12}}, + "required": ["size"], + }, + "transition_callback": handle_party_size_collection, + }, + } + ], + } + + +def create_time_selection_node() -> NodeConfig: + """Create node for time selection and availability check.""" + logger.debug("Creating time selection node") + return { + "task_messages": [ + { + "role": "system", + "content": "Ask what time they'd like to dine. Restaurant is open 5 PM to 10 PM.", + } + ], + "functions": [ + { + "type": "function", + "function": { + "name": "check_availability", + "handler": check_availability, + "description": "Check availability for requested time", + "parameters": { + "type": "object", + "properties": { + "time": { + "type": "string", + "pattern": "^([5-9]|10):00 PM$", # Matches "5:00 PM" through "10:00 PM" + "description": "Reservation time (e.g., '6:00 PM')", + }, + "party_size": {"type": "integer"}, + }, + "required": ["time", "party_size"], + }, + "transition_callback": handle_availability_check, + }, + } + ], + } + + +def create_confirmation_node() -> NodeConfig: + """Create confirmation node for successful reservations.""" + return { + "task_messages": [ + { + "role": "system", + "content": "Confirm the reservation details and ask if they need anything else.", + } + ], + "functions": [ + { + "type": "function", + "function": { + "name": "end_conversation", + "description": "End the conversation", + "parameters": {"type": "object", "properties": {}}, + "transition_callback": handle_end, + }, + } + ], + } + + +def create_no_availability_node(alternative_times: list[str]) -> NodeConfig: + """Create node for handling no availability.""" + times_list = ", ".join(alternative_times) + return { + "task_messages": [ + { + "role": "system", + "content": ( + f"Apologize that the requested time is not available. " + f"Suggest these alternative times: {times_list}. " + "Ask if they'd like to try one of these times." + ), + } + ], + "functions": [ + { + "type": "function", + "function": { + "name": "check_availability", + "handler": check_availability, + "description": "Check availability for new time", + "parameters": { + "type": "object", + "properties": { + "time": { + "type": "string", + "pattern": "^([5-9]|10):00 PM$", + "description": "Reservation time (e.g., '6:00 PM')", + }, + "party_size": {"type": "integer"}, + }, + "required": ["time", "party_size"], + }, + "transition_callback": handle_availability_check, + }, + }, + { + "type": "function", + "function": { + "name": "end_conversation", + "description": "End the conversation", + "parameters": {"type": "object", "properties": {}}, + "transition_callback": handle_end, + }, + }, + ], + } + + +def create_end_node() -> NodeConfig: + """Create the final node.""" + return { + "task_messages": [ + { + "role": "system", + "content": "Thank them and end the conversation.", + } + ], + "functions": [], + "post_actions": [{"type": "end_conversation"}], + } + + +# Main setup +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, _) = await configure(session) + + transport = DailyTransport( + room_url, + None, + "Reservation bot", + DailyParams( + audio_out_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + ), + ) + + stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", + ) + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + + context = OpenAILLMContext() + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), + stt, + context_aggregator.user(), + llm, + tts, + transport.output(), + context_aggregator.assistant(), + ] + ) + + task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) + + # Initialize flow manager + flow_manager = FlowManager( + task=task, + llm=llm, + context_aggregator=context_aggregator, + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + logger.debug("Initializing flow manager") + await flow_manager.initialize() + logger.debug("Setting initial node") + await flow_manager.set_node("initial", create_initial_node()) + + runner = PipelineRunner() + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/static/restaurant_reservation.py b/examples/static/restaurant_reservation.py deleted file mode 100644 index c12cdf5..0000000 --- a/examples/static/restaurant_reservation.py +++ /dev/null @@ -1,248 +0,0 @@ -# -# Copyright (c) 2024, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import asyncio -import os -import sys -from pathlib import Path - -import aiohttp -from dotenv import load_dotenv -from loguru import logger -from pipecat.audio.vad.silero import SileroVADAnalyzer -from pipecat.pipeline.pipeline import Pipeline -from pipecat.pipeline.runner import PipelineRunner -from pipecat.pipeline.task import PipelineParams, PipelineTask -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.cartesia import CartesiaTTSService -from pipecat.services.deepgram import DeepgramSTTService -from pipecat.services.openai import OpenAILLMService -from pipecat.transports.services.daily import DailyParams, DailyTransport - -from pipecat_flows import FlowArgs, FlowConfig, FlowManager, FlowResult - -sys.path.append(str(Path(__file__).parent.parent)) -from runner import configure - -load_dotenv(override=True) - -logger.remove(0) -logger.add(sys.stderr, level="DEBUG") - -# Flow Configuration - Restaurant Reservation System -# -# This configuration defines a streamlined restaurant reservation system with the following states: -# -# 1. start -# - Initial state collecting party size information -# - Functions: -# * record_party_size (node function, validates 1-12 people) -# * get_time (edge function, transitions to time selection) -# - Expected flow: Greet -> Ask party size -> Record -> Transition to time -# -# 2. get_time -# - Collects preferred reservation time -# - Operating hours: 5 PM - 10 PM -# - Functions: -# * record_time (node function, collects time in HH:MM format) -# * confirm (edge function, transitions to confirmation) -# - Expected flow: Ask preferred time -> Record time -> Proceed to confirmation -# -# 3. confirm -# - Reviews reservation details with guest -# - Functions: -# * end (edge function, transitions to end) -# - Expected flow: Review details -> Confirm -> End conversation -# -# 4. end -# - Final state that closes the conversation -# - No functions available -# - Post-action: Ends conversation -# -# This simplified flow demonstrates both node functions (which perform operations within -# a state) and edge functions (which transition between states), while maintaining a -# clear and efficient reservation process. - - -# Type definitions -class PartySizeResult(FlowResult): - size: int - - -class TimeResult(FlowResult): - time: str - - -# Function handlers -async def record_party_size(args: FlowArgs) -> FlowResult: - """Handler for recording party size.""" - size = args["size"] - # In a real app, this would store the reservation details - return PartySizeResult(size=size) - - -async def record_time(args: FlowArgs) -> FlowResult: - """Handler for recording reservation time.""" - time = args["time"] - # In a real app, this would validate availability and store the time - return TimeResult(time=time) - - -flow_config: FlowConfig = { - "initial_node": "start", - "nodes": { - "start": { - "role_messages": [ - { - "role": "system", - "content": "You are a restaurant reservation assistant for La Maison, an upscale French restaurant. You must ALWAYS use one of the available functions to progress the conversation. This is a phone conversations and your responses will be converted to audio. Avoid outputting special characters and emojis. Be causal and friendly.", - } - ], - "task_messages": [ - { - "role": "system", - "content": "Warmly greet the customer and ask how many people are in their party.", - } - ], - "functions": [ - { - "type": "function", - "function": { - "name": "record_party_size", - "handler": record_party_size, - "description": "Record the number of people in the party", - "parameters": { - "type": "object", - "properties": { - "size": {"type": "integer", "minimum": 1, "maximum": 12} - }, - "required": ["size"], - }, - "transition_to": "get_time", - }, - }, - ], - }, - "get_time": { - "task_messages": [ - { - "role": "system", - "content": "Ask what time they'd like to dine. Restaurant is open 5 PM to 10 PM. After they provide a time, confirm it's within operating hours before recording. Use 24-hour format for internal recording (e.g., 17:00 for 5 PM).", - } - ], - "functions": [ - { - "type": "function", - "function": { - "name": "record_time", - "handler": record_time, - "description": "Record the requested time", - "parameters": { - "type": "object", - "properties": { - "time": { - "type": "string", - "pattern": "^(17|18|19|20|21|22):([0-5][0-9])$", - "description": "Reservation time in 24-hour format (17:00-22:00)", - } - }, - "required": ["time"], - }, - "transition_to": "confirm", - }, - }, - ], - }, - "confirm": { - "task_messages": [ - { - "role": "system", - "content": "Confirm the reservation details and end the conversation.", - } - ], - "functions": [ - { - "type": "function", - "function": { - "name": "end", - "description": "End the conversation", - "parameters": {"type": "object", "properties": {}}, - "transition_to": "end", - }, - } - ], - }, - "end": { - "task_messages": [ - {"role": "system", "content": "Thank them and end the conversation."} - ], - "functions": [], - "post_actions": [{"type": "end_conversation"}], - }, - }, -} - - -async def main(): - async with aiohttp.ClientSession() as session: - (room_url, _) = await configure(session) - - transport = DailyTransport( - room_url, - None, - "Reservation bot", - DailyParams( - audio_out_enabled=True, - vad_enabled=True, - vad_analyzer=SileroVADAnalyzer(), - vad_audio_passthrough=True, - ), - ) - - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) - tts = CartesiaTTSService( - api_key=os.getenv("CARTESIA_API_KEY"), - voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady - ) - llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") - - context = OpenAILLMContext() - context_aggregator = llm.create_context_aggregator(context) - - pipeline = Pipeline( - [ - transport.input(), # Transport user input - stt, # STT - context_aggregator.user(), # User responses - llm, # LLM - tts, # TTS - transport.output(), # Transport bot output - context_aggregator.assistant(), # Assistant spoken responses - ] - ) - - task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) - - # Initialize flow manager with LLM - flow_manager = FlowManager( - task=task, - llm=llm, - context_aggregator=context_aggregator, - flow_config=flow_config, - ) - - @transport.event_handler("on_first_participant_joined") - async def on_first_participant_joined(transport, participant): - await transport.capture_participant_transcription(participant["id"]) - # Initialize the flow processor - await flow_manager.initialize() - - runner = PipelineRunner() - await runner.run(task) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 2655eda..852b662 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -304,7 +304,9 @@ def decrease_pending_function_calls() -> None: f"Function call completed: {name} (remaining: {self._pending_function_calls})" ) - async def on_context_updated_edge(args: Dict[str, Any], result_callback: Callable) -> None: + async def on_context_updated_edge( + args: Dict[str, Any], result: Any, result_callback: Callable + ) -> None: """Handle context updates for edge functions with transitions.""" try: decrease_pending_function_calls() @@ -316,7 +318,14 @@ async def on_context_updated_edge(args: Dict[str, Any], result_callback: Callabl await self.set_node(transition_to, self.nodes[transition_to]) elif transition_callback: # Dynamic flow logger.debug(f"Dynamic transition for: {name}") - await transition_callback(args, self) + # Check callback signature + sig = inspect.signature(transition_callback) + if len(sig.parameters) == 2: + # Old style: (args, flow_manager) + await transition_callback(args, self) + else: + # New style: (args, result, flow_manager) + await transition_callback(args, result, self) # Reset counter after transition completes self._pending_function_calls = 0 logger.debug("Reset pending function calls counter") @@ -365,7 +374,7 @@ async def transition_func( # For node functions, allow immediate completion (run_llm=True) async def on_context_updated() -> None: if is_edge_function: - await on_context_updated_edge(args, result_callback) + await on_context_updated_edge(args, result, result_callback) else: await on_context_updated_node()