Skip to content

Commit

Permalink
Merge branch 'develop' of github.com:agiletechnologist/MemGPT into de…
Browse files Browse the repository at this point in the history
…velop

* 'develop' of github.com:agiletechnologist/MemGPT:
  Docs: Fix typos (letta-ai#477)
  Lancedb storage integration (letta-ai#455)
  updated websocket protocol and server (letta-ai#473)
  • Loading branch information
agiletechnologist committed Nov 18, 2023
2 parents 859df95 + 22c8a7a commit b39b3f8
Show file tree
Hide file tree
Showing 13 changed files with 468 additions and 12 deletions.
17 changes: 17 additions & 0 deletions docs/storage.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,22 @@ pip install 'pymemgpt[postgres]'
You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation).


## LanceDB
In order to use the LanceDB backend.

You have to enable the LanceDB backend by running

```
memgpt configure
```
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`.

To enable the LanceDB backend, make sure to install the required dependencies with:
```
pip install 'pymemgpt[lancedb]'
```
for more checkout [lancedb docs](https://lancedb.github.io/lancedb/)


## Chroma
(Coming soon)
3 changes: 2 additions & 1 deletion memgpt/autogen/examples/memgpt_coder_autogen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"outputs": [],
"source": [
"import openai\n",
"openai.api_key=\"YOUR_API_KEY\""
"\n",
"openai.api_key = \"YOUR_API_KEY\""
]
},
{
Expand Down
11 changes: 10 additions & 1 deletion memgpt/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def configure_cli(config: MemGPTConfig):

def configure_archival_storage(config: MemGPTConfig):
# Configure archival storage backend
archival_storage_options = ["local", "postgres"]
archival_storage_options = ["local", "lancedb", "postgres"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
Expand All @@ -220,8 +220,17 @@ def configure_archival_storage(config: MemGPTConfig):
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()

if archival_storage_type == "lancedb":
archival_storage_uri = questionary.text(
"Enter lanncedb connection string (e.g. ./.lancedb",
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
).ask()

return archival_storage_type, archival_storage_uri

# TODO: allow configuring embedding model


@app.command()
def configure():
Expand Down
137 changes: 137 additions & 0 deletions memgpt/connectors/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional, List, Iterator
import numpy as np
from tqdm import tqdm
import pandas as pd

from memgpt.config import MemGPTConfig
from memgpt.connectors.storage import StorageConnector, Passage
Expand Down Expand Up @@ -181,3 +182,139 @@ def generate_table_name_agent(self, agent_config: AgentConfig):

def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"


class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""

# TODO: this should probably eventually be moved into a parent DB class

def __init__(self, name: Optional[str] = None):
config = MemGPTConfig.load()

# determine table name
if name:
self.table_name = self.generate_table_name(name)
else:
self.table_name = "lancedb_tbl"

printd(f"Using table name {self.table_name}")

# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
import lancedb

self.db = lancedb.connect(self.uri)
self.table = None

def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
session = self.Session()
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = self.table.search().limit(page_size).to_list()

# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break

# Yield a list of Passage objects converted from the chunk
yield [
Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages_chunk
]

# Increment the offset to get the next chunk in the next iteration
offset += page_size

def get_all(self, limit=10) -> List[Passage]:
db_passages = self.table.search().limit(limit).to_list()
return [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages]

def get(self, id: str) -> Optional[Passage]:
db_passage = self.table.where(f"passage_id={id}").to_list()
if len(db_passage) == 0:
return None
return Passage(
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]
)

def size(self) -> int:
# return size of table
if self.table:
return len(self.table.search().to_list())
else:
print(f"Table with name {self.table_name} not present")
return 0

def insert(self, passage: Passage):
data = [{"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}]

if self.table:
self.table.add(data)
else:
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")

def insert_many(self, passages: List[Passage], show_progress=True):
data = []
iterable = tqdm(passages) if show_progress else passages
for passage in iterable:
temp_dict = {"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}
data.append(temp_dict)

if self.table:
self.table.add(data)
else:
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")

def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
# Assuming query_vec is of same length as embeddings inside table
results = self.table.search(query_vec).limit(top_k)

# Convert the results into Passage objects
passages = [
Passage(text=result["text"], embedding=result["embedding"], doc_id=result["doc_id"], passage_id=result["passage_id"])
for result in results
]
return passages

def delete(self):
"""Drop the passage table from the database."""
# Drop the table specified by the PassageModel class
self.db.drop_table(self.table_name)

def save(self):
return

@staticmethod
def list_loaded_data():
config = MemGPTConfig.load()
import lancedb

db = lancedb.connect(config.archival_storage_uri)

tables = db.table_names()
tables = [table for table in tables if table.startswith("memgpt_")]
tables = [table.replace("memgpt_", "") for table in tables]
return tables

def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()

# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)

# Truncate to the maximum identifier length
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")

# Convert to lowercase
name = name.lower()

return name

def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"
10 changes: 10 additions & 0 deletions memgpt/connectors/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def get_storage_connector(name: Optional[str] = None, agent_config: Optional[Age

return PostgresStorageConnector(name=name, agent_config=agent_config)

elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

return LanceDBConnector(name=name)

else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

Expand All @@ -62,6 +67,11 @@ def list_loaded_data():
from memgpt.connectors.db import PostgresStorageConnector

return PostgresStorageConnector.list_loaded_data()

elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector

return LanceDBConnector.list_loaded_data()
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

Expand Down
2 changes: 1 addition & 1 deletion memgpt/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
def __init__(self, agent_config, top_k: Optional[int] = 100):
"""Init function for archival memory
:param archiva_memory_database: name of dataset to pre-fill archival with
:param archival_memory_database: name of dataset to pre-fill archival with
:type archival_memory_database: str
"""
from memgpt.connectors.storage import StorageConnector
Expand Down
5 changes: 4 additions & 1 deletion memgpt/server/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
def condition_to_stop_receiving(response):
"""Determines when to stop listening to the server"""
return response.get("type") == "agent_response_end"
if response.get("type") in ["agent_response_end", "agent_response_error", "command_response", "server_error"]:
return True
else:
return False


def print_server_response(response):
Expand Down
3 changes: 2 additions & 1 deletion memgpt/server/websocket_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ def server_agent_function_message(msg):
# Client -> server


def client_user_message(msg):
def client_user_message(msg, agent_name=None):
return json.dumps(
{
"type": "user_message",
"message": msg,
"agent_name": agent_name,
}
)

Expand Down
43 changes: 40 additions & 3 deletions memgpt/server/websocket_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import traceback

import websockets

Expand All @@ -15,7 +16,9 @@ def __init__(self, host="localhost", port=DEFAULT_PORT):
self.host = host
self.port = port
self.interface = SyncWebSocketInterface()

self.agent = None
self.agent_name = None

def run_step(self, user_message, first_message=False, no_verify=False):
while True:
Expand All @@ -41,16 +44,27 @@ async def handle_client(self, websocket, path):
message = await websocket.recv()

# Assuming the message is a JSON string
data = json.loads(message)

if data["type"] == "command":
try:
data = json.loads(message)
except:
print(f"[server] bad data from client:\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
continue

if "type" not in data:
print(f"[server] bad data from client (JSON but no type):\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))

elif data["type"] == "command":
# Create a new agent
if data["command"] == "create_agent":
try:
self.agent = self.create_new_agent(data["config"])
await websocket.send(protocol.server_command_response("OK: Agent initialized"))
except Exception as e:
self.agent = None
print(f"[server] self.create_new_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}"))

# Load an existing agent
Expand All @@ -59,9 +73,11 @@ async def handle_client(self, websocket, path):
if agent_name is not None:
try:
self.agent = self.load_agent(agent_name)
self.agent_name = agent_name
await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded"))
except Exception as e:
print(f"[server] self.load_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
self.agent = None
await websocket.send(
protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}")
Expand All @@ -76,6 +92,26 @@ async def handle_client(self, websocket, path):
elif data["type"] == "user_message":
user_message = data["message"]

if "agent_name" in data:
agent_name = data["agent_name"]
# If the agent requested the same one that's already loading?
if self.agent_name is None or self.agent_name != data["agent_name"]:
try:
print(f"[server] loading agent {agent_name}")
self.agent = self.load_agent(agent_name)
self.agent_name = agent_name
# await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded"))
except Exception as e:
print(f"[server] self.load_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
self.agent = None
await websocket.send(
protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}")
)
else:
await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request"))
continue

if self.agent is None:
await websocket.send(protocol.server_agent_response_error("No agent has been initialized"))
else:
Expand All @@ -84,6 +120,7 @@ async def handle_client(self, websocket, path):
self.run_step(user_message)
except Exception as e:
print(f"[server] self.run_step failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_agent_response_error(f"self.run_step failed with: {e}"))

await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed
Expand Down
Loading

0 comments on commit b39b3f8

Please sign in to comment.