Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support pgvecto.rs #748

Merged
merged 3 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vanna/pgvector/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .pgvector import PG_VectorStore
from .pgvecto_rs import PG_Vecto_rsStore
269 changes: 269 additions & 0 deletions src/vanna/pgvector/pgvecto_rs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import ast
import json
import logging
import uuid

import pandas as pd
from langchain_core.documents import Document
from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs
from sqlalchemy import create_engine, text

from .. import ValidationError
from ..base import VannaBase
from ..types import TrainingPlan, TrainingPlanItem
from ..utils import deterministic_uuid


class PG_Vecto_rsStore(VannaBase):
def __init__(self, config=None):
if not config or "connection_string" not in config:
raise ValueError(
"A valid 'config' dictionary with a 'connection_string' is required.")

VannaBase.__init__(self, config=config)

if config and "connection_string" in config:
self.connection_string = config.get("connection_string")
self.n_results = config.get("n_results", 10)

if config and "embedding_function" in config:
self.embedding_function = config.get("embedding_function")
self.vector_dimension = config.get("vector_dimension")
else:
from langchain_huggingface import HuggingFaceEmbeddings
self.embedding_function = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2")
self.vector_dimension = 384
self.sql_collection = PGVecto_rs(
embedding=self.embedding_function,
collection_name="sql",
db_url=self.connection_string,
dimension=self.vector_dimension,
)
self.ddl_collection = PGVecto_rs(
embedding=self.embedding_function,
collection_name="ddl",
db_url=self.connection_string,
dimension=self.vector_dimension,
)
self.documentation_collection = PGVecto_rs(
embedding=self.embedding_function,
collection_name="documentation",
db_url=self.connection_string,
dimension=self.vector_dimension,
)

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps(
{
"question": question,
"sql": sql,
},
ensure_ascii=False,
)
id = deterministic_uuid(question_sql_json) + "-sql"
createdat = kwargs.get("createdat")
doc = Document(
page_content=question_sql_json,
metadata={"id": id, "createdat": createdat},
)
self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])

return id

def add_ddl(self, ddl: str, **kwargs) -> str:
_id = deterministic_uuid(ddl) + "-ddl"
doc = Document(
page_content=ddl,
metadata={"id": _id},
)
self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
return _id

def add_documentation(self, documentation: str, **kwargs) -> str:
_id = deterministic_uuid(documentation) + "-doc"
doc = Document(
page_content=documentation,
metadata={"id": _id},
)
self.documentation_collection.add_documents([doc],
ids=[doc.metadata["id"]])
return _id

def get_collection(self, collection_name):
match collection_name:
case "sql":
return self.sql_collection
case "ddl":
return self.ddl_collection
case "documentation":
return self.documentation_collection
case _:
raise ValueError("Specified collection does not exist.")

def get_similar_question_sql(self, question: str, **kwargs) -> list:
documents = self.sql_collection.similarity_search(query=question,
k=self.n_results)
return [ast.literal_eval(document.page_content) for document in documents]

def get_related_ddl(self, question: str, **kwargs) -> list:
documents = self.ddl_collection.similarity_search(query=question,
k=self.n_results)
return [document.page_content for document in documents]

def get_related_documentation(self, question: str, **kwargs) -> list:
documents = self.documentation_collection.similarity_search(query=question,
k=self.n_results)
return [document.page_content for document in documents]

def train(
self,
question: str | None = None,
sql: str | None = None,
ddl: str | None = None,
documentation: str | None = None,
plan: TrainingPlan | None = None,
createdat: str | None = None,
):
if question and not sql:
raise ValidationError("Please provide a SQL query.")

if documentation:
logging.info(f"Adding documentation: {documentation}")
return self.add_documentation(documentation)

if sql and question:
return self.add_question_sql(question=question, sql=sql,
createdat=createdat)

if ddl:
logging.info(f"Adding ddl: {ddl}")
return self.add_ddl(ddl)

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
self.add_question_sql(question=item.item_name, sql=item.item_value)

def get_training_data(self, **kwargs) -> pd.DataFrame:
# Establishing the connection
engine = create_engine(self.connection_string)

# Querying the 'langchain_pg_embedding' table
query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
df_embedding = pd.read_sql(query_embedding, engine)

# List to accumulate the processed rows
processed_rows = []

# Process each row in the DataFrame
for _, row in df_embedding.iterrows():
custom_id = row["cmetadata"]["id"]
document = row["document"]
training_data_type = "documentation" if custom_id[
-3:] == "doc" else custom_id[-3:]

