Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for larger archival memory stores #359

Merged
merged 41 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
89cf976
mark depricated API section
sarahwooders Oct 30, 2023
be6212c
add readme
sarahwooders Oct 31, 2023
b011380
add readme
sarahwooders Oct 31, 2023
59f7b71
add readme
sarahwooders Oct 31, 2023
176538b
add readme
sarahwooders Oct 31, 2023
9905266
add readme
sarahwooders Oct 31, 2023
3606959
add readme
sarahwooders Oct 31, 2023
c48803c
add readme
sarahwooders Oct 31, 2023
40cdb23
add readme
sarahwooders Oct 31, 2023
ff43c98
add readme
sarahwooders Oct 31, 2023
01db319
CLI bug fixes for azure
sarahwooders Oct 31, 2023
a11cef9
check azure before running
sarahwooders Oct 31, 2023
a47d49e
Merge branch 'cpacker:main' into main
sarahwooders Oct 31, 2023
fbe2482
Update README.md
sarahwooders Oct 31, 2023
446a1a1
Update README.md
sarahwooders Oct 31, 2023
1541482
bug fix with persona loading
sarahwooders Oct 31, 2023
5776e30
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Oct 31, 2023
d48cf23
Merge branch 'cpacker:main' into main
sarahwooders Oct 31, 2023
7a8eb80
remove print
sarahwooders Oct 31, 2023
9a5ece0
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Oct 31, 2023
d3370b3
merge
sarahwooders Nov 3, 2023
c19c2ce
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
aa6ee71
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
36bb04d
make errors for cli flags more clear
sarahwooders Nov 3, 2023
6f50db1
format
sarahwooders Nov 3, 2023
4c91a41
Merge branch 'cpacker:main' into main
sarahwooders Nov 3, 2023
dbaf4a0
Merge branch 'cpacker:main' into main
sarahwooders Nov 5, 2023
c86e1c9
fix imports
sarahwooders Nov 5, 2023
e54e762
Merge branch 'cpacker:main' into main
sarahwooders Nov 5, 2023
524a974
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 5, 2023
7baf3e7
fix imports
sarahwooders Nov 5, 2023
2fd8795
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 5, 2023
4ab4f2d
add prints
sarahwooders Nov 5, 2023
cc94b4e
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 6, 2023
9d1707d
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
1782bb9
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
caaf476
Merge branch 'main' of github.com:sarahwooders/MemGPT
sarahwooders Nov 7, 2023
6692bca
update lock
sarahwooders Nov 7, 2023
7cc1b9f
Merge branch 'cpacker:main' into main
sarahwooders Nov 7, 2023
5289cb0
update files
sarahwooders Nov 7, 2023
6de31e3
commit
sarahwooders Nov 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,11 +771,8 @@ def edit_memory(self, name, content):
return None

def edit_memory_append(self, name, content):
print("edit append")
new_len = self.memory.edit_append(name, content)
print("rebuild memory")
self.rebuild_memory()
print("done")
return None

def edit_memory_replace(self, name, old_content, new_content):
Expand Down
17 changes: 11 additions & 6 deletions memgpt/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,26 @@ def attach(
):
# loads the data contained in data source into the agent's memory
from memgpt.connectors.storage import StorageConnector
from tqdm import tqdm

agent_config = AgentConfig.load(agent)
config = MemGPTConfig.load()

# get storage connectors
source_storage = StorageConnector.get_storage_connector(name=data_source)
dest_storage = StorageConnector.get_storage_connector(agent_config=agent_config)

passages = source_storage.get_all()
for p in passages:
len(p.embedding) == config.embedding_dim, f"Mismatched embedding sizes {len(p.embedding)} != {config.embedding_dim}"
dest_storage.insert_many(passages)
size = source_storage.size()
typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN)
page_size = 100
generator = source_storage.get_all_paginated(page_size=page_size) # yields List[Passage]
for i in tqdm(range(0, size, page_size)):
passages = next(generator)
dest_storage.insert_many(passages, show_progress=False)

