Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(3/n): agent resume_turn #1194

Merged
merged 19 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 135 additions & 9 deletions llama_stack/providers/inline/agents/meta_reference/agent_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
AgentTurnResponseTurnAwaitingInputPayload,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
AgentTurnResumeRequest,
Attachment,
Document,
InferenceStep,
Expand Down Expand Up @@ -156,6 +157,15 @@ def turn_to_messages(self, turn: Turn) -> List[Message]:
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)

async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))

for turn in turns:
messages.extend(self.turn_to_messages(turn))
return messages

async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
Expand All @@ -168,14 +178,7 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn
raise ValueError(f"Session {request.session_id} not found")

turns = await self.storage.get_session_turns(request.session_id)

messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))

for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))

messages = await self.get_messages_from_turns(turns)
messages.extend(request.messages)

turn_id = str(uuid.uuid4())
Expand Down Expand Up @@ -246,6 +249,119 @@ async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> Asyn

yield chunk

async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator:
with tracing.span("resume_turn") as span:
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("session_id", request.session_id)
span.set_attribute("turn_id", request.turn_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"

session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")

turns = await self.storage.get_session_turns(request.session_id)
messages = await self.get_messages_from_turns(turns)
messages.extend(request.tool_responses)

last_turn_messages = [
x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]

# get the steps from the turn id
steps = []
if len(turns) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be true always?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but just in case kvstore gets destroyed.

steps = turns[-1].steps

# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now()
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you error out if there is no step found?

Copy link
Contributor Author

@yanxi0830 yanxi0830 Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree its a bit confusing (had to play around with the react_agent app). Let me add a comment here.

We do not error out here b/c in the case of ReActAgent (with a custom tool parser), server do not output a tool_execution step_start, and don't have the step. However, we should still allow the turn to be resumed with the ToolCallResponse in this case because server outputs message (no ToolCall) --> parser parse into ToolCall --> client execute ToolCall --> resume turn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanxi0830 hm yeah this is confusing and broke my mental model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case, the server wouldn't have sent a step_start event either

Copy link
Contributor Author

@yanxi0830 yanxi0830 Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, in this case (of custom tool parsers), the server wouldn't have sent a step_start event, but we would still like to log the ToolExecutionStep in resume as a client tool call did indeed happen.

turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=[
ToolResponse(
call_id=x.call_id,
tool_name=x.tool_name,
content=x.content,
)
for x in request.tool_responses
],
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)

output_message = None
async for chunk in self.run(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic seems to be significantly overlapping with impl of create_and_execute_turn right ?
May be a follow up : but would be good to consolidate the code so that logic of running a turn is in one place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, doing that for 0.1.5 (so that existing create_and_execute_turn are fully verified and we can depreacate the allow_resume_turn flag: #1212

session_id=request.session_id,
turn_id=request.turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
continue

assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)

yield chunk

assert output_message is not None

last_turn_start_time = datetime.now()
if len(turns) > 0:
last_turn_start_time = turns[-1].started_at

turn = Turn(
turn_id=request.turn_id,
session_id=request.session_id,
input_messages=last_turn_messages,
output_message=output_message,
started_at=last_turn_start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)

if output_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
turn=turn,
)
)
)
else:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)

yield chunk

async def run(
self,
session_id: str,
Expand Down Expand Up @@ -636,7 +752,6 @@ async def _run(
)
)
)

tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
Expand All @@ -654,6 +769,17 @@ async def _run(

# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[],
started_at=datetime.now(),
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to save n_iter of inference so that we respect self.agent_config.max_infer_iters, which btw we don't currently respect when custom tool is used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meta-llama/llama-stack-client-python#158

The max_infer_iters is still used here to track number of times we call resume.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make it so that the total number of inference doesn't exceed max_infer_iters? Currently we could have max_infer_iters^2 of inference in the worst case (each resume could have max_infer_iters inference)

Copy link
Contributor Author

@yanxi0830 yanxi0830 Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah, this will be max_infer_iters^2, I think that's what the current behaviour now if creating a second turn instead of resume too.

Will need to think a bit on how we can keep track of the same max_infer_iters across the client & server boundary in a follow up. E.g. we need to introduce extra params to communicate b/w the number of iters, whether we should use this max_infer_iters on client side.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG for following up on this. Thanks!

yield message
return

Expand Down
21 changes: 20 additions & 1 deletion llama_stack/providers/inline/agents/meta_reference/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
Session,
Turn,
Expand Down Expand Up @@ -179,7 +180,25 @@ async def resume_agent_turn(
tool_responses: List[ToolResponseMessage],
stream: Optional[bool] = False,
) -> AsyncGenerator:
pass
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")

async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.resume_turn(request):
yield event

async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from pydantic import BaseModel

from llama_stack.apis.agents import Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.providers.utils.kvstore import KVStore

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,3 +84,15 @@ async def get_session_turns(self, session_id: str) -> List[Turn]:
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns

async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)

async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]:
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None