if training_data_type == "sql":
# Convert the document string to a dictionary
try:
doc_dict = ast.literal_eval(document)
question = doc_dict.get("question")
content = doc_dict.get("sql")
except (ValueError, SyntaxError):
logging.info(
f"Skipping row with custom_id {custom_id} due to parsing error.")
continue
elif training_data_type in ["documentation", "ddl"]:
question = None # Default value for question
content = document
else:
# If the suffix is not recognized, skip this row
logging.info(
f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
continue

# Append the processed data to the list
processed_rows.append(
{"id": custom_id, "question": question, "content": content,
"training_data_type": training_data_type}
)

# Create a DataFrame from the list of processed rows
df_processed = pd.DataFrame(processed_rows)

return df_processed

def remove_training_data(self, id: str, **kwargs) -> bool:
# Create the database engine
engine = create_engine(self.connection_string)

# SQL DELETE statement
delete_statement = text(
"""
DELETE FROM langchain_pg_embedding
WHERE cmetadata ->> 'id' = :id
"""
)

# Connect to the database and execute the delete statement
with engine.connect() as connection:
# Start a transaction
with connection.begin() as transaction:
try:
result = connection.execute(delete_statement, {"id": id})
# Commit the transaction if the delete was successful
transaction.commit()
# Check if any row was deleted and return True or False accordingly
return result.rowcount() > 0
except Exception as e:
# Rollback the transaction in case of error
logging.error(f"An error occurred: {e}")
transaction.rollback()
return False

def remove_collection(self, collection_name: str) -> bool:
engine = create_engine(self.connection_string)

# Determine the suffix to look for based on the collection name
suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
suffix = suffix_map.get(collection_name)

if not suffix:
logging.info(
"Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
return False

# SQL query to delete rows based on the condition
query = text(
f"""
DELETE FROM langchain_pg_embedding
WHERE cmetadata->>'id' LIKE '%{suffix}'
"""
)

# Execute the deletion within a transaction block
with engine.connect() as connection:
with connection.begin() as transaction:
try:
result = connection.execute(query)
transaction.commit() # Explicitly commit the transaction
if result.rowcount() > 0:
logging.info(
f"Deleted {result.rowcount()} rows from "
f"langchain_pg_embedding where collection is {collection_name}."
)
return True
else:
logging.info(f"No rows deleted for collection {collection_name}.")
return False
except Exception as e:
logging.error(f"An error occurred: {e}")
transaction.rollback() # Rollback in case of error
return False

def generate_embedding(self, *args, **kwargs):
pass
15 changes: 8 additions & 7 deletions src/vanna/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import ValidationError
from ..base import VannaBase
from ..types import TrainingPlan, TrainingPlanItem
from ..utils import deterministic_uuid


class PG_VectorStore(VannaBase):
Expand Down Expand Up @@ -55,7 +56,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
},
ensure_ascii=False,
)
id = str(uuid.uuid4()) + "-sql"
id = deterministic_uuid(question_sql_json) + "-sql"
createdat = kwargs.get("createdat")
doc = Document(
page_content=question_sql_json,
Expand All @@ -66,7 +67,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
return id

def add_ddl(self, ddl: str, **kwargs) -> str:
_id = str(uuid.uuid4()) + "-ddl"
_id = deterministic_uuid(ddl) + "-ddl"
doc = Document(
page_content=ddl,
metadata={"id": _id},
Expand All @@ -75,7 +76,7 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
return _id

def add_documentation(self, documentation: str, **kwargs) -> str:
_id = str(uuid.uuid4()) + "-doc"
_id = deterministic_uuid(documentation) + "-doc"
doc = Document(
page_content=documentation,
metadata={"id": _id},
Expand All @@ -94,7 +95,7 @@ def get_collection(self, collection_name):
case _:
raise ValueError("Specified collection does not exist.")

def get_similar_question_sql(self, question: str) -> list:
def get_similar_question_sql(self, question: str, **kwargs) -> list:
documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
return [ast.literal_eval(document.page_content) for document in documents]

Expand Down Expand Up @@ -203,7 +204,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool:
# Commit the transaction if the delete was successful
transaction.commit()
# Check if any row was deleted and return True or False accordingly
return result.rowcount > 0
return result.rowcount() > 0
except Exception as e:
# Rollback the transaction in case of error
logging.error(f"An error occurred: {e}")
Expand Down Expand Up @@ -235,9 +236,9 @@ def remove_collection(self, collection_name: str) -> bool:
try:
result = connection.execute(query)
transaction.commit() # Explicitly commit the transaction
if result.rowcount > 0:
if result.rowcount() > 0:
logging.info(
f"Deleted {result.rowcount} rows from "
f"Deleted {result.rowcount()} rows from "
f"langchain_pg_embedding where collection is {collection_name}."
)
return True
Expand Down