Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Truncate context window based on session settings #381

Merged
merged 1 commit into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-05-28T03:07:50+00:00
# timestamp: 2024-06-01T12:26:54+00:00

from __future__ import annotations

Expand Down Expand Up @@ -124,6 +124,14 @@ class Session(BaseModel):
"""
Render system and assistant message content as jinja templates
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class CreateSessionRequest(BaseModel):
Expand Down Expand Up @@ -151,6 +159,14 @@ class CreateSessionRequest(BaseModel):
"""
Render system and assistant message content as jinja templates
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class UpdateSessionRequest(BaseModel):
Expand All @@ -166,6 +182,14 @@ class UpdateSessionRequest(BaseModel):
"""
Optional metadata
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class UpdateUserRequest(BaseModel):
Expand Down Expand Up @@ -753,6 +777,14 @@ class PatchSessionRequest(BaseModel):
"""
Optional metadata
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class PartialFunctionDef(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ class SessionData(BaseModel):
metadata: Dict = {}
user_metadata: Optional[Dict] = None
agent_metadata: Dict = {}
token_budget: int | None = None
context_overflow: str | None = None
8 changes: 0 additions & 8 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
model_api_key: str = env.str("MODEL_API_KEY", default=None)
model_inference_url: str = env.str("MODEL_INFERENCE_URL", default=None)
openai_api_key: str = env.str("OPENAI_API_KEY", default="")
summarization_ratio_threshold: float = env.float(
"MAX_TOKENS_RATIO_TO_SUMMARIZE", default=0.5
)
summarization_tokens_threshold: int = env.int(
"SUMMARIZATION_TOKENS_THRESHOLD", default=2048
)
summarization_model_name: str = env.str(
"SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo"
)
Expand Down Expand Up @@ -78,8 +72,6 @@
debug=debug,
cozo_host=cozo_host,
cozo_auth=cozo_auth,
summarization_ratio_threshold=summarization_ratio_threshold,
summarization_tokens_threshold=summarization_tokens_threshold,
worker_url=worker_url,
sentry_dsn=sentry_dsn,
temporal_endpoint=temporal_endpoint,
Expand Down
12 changes: 11 additions & 1 deletion agents-api/agents_api/models/session/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def create_session_query(
situation: str | None,
metadata: dict = {},
render_templates: bool = False,
token_budget: int | None = None,
context_overflow: str | None = None,
) -> tuple[str, dict]:
"""
Constructs and executes a datalog query to create a new session in the database.
Expand All @@ -33,6 +35,8 @@ def create_session_query(
- situation (str | None): The situation/context of the session.
- metadata (dict): Additional metadata for the session.
- render_templates (bool): Specifies whether to render templates.
- token_budget (int | None): Token count threshold to consider it as a context window overflow
- context_overflow (str | None): Action to take on context window overflow

Returns:
- pd.DataFrame: The result of the query execution.
Expand All @@ -57,12 +61,14 @@ def create_session_query(
}
} {
# Insert the new session data into the 'session' table with the specified columns.
?[session_id, developer_id, situation, metadata, render_templates] <- [[
?[session_id, developer_id, situation, metadata, render_templates, token_budget, context_overflow] <- [[
$session_id,
$developer_id,
$situation,
$metadata,
$render_templates,
$token_budget,
$context_overflow,
]]

:insert sessions {
Expand All @@ -71,6 +77,8 @@ def create_session_query(
situation,
metadata,
render_templates,
token_budget,
context_overflow,
}
# Specify the data to return after the query execution, typically the newly created session's ID.
:returning
Expand All @@ -87,5 +95,7 @@ def create_session_query(
"situation": situation,
"metadata": metadata,
"render_templates": render_templates,
"token_budget": token_budget,
"context_overflow": context_overflow,
},
)
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/session/get_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def get_session_query(
created_at,
metadata,
render_templates,
token_budget,
context_overflow,
] := input[developer_id, id],
*sessions{
developer_id,
Expand All @@ -54,6 +56,8 @@ def get_session_query(
updated_at: validity,
metadata,
render_templates,
token_budget,
context_overflow,
@ "NOW"
},
*session_lookup{
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/session/list_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def list_sessions_query(
updated_at,
created_at,
metadata,
token_budget,
context_overflow,
] :=
input[developer_id],
*sessions{{
Expand All @@ -60,6 +62,8 @@ def list_sessions_query(
created_at,
updated_at: validity,
metadata,
token_budget,
context_overflow,
@ "NOW"
}},
*session_lookup{{
Expand Down
8 changes: 8 additions & 0 deletions agents-api/agents_api/models/session/session_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def session_data_query(
default_settings,
metadata,
render_templates,
token_budget,
context_overflow,
user_metadata,
agent_metadata,
] := input[developer_id, session_id],
Expand All @@ -59,6 +61,8 @@ def session_data_query(
updated_at: validity,
metadata,
render_templates,
token_budget,
context_overflow,
@ "NOW"
},
*session_lookup{
Expand Down Expand Up @@ -116,6 +120,8 @@ def session_data_query(
default_settings,
metadata,
render_templates,
token_budget,
context_overflow,
user_metadata,
agent_metadata,
] := input[developer_id, session_id],
Expand All @@ -128,6 +134,8 @@ def session_data_query(
updated_at: validity,
metadata,
render_templates,
token_budget,
context_overflow,
@ "NOW"
},
*session_lookup{
Expand Down
7 changes: 5 additions & 2 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from functools import wraps
from typing import Callable
from typing import Callable, ParamSpec

import pandas as pd

from ..clients.cozo import client as cozo_client


def cozo_query(func: Callable[..., tuple[str, dict]]):
P = ParamSpec("P")


def cozo_query(func: Callable[P, tuple[str, dict]]):
"""
Decorator that wraps a function that takes arbitrary arguments, and
returns a (query string, variables) tuple.
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/routers/sessions/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Settings(BaseModel):
min_p: float | None = Field(default=0.01)
preset: Preset | None = Field(default=None)
tools: list[Tool] | None = Field(default=None)
token_budget: int | None = Field(default=None)
context_overflow: str | None = Field(default=None)

@field_validator("max_tokens")
def set_max_tokens(cls, max_tokens):
Expand Down
6 changes: 6 additions & 0 deletions agents-api/agents_api/routers/sessions/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ async def create_session(
situation=request.situation,
metadata=request.metadata or {},
render_templates=request.render_templates or False,
token_budget=request.token_budget,
context_overflow=request.context_overflow,
)

return ResourceCreatedResponse(
Expand Down Expand Up @@ -151,6 +153,8 @@ async def update_session(
developer_id=x_developer_id,
situation=request.situation,
metadata=request.metadata,
token_budget=request.token_budget,
context_overflow=request.context_overflow,
)

return ResourceUpdatedResponse(
Expand Down Expand Up @@ -182,6 +186,8 @@ async def patch_session(
developer_id=x_developer_id,
situation=request.situation,
metadata=request.metadata,
token_budget=request.token_budget,
context_overflow=request.context_overflow,
)

return ResourceUpdatedResponse(
Expand Down
56 changes: 49 additions & 7 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import litellm
from litellm import acompletion

from ...autogen.openapi_model import InputChatMLMessage, Tool, DocIds
from ...autogen.openapi_model import InputChatMLMessage, Tool, DocIds, Role
from ...clients.embed import embed
from ...clients.temporal import run_summarization_task
from ...clients.worker.types import ChatML
Expand All @@ -23,7 +23,6 @@
from ...common.utils.json import CustomJSONEncoder
from ...common.utils.messages import stringify_content
from ...env import (
summarization_tokens_threshold,
docs_embedding_service_url,
docs_embedding_model_id,
)
Expand Down Expand Up @@ -116,15 +115,18 @@ def _remove_messages(

return result, token_count

def truncate(
self, messages: list[Entry], summarization_tokens_threshold: int
def _truncate_context(
self, messages: list[Entry], summarization_tokens_threshold: int | None
) -> list[Entry]:
def rm_thoughts(m):
return m.role == "system" and m.name == "thought"

def rm_user_assistant(m):
return m.role in ("user", "assistant")

if summarization_tokens_threshold is None:
return messages

token_count = reduce(lambda c, e: (e.token_count or 0) + c, messages, 0)

if token_count <= summarization_tokens_threshold:
Expand Down Expand Up @@ -153,6 +155,31 @@ def rm_user_assistant(m):

raise InputTooBigError(token_count, summarization_tokens_threshold)

def _truncate_entries(
self, messages: list[Entry], token_count_threshold: int
) -> list[Entry]:
if not len(messages):
return messages

result: list[Entry] = []
token_cnt, offset = 0, 0
if messages[0].role == Role.system:
result: list[Entry] = messages[0]
token_cnt, offset = messages[0].token_count, 1

for m in reversed(messages[offset:]):
if token_cnt < token_count_threshold:
result.append(m)
else:
break

token_cnt += m.token_count

if offset:
result.append(messages[0])

return list(reversed(result))

async def run(
self, new_input, settings: Settings
) -> tuple[ChatCompletion, Entry, Callable | None, DocIds]:
Expand All @@ -170,7 +197,8 @@ async def run(

# Generate response
response = await self.generate(
self.truncate(init_context, summarization_tokens_threshold), final_settings
self._truncate_context(init_context, final_settings.token_budget),
final_settings,
)

# Save response to session
Expand Down Expand Up @@ -219,6 +247,10 @@ async def forward(
new_input: list[Entry],
settings: Settings,
) -> tuple[list[ChatML], Settings, DocIds]:
if session_data is not None:
settings.token_budget = session_data.token_budget
settings.context_overflow = session_data.context_overflow

stringified_input = []
for msg in new_input:
stringified_input.append(
Expand Down Expand Up @@ -452,10 +484,20 @@ async def backward(
)

entries.append(new_entry)
summarization_task = None

if (
final_settings.token_budget is not None
and total_tokens >= final_settings.token_budget
):
if final_settings.context_overflow == "truncate":
entries = self._truncate_entries(entries, final_settings.token_budget)
elif final_settings.context_overflow == "adaptive":
summarization_task = run_summarization_task

add_entries_query(entries)

if total_tokens >= summarization_tokens_threshold:
return run_summarization_task
return summarization_task


class PlainCompletionSession(BaseSession):
Expand Down
Loading
Loading