# save destination storage
dest_storage.save()

total_agent_passages = len(dest_storage.get_all())
total_agent_passages = dest_storage.size()

typer.secho(
f"Attached data source {data_source} to agent {agent}, consisting of {len(passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
Expand Down
28 changes: 25 additions & 3 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import re
from tqdm import tqdm
from typing import Optional, List
from typing import Optional, List, Iterator
import numpy as np
from tqdm import tqdm

Expand Down Expand Up @@ -76,9 +76,26 @@ def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfi
self.Session = sessionmaker(bind=self.engine)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension

def get_all(self) -> List[Passage]:
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
session = self.Session()
db_passages = session.query(self.db_model).all()
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = session.query(self.db_model).offset(offset).limit(page_size).all()

# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break

# Yield a list of Passage objects converted from the chunk
yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk]

# Increment the offset to get the next chunk in the next iteration
offset += page_size

def get_all(self, limit=10) -> List[Passage]:
session = self.Session()
db_passages = session.query(self.db_model).limit(limit).all()
return [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages]

def get(self, id: str) -> Optional[Passage]:
Expand All @@ -88,6 +105,11 @@ def get(self, id: str) -> Optional[Passage]:
return None
return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id)

def size(self) -> int:
# return size of table
session = self.Session()
return session.query(self.db_model).count()

def insert(self, passage: Passage):
session = self.Session()
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
Expand Down
15 changes: 13 additions & 2 deletions memgpt/connectors/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Iterator
from memgpt.config import AgentConfig, MemGPTConfig
from tqdm import tqdm
import re
Expand Down Expand Up @@ -72,11 +72,19 @@ def add_nodes(self, nodes: List[TextNode]):
self.nodes += nodes
self.index = VectorStoreIndex(self.nodes)

def get_all(self) -> List[Passage]:
def get_all_paginated(self, page_size: int = 100) -> Iterator[List[Passage]]:
"""Get all passages in the index"""
nodes = self.get_nodes()
for i in tqdm(range(0, len(nodes), page_size)):
yield [Passage(text=node.text, embedding=node.embedding) for node in nodes[i : i + page_size]]

def get_all(self, limit: int) -> List[Passage]:
passages = []
for node in self.get_nodes():
assert node.embedding is not None, f"Node embedding is None"
passages.append(Passage(text=node.text, embedding=node.embedding))
if len(passages) >= limit:
break
return passages

def get(self, id: str) -> Passage:
Expand Down Expand Up @@ -126,3 +134,6 @@ def list_loaded_data():
name = os.path.basename(data_source_file)
sources.append(name)
return sources

def size(self):
return len(self.get_nodes())
13 changes: 11 additions & 2 deletions memgpt/connectors/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

We originally tried to use Llama Index VectorIndex, but their limited API was extremely problematic.
"""
from typing import Optional, List
from typing import Optional, List, Iterator
import re
import pickle
import os
Expand Down Expand Up @@ -66,7 +66,11 @@ def list_loaded_data():
raise NotImplementedError(f"Storage type {storage_type} not implemented")

@abstractmethod
def get_all(self) -> List[Passage]:
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
pass

@abstractmethod
def get_all(self, limit: int) -> List[Passage]:
pass

@abstractmethod
Expand All @@ -89,3 +93,8 @@ def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Pas
def save(self):
"""Save state of storage connector"""
pass

@abstractmethod
def size(self):
"""Get number of passages (text/embedding pairs) in storage"""
pass
5 changes: 2 additions & 3 deletions memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,11 +818,10 @@ async def a_insert(self, memory_string):
def __repr__(self) -> str:
limit = 10
passages = []
for passage in list(self.storage.get_all())[:limit]: # TODO: only get first 10
for passage in list(self.storage.get_all(limit)): # TODO: only get first 10
passages.append(str(passage.text))
memory_str = "\n".join(passages)
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"

def __len__(self):
print("get archival storage size")
return len(self.storage.get_all())
return self.storage.size()