-
Notifications
You must be signed in to change notification settings - Fork 878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(3/n): agent resume_turn #1194
Changes from all commits
c7e8425
01f90df
cd36a77
ee3c174
157cf32
22355e3
4923270
5e00e9f
9f2f6c9
6d08a93
9a07e70
97f9580
9c40529
0de38a2
99bc54b
2c06704
fa4a56c
b1b45ed
ea050f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
AgentTurnResponseTurnAwaitingInputPayload, | ||
AgentTurnResponseTurnCompletePayload, | ||
AgentTurnResponseTurnStartPayload, | ||
AgentTurnResumeRequest, | ||
Attachment, | ||
Document, | ||
InferenceStep, | ||
|
@@ -156,6 +157,15 @@ def turn_to_messages(self, turn: Turn) -> List[Message]: | |
async def create_session(self, name: str) -> str: | ||
return await self.storage.create_session(name) | ||
|
||
async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: | ||
messages = [] | ||
if self.agent_config.instructions != "": | ||
messages.append(SystemMessage(content=self.agent_config.instructions)) | ||
|
||
for turn in turns: | ||
messages.extend(self.turn_to_messages(turn)) | ||
return messages | ||
|
||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: | ||
with tracing.span("create_and_execute_turn") as span: | ||
span.set_attribute("session_id", request.session_id) | ||
|
@@ -168,14 +178,7 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn | |
raise ValueError(f"Session {request.session_id} not found") | ||
|
||
turns = await self.storage.get_session_turns(request.session_id) | ||
|
||
messages = [] | ||
if self.agent_config.instructions != "": | ||
messages.append(SystemMessage(content=self.agent_config.instructions)) | ||
|
||
for i, turn in enumerate(turns): | ||
messages.extend(self.turn_to_messages(turn)) | ||
|
||
messages = await self.get_messages_from_turns(turns) | ||
messages.extend(request.messages) | ||
|
||
turn_id = str(uuid.uuid4()) | ||
|
@@ -246,6 +249,119 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn | |
|
||
yield chunk | ||
|
||
async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: | ||
with tracing.span("resume_turn") as span: | ||
span.set_attribute("agent_id", self.agent_id) | ||
span.set_attribute("session_id", request.session_id) | ||
span.set_attribute("turn_id", request.turn_id) | ||
span.set_attribute("request", request.model_dump_json()) | ||
assert request.stream is True, "Non-streaming not supported" | ||
|
||
session_info = await self.storage.get_session_info(request.session_id) | ||
if session_info is None: | ||
raise ValueError(f"Session {request.session_id} not found") | ||
|
||
turns = await self.storage.get_session_turns(request.session_id) | ||
messages = await self.get_messages_from_turns(turns) | ||
messages.extend(request.tool_responses) | ||
|
||
last_turn_messages = [ | ||
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) | ||
] | ||
|
||
# get the steps from the turn id | ||
steps = [] | ||
if len(turns) > 0: | ||
steps = turns[-1].steps | ||
|
||
# mark tool execution step as complete | ||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client), | ||
# we'll create a new tool execution step with current time | ||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( | ||
request.session_id, request.turn_id | ||
) | ||
now = datetime.now() | ||
tool_execution_step = ToolExecutionStep( | ||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shouldn't you error out if there is no step found? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree its a bit confusing (had to play around with the react_agent app). Let me add a comment here. We do not error out here b/c in the case of ReActAgent (with a custom tool parser), server do not output a tool_execution step_start, and don't have the step. However, we should still allow the turn to be resumed with the ToolCallResponse in this case because server outputs message (no ToolCall) --> parser parse into ToolCall --> client execute ToolCall --> resume turn. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yanxi0830 hm yeah this is confusing and broke my mental model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in that case, the server wouldn't have sent a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, in this case (of custom tool parsers), the server wouldn't have sent a |
||
turn_id=request.turn_id, | ||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), | ||
tool_responses=[ | ||
ToolResponse( | ||
call_id=x.call_id, | ||
tool_name=x.tool_name, | ||
content=x.content, | ||
) | ||
for x in request.tool_responses | ||
], | ||
completed_at=now, | ||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), | ||
) | ||
steps.append(tool_execution_step) | ||
yield AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseStepCompletePayload( | ||
step_type=StepType.tool_execution.value, | ||
step_id=tool_execution_step.step_id, | ||
step_details=tool_execution_step, | ||
) | ||
) | ||
) | ||
|
||
output_message = None | ||
async for chunk in self.run( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic seems to be significantly overlapping with impl of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, doing that for 0.1.5 (so that existing |
||
session_id=request.session_id, | ||
turn_id=request.turn_id, | ||
input_messages=messages, | ||
sampling_params=self.agent_config.sampling_params, | ||
stream=request.stream, | ||
): | ||
if isinstance(chunk, CompletionMessage): | ||
output_message = chunk | ||
continue | ||
|
||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" | ||
event = chunk.event | ||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: | ||
steps.append(event.payload.step_details) | ||
|
||
yield chunk | ||
|
||
assert output_message is not None | ||
|
||
last_turn_start_time = datetime.now() | ||
if len(turns) > 0: | ||
last_turn_start_time = turns[-1].started_at | ||
|
||
turn = Turn( | ||
turn_id=request.turn_id, | ||
session_id=request.session_id, | ||
input_messages=last_turn_messages, | ||
output_message=output_message, | ||
started_at=last_turn_start_time, | ||
completed_at=datetime.now(), | ||
steps=steps, | ||
) | ||
await self.storage.add_turn_to_session(request.session_id, turn) | ||
|
||
if output_message.tool_calls: | ||
chunk = AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseTurnAwaitingInputPayload( | ||
turn=turn, | ||
) | ||
) | ||
) | ||
else: | ||
chunk = AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
payload=AgentTurnResponseTurnCompletePayload( | ||
turn=turn, | ||
) | ||
) | ||
) | ||
|
||
yield chunk | ||
|
||
async def run( | ||
self, | ||
session_id: str, | ||
|
@@ -636,7 +752,6 @@ async def _run( | |
) | ||
) | ||
) | ||
|
||
tool_call = message.tool_calls[0] | ||
yield AgentTurnResponseStreamChunk( | ||
event=AgentTurnResponseEvent( | ||
|
@@ -654,6 +769,17 @@ async def _run( | |
|
||
# If tool is a client tool, yield CompletionMessage and return | ||
if tool_call.tool_name in client_tools: | ||
await self.storage.set_in_progress_tool_call_step( | ||
ashwinb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
session_id, | ||
turn_id, | ||
ToolExecutionStep( | ||
step_id=step_id, | ||
turn_id=turn_id, | ||
tool_calls=[tool_call], | ||
tool_responses=[], | ||
started_at=datetime.now(), | ||
), | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we also need to save n_iter of inference so that we respect self.agent_config.max_infer_iters, which btw we don't currently respect when custom tool is used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. meta-llama/llama-stack-client-python#158 The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make it so that the total number of inference doesn't exceed max_infer_iters? Currently we could have max_infer_iters^2 of inference in the worst case (each resume could have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yeah, this will be Will need to think a bit on how we can keep track of the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SG for following up on this. Thanks! |
||
yield message | ||
return | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be true always?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, but just in case kvstore gets destroyed.