Skip to content

Commit

Permalink
feat: Truncate context window based on session settings (#381)
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Jun 3, 2024
1 parent e9de8db commit 9998986
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 19 deletions.
34 changes: 33 additions & 1 deletion agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: openapi.yaml
# timestamp: 2024-05-28T03:07:50+00:00
# timestamp: 2024-06-01T12:26:54+00:00

from __future__ import annotations

Expand Down Expand Up @@ -124,6 +124,14 @@ class Session(BaseModel):
"""
Render system and assistant message content as jinja templates
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class CreateSessionRequest(BaseModel):
Expand Down Expand Up @@ -151,6 +159,14 @@ class CreateSessionRequest(BaseModel):
"""
Render system and assistant message content as jinja templates
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class UpdateSessionRequest(BaseModel):
Expand All @@ -166,6 +182,14 @@ class UpdateSessionRequest(BaseModel):
"""
Optional metadata
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class UpdateUserRequest(BaseModel):
Expand Down Expand Up @@ -753,6 +777,14 @@ class PatchSessionRequest(BaseModel):
"""
Optional metadata
"""
token_budget: int | None = None
"""
Threshold value for the adaptive context functionality
"""
context_overflow: str | None = None
"""
Action to start on context window overflow
"""


class PartialFunctionDef(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ class SessionData(BaseModel):
metadata: Dict = {}
user_metadata: Optional[Dict] = None
agent_metadata: Dict = {}
token_budget: int | None = None
context_overflow: str | None = None
8 changes: 0 additions & 8 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
model_api_key: str = env.str("MODEL_API_KEY", default=None)
model_inference_url: str = env.str("MODEL_INFERENCE_URL", default=None)
openai_api_key: str = env.str("OPENAI_API_KEY", default="")
summarization_ratio_threshold: float = env.float(
"MAX_TOKENS_RATIO_TO_SUMMARIZE", default=0.5
)
summarization_tokens_threshold: int = env.int(
"SUMMARIZATION_TOKENS_THRESHOLD", default=2048
)
summarization_model_name: str = env.str(
"SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo"
)
Expand Down Expand Up @@ -78,8 +72,6 @@
debug=debug,
cozo_host=cozo_host,
cozo_auth=cozo_auth,
summarization_ratio_threshold=summarization_ratio_threshold,
summarization_tokens_threshold=summarization_tokens_threshold,
worker_url=worker_url,
sentry_dsn=sentry_dsn,
temporal_endpoint=temporal_endpoint,
Expand Down
12 changes: 11 additions & 1 deletion agents-api/agents_api/models/session/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def create_session_query(
situation: str | None,
metadata: dict = {},
render_templates: bool = False,
token_budget: int | None = None,
context_overflow: str | None = None,
) -> tuple[str, dict]:
"""
Constructs and executes a datalog query to create a new session in the database.
Expand All @@ -33,6 +35,8 @@ def create_session_query(
- situation (str | None): The situation/context of the session.
- metadata (dict): Additional metadata for the session.
- render_templates (bool): Specifies whether to render templates.
- token_budget (int | None): Token count threshold to consider it as a context window overflow
- context_overflow (str | None): Action to take on context window overflow
Returns:
- pd.DataFrame: The result of the query execution.
Expand All @@ -57,12 +61,14 @@ def create_session_query(
}
} {
# Insert the new session data into the 'session' table with the specified columns.
?[session_id, developer_id, situation, metadata, render_templates] <- [[
?[session_id, developer_id, situation, metadata, render_templates, token_budget, context_overflow] <- [[
$session_id,
$developer_id,
$situation,
$metadata,
$render_templates,
$token_budget,
$context_overflow,
]]
:insert sessions {
Expand All @@ -71,6 +77,8 @@ def create_session_query(
situation,
metadata,
render_templates,
token_budget,
context_overflow,
}
# Specify the data to return after the query execution, typically the newly created session's ID.
:returning
Expand All @@ -87,5 +95,7 @@ def create_session_query(
"situation": situation,
"metadata": metadata,
"render_templates": render_templates,
"token_budget": token_budget,
"context_overflow": context_overflow,
},
)
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/session/get_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def get_session_query(
created_at,
metadata,
render_templates,
token_budget,
context_overflow,
] := input[developer_id, id],
*sessions{
developer_id,
Expand All @@ -54,6 +56,8 @@ def get_session_query(
updated_at: validity,
metadata,
render_templates,
token_budget,
context_overflow,
@ "NOW"
},
*session_lookup{
Expand Down
4 changes: 4 additions & 0 deletions agents-api/agents_api/models/session/list_sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def list_sessions_query(
updated_at,
created_at,
metadata,
token_budget,
context_overflow,
] :=
input[developer_id],
*sessions{{
Expand All @@ -60,6 +62,8 @@ def list_sessions_query(
created_at,
updated_at: validity,
metadata,
token_budget,
context_overflow,
@ "NOW"
}},
*session_lookup{{
Expand Down
8 changes: 8 additions & 0 deletions agents-api/agents_api/models/session/session_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def session_data_query(
default_settings,
metadata,
render_templates,
token_budget,
context_overflow,
user_metadata,
agent_metadata,
] := input[developer_id, session_id],
Expand All @@ -59,6 +61,8 @@ def session_data_query(
updated_at: validity,
metadata,
render_templates,
token_budget,
context_overflow,
@ "NOW"
},
*session_lookup{
Expand Down Expand Up @@ -116,6 +120,8 @@ def session_data_query(
default_settings,
metadata,
render_templates,
token_budget,
context_overflow,
user_metadata,
agent_metadata,
] := input[developer_id, session_id],
Expand All @@ -128,6 +134,8 @@ def session_data_query(
updated_at: validity,
metadata,
render_templates,
token_budget,
context_overflow,
@ "NOW"
},
*session_lookup{
Expand Down
7 changes: 5 additions & 2 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from functools import wraps
from typing import Callable
from typing import Callable, ParamSpec

import pandas as pd

from ..clients.cozo import client as cozo_client


def cozo_query(func: Callable[..., tuple[str, dict]]):
P = ParamSpec("P")


def cozo_query(func: Callable[P, tuple[str, dict]]):
"""
Decorator that wraps a function that takes arbitrary arguments, and
returns a (query string, variables) tuple.
Expand Down
2 changes: 2 additions & 0 deletions agents-api/agents_api/routers/sessions/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Settings(BaseModel):
min_p: float | None = Field(default=0.01)
preset: Preset | None = Field(default=None)
tools: list[Tool] | None = Field(default=None)
token_budget: int | None = Field(default=None)
context_overflow: str | None = Field(default=None)

@field_validator("max_tokens")
def set_max_tokens(cls, max_tokens):
Expand Down
6 changes: 6 additions & 0 deletions agents-api/agents_api/routers/sessions/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ async def create_session(
situation=request.situation,
metadata=request.metadata or {},
render_templates=request.render_templates or False,
token_budget=request.token_budget,
context_overflow=request.context_overflow,
)

return ResourceCreatedResponse(
Expand Down Expand Up @@ -151,6 +153,8 @@ async def update_session(
developer_id=x_developer_id,
situation=request.situation,
metadata=request.metadata,
token_budget=request.token_budget,
context_overflow=request.context_overflow,
)

return ResourceUpdatedResponse(
Expand Down Expand Up @@ -182,6 +186,8 @@ async def patch_session(
developer_id=x_developer_id,
situation=request.situation,
metadata=request.metadata,
token_budget=request.token_budget,
context_overflow=request.context_overflow,
)

return ResourceUpdatedResponse(
Expand Down
56 changes: 49 additions & 7 deletions agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import litellm
from litellm import acompletion

from ...autogen.openapi_model import InputChatMLMessage, Tool, DocIds
from ...autogen.openapi_model import InputChatMLMessage, Tool, DocIds, Role
from ...clients.embed import embed
from ...clients.temporal import run_summarization_task
from ...clients.worker.types import ChatML
Expand All @@ -23,7 +23,6 @@
from ...common.utils.json import CustomJSONEncoder
from ...common.utils.messages import stringify_content
from ...env import (
summarization_tokens_threshold,
docs_embedding_service_url,
docs_embedding_model_id,
)
Expand Down Expand Up @@ -116,15 +115,18 @@ def _remove_messages(

return result, token_count

def truncate(
self, messages: list[Entry], summarization_tokens_threshold: int
def _truncate_context(
self, messages: list[Entry], summarization_tokens_threshold: int | None
) -> list[Entry]:
def rm_thoughts(m):
return m.role == "system" and m.name == "thought"

def rm_user_assistant(m):
return m.role in ("user", "assistant")

if summarization_tokens_threshold is None:
return messages

token_count = reduce(lambda c, e: (e.token_count or 0) + c, messages, 0)

if token_count <= summarization_tokens_threshold:
Expand Down Expand Up @@ -153,6 +155,31 @@ def rm_user_assistant(m):

raise InputTooBigError(token_count, summarization_tokens_threshold)

def _truncate_entries(
self, messages: list[Entry], token_count_threshold: int
) -> list[Entry]:
if not len(messages):
return messages

result: list[Entry] = []
token_cnt, offset = 0, 0
if messages[0].role == Role.system:
result: list[Entry] = messages[0]
token_cnt, offset = messages[0].token_count, 1

for m in reversed(messages[offset:]):
if token_cnt < token_count_threshold:
result.append(m)
else:
break

token_cnt += m.token_count

if offset:
result.append(messages[0])

return list(reversed(result))

async def run(
self, new_input, settings: Settings
) -> tuple[ChatCompletion, Entry, Callable | None, DocIds]:
Expand All @@ -170,7 +197,8 @@ async def run(

# Generate response
response = await self.generate(
self.truncate(init_context, summarization_tokens_threshold), final_settings
self._truncate_context(init_context, final_settings.token_budget),
final_settings,
)

# Save response to session
Expand Down Expand Up @@ -219,6 +247,10 @@ async def forward(
new_input: list[Entry],
settings: Settings,
) -> tuple[list[ChatML], Settings, DocIds]:
if session_data is not None:
settings.token_budget = session_data.token_budget
settings.context_overflow = session_data.context_overflow

stringified_input = []
for msg in new_input:
stringified_input.append(
Expand Down Expand Up @@ -452,10 +484,20 @@ async def backward(
)

entries.append(new_entry)
summarization_task = None

if (
final_settings.token_budget is not None
and total_tokens >= final_settings.token_budget
):
if final_settings.context_overflow == "truncate":
entries = self._truncate_entries(entries, final_settings.token_budget)
elif final_settings.context_overflow == "adaptive":
summarization_task = run_summarization_task

add_entries_query(entries)

if total_tokens >= summarization_tokens_threshold:
return run_summarization_task
return summarization_task


class PlainCompletionSession(BaseSession):
Expand Down
Loading

0 comments on commit 9998986

Please sign in to comment.