diff --git a/memgpt/agent.py b/memgpt/agent.py index 8c5ea60428..7e58741973 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -771,11 +771,8 @@ def edit_memory(self, name, content): return None def edit_memory_append(self, name, content): - print("edit append") new_len = self.memory.edit_append(name, content) - print("rebuild memory") self.rebuild_memory() - print("done") return None def edit_memory_replace(self, name, old_content, new_content): diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 09f5d26e6d..e545986d25 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -165,21 +165,26 @@ def attach( ): # loads the data contained in data source into the agent's memory from memgpt.connectors.storage import StorageConnector + from tqdm import tqdm agent_config = AgentConfig.load(agent) - config = MemGPTConfig.load() # get storage connectors source_storage = StorageConnector.get_storage_connector(name=data_source) dest_storage = StorageConnector.get_storage_connector(agent_config=agent_config) - passages = source_storage.get_all() - for p in passages: - len(p.embedding) == config.embedding_dim, f"Mismatched embedding sizes {len(p.embedding)} != {config.embedding_dim}" - dest_storage.insert_many(passages) + size = source_storage.size() + typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN) + page_size = 100 + generator = source_storage.get_all_paginated(page_size=page_size) # yields List[Passage] + for i in tqdm(range(0, size, page_size)): + passages = next(generator) + dest_storage.insert_many(passages, show_progress=False) + + # save destination storage dest_storage.save() - total_agent_passages = len(dest_storage.get_all()) + total_agent_passages = dest_storage.size() typer.secho( f"Attached data source {data_source} to agent {agent}, consisting of {len(passages)}. Agent now has {total_agent_passages} embeddings in archival memory.", diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 2bbaebc4f2..4ea68c511b 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -10,7 +10,7 @@ import re from tqdm import tqdm -from typing import Optional, List +from typing import Optional, List, Iterator import numpy as np from tqdm import tqdm @@ -76,9 +76,26 @@ def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfi self.Session = sessionmaker(bind=self.engine) self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension - def get_all(self) -> List[Passage]: + def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]: session = self.Session() - db_passages = session.query(self.db_model).all() + offset = 0 + while True: + # Retrieve a chunk of records with the given page_size + db_passages_chunk = session.query(self.db_model).offset(offset).limit(page_size).all() + + # If the chunk is empty, we've retrieved all records + if not db_passages_chunk: + break + + # Yield a list of Passage objects converted from the chunk + yield [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages_chunk] + + # Increment the offset to get the next chunk in the next iteration + offset += page_size + + def get_all(self, limit=10) -> List[Passage]: + session = self.Session() + db_passages = session.query(self.db_model).limit(limit).all() return [Passage(text=p.text, embedding=p.embedding, doc_id=p.doc_id, passage_id=p.id) for p in db_passages] def get(self, id: str) -> Optional[Passage]: @@ -88,6 +105,11 @@ def get(self, id: str) -> Optional[Passage]: return None return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id) + def size(self) -> int: + # return size of table + session = self.Session() + return session.query(self.db_model).count() + def insert(self, passage: Passage): session = self.Session() db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding) diff --git a/memgpt/connectors/local.py b/memgpt/connectors/local.py index dd77637914..233856c8da 100644 --- a/memgpt/connectors/local.py +++ b/memgpt/connectors/local.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Iterator from memgpt.config import AgentConfig, MemGPTConfig from tqdm import tqdm import re @@ -72,11 +72,19 @@ def add_nodes(self, nodes: List[TextNode]): self.nodes += nodes self.index = VectorStoreIndex(self.nodes) - def get_all(self) -> List[Passage]: + def get_all_paginated(self, page_size: int = 100) -> Iterator[List[Passage]]: + """Get all passages in the index""" + nodes = self.get_nodes() + for i in tqdm(range(0, len(nodes), page_size)): + yield [Passage(text=node.text, embedding=node.embedding) for node in nodes[i : i + page_size]] + + def get_all(self, limit: int) -> List[Passage]: passages = [] for node in self.get_nodes(): assert node.embedding is not None, f"Node embedding is None" passages.append(Passage(text=node.text, embedding=node.embedding)) + if len(passages) >= limit: + break return passages def get(self, id: str) -> Passage: @@ -126,3 +134,6 @@ def list_loaded_data(): name = os.path.basename(data_source_file) sources.append(name) return sources + + def size(self): + return len(self.get_nodes()) diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 6fd485a79f..21a69cb8e1 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -2,7 +2,7 @@ We originally tried to use Llama Index VectorIndex, but their limited API was extremely problematic. """ -from typing import Optional, List +from typing import Optional, List, Iterator import re import pickle import os @@ -66,7 +66,11 @@ def list_loaded_data(): raise NotImplementedError(f"Storage type {storage_type} not implemented") @abstractmethod - def get_all(self) -> List[Passage]: + def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]: + pass + + @abstractmethod + def get_all(self, limit: int) -> List[Passage]: pass @abstractmethod @@ -89,3 +93,8 @@ def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Pas def save(self): """Save state of storage connector""" pass + + @abstractmethod + def size(self): + """Get number of passages (text/embedding pairs) in storage""" + pass diff --git a/memgpt/memory.py b/memgpt/memory.py index 126a2c3143..83cf3a6e10 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -818,11 +818,10 @@ async def a_insert(self, memory_string): def __repr__(self) -> str: limit = 10 passages = [] - for passage in list(self.storage.get_all())[:limit]: # TODO: only get first 10 + for passage in list(self.storage.get_all(limit)): # TODO: only get first 10 passages.append(str(passage.text)) memory_str = "\n".join(passages) return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" def __len__(self): - print("get archival storage size") - return len(self.storage.get_all()) + return self.storage.size()