Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #487 (summarize call uses OpenAI even with local LLM config) #488

Merged
merged 2 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,9 +716,7 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True)
if (self.model is not None and self.model in LLM_MAX_TOKENS)
else str(LLM_MAX_TOKENS["DEFAULT"])
)
summary = summarize_messages(
model=self.model, context_window=int(self.config.context_window), message_sequence_to_summarize=message_sequence_to_summarize
)
summary = summarize_messages(agent_config=self.config, message_sequence_to_summarize=message_sequence_to_summarize)
printd(f"Got summary: {summary}")

# Metadata that's useful for the agent to see
Expand Down
24 changes: 11 additions & 13 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
import re
from typing import Optional, List, Tuple

from .constants import MESSAGE_SUMMARY_WARNING_FRAC, MEMGPT_DIR
from .utils import cosine_similarity, get_local_time, printd, count_tokens
from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC, MEMGPT_DIR
from memgpt.utils import cosine_similarity, get_local_time, printd, count_tokens
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from memgpt import utils
from .openai_tools import (
get_embedding_with_backoff,
completions_with_backoff as create,
)
from memgpt.openai_tools import get_embedding_with_backoff, chat_completion_with_backoff
from llama_index import (
VectorStoreIndex,
EmptyIndex,
Expand Down Expand Up @@ -119,11 +116,12 @@ def edit_replace(self, field, old_content, new_content):


def summarize_messages(
model,
context_window,
agent_config,
message_sequence_to_summarize,
):
"""Summarize a message sequence using GPT"""
# we need the context_window
context_window = agent_config.context_window

summary_prompt = SUMMARY_PROMPT_SYSTEM
summary_input = str(message_sequence_to_summarize)
Expand All @@ -132,17 +130,17 @@ def summarize_messages(
trunc_ratio = (MESSAGE_SUMMARY_WARNING_FRAC * context_window / summary_input_tkns) * 0.8 # For good measure...
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
summary_input = str(
[summarize_messages(model, context_window, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:]
[summarize_messages(agent_config, context_window, message_sequence_to_summarize[:cutoff])]
+ message_sequence_to_summarize[cutoff:]
)
message_sequence = [
{"role": "system", "content": summary_prompt},
{"role": "user", "content": summary_input},
]

response = create(
model=model,
response = chat_completion_with_backoff(
agent_config=agent_config,
messages=message_sequence,
context_window=context_window,
)

printd(f"summarize_messages gpt reply: {response.choices[0]}")
Expand Down
5 changes: 5 additions & 0 deletions memgpt/openai_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def chat_completion_with_backoff(agent_config, **kwargs):
from memgpt.utils import printd
from memgpt.config import AgentConfig, MemGPTConfig

# both "model" and "messages" are required for base OpenAI calls
# also required for local LLM Ollama, but not others
if "model" not in kwargs:
kwargs["model"] = agent_config.model

printd(f"Using model {agent_config.model_endpoint_type}, endpoint: {agent_config.model_endpoint}")
if agent_config.model_endpoint_type == "openai":
# openai
Expand Down