Skip to content

Commit

Permalink
Merge pull request #748 from wwulfric/main
Browse files Browse the repository at this point in the history
feat: support pgvecto.rs
  • Loading branch information
zainhoda authored Feb 8, 2025
2 parents 36bdcde + 6453492 commit 856dfa9
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/vanna/pgvector/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .pgvector import PG_VectorStore
from .pgvecto_rs import PG_Vecto_rsStore
269 changes: 269 additions & 0 deletions src/vanna/pgvector/pgvecto_rs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import ast
import json
import logging
import uuid

import pandas as pd
from langchain_core.documents import Document
from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs
from sqlalchemy import create_engine, text

from .. import ValidationError
from ..base import VannaBase
from ..types import TrainingPlan, TrainingPlanItem
from ..utils import deterministic_uuid


class PG_Vecto_rsStore(VannaBase):
def __init__(self, config=None):
if not config or "connection_string" not in config:
raise ValueError(
"A valid 'config' dictionary with a 'connection_string' is required.")

VannaBase.__init__(self, config=config)

if config and "connection_string" in config:
self.connection_string = config.get("connection_string")
self.n_results = config.get("n_results", 10)

if config and "embedding_function" in config:
self.embedding_function = config.get("embedding_function")
self.vector_dimension = config.get("vector_dimension")
else:
from langchain_huggingface import HuggingFaceEmbeddings
self.embedding_function = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2")
self.vector_dimension = 384
self.sql_collection = PGVecto_rs(
embedding=self.embedding_function,
collection_name="sql",
db_url=self.connection_string,
dimension=self.vector_dimension,
)
self.ddl_collection = PGVecto_rs(
embedding=self.embedding_function,
collection_name="ddl",
db_url=self.connection_string,
dimension=self.vector_dimension,
)
self.documentation_collection = PGVecto_rs(
embedding=self.embedding_function,
collection_name="documentation",
db_url=self.connection_string,
dimension=self.vector_dimension,
)

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps(
{
"question": question,
"sql": sql,
},
ensure_ascii=False,
)
id = deterministic_uuid(question_sql_json) + "-sql"
createdat = kwargs.get("createdat")
doc = Document(
page_content=question_sql_json,
metadata={"id": id, "createdat": createdat},
)
self.sql_collection.add_documents([doc], ids=[doc.metadata["id"]])

return id

def add_ddl(self, ddl: str, **kwargs) -> str:
_id = deterministic_uuid(ddl) + "-ddl"
doc = Document(
page_content=ddl,
metadata={"id": _id},
)
self.ddl_collection.add_documents([doc], ids=[doc.metadata["id"]])
return _id

def add_documentation(self, documentation: str, **kwargs) -> str:
_id = deterministic_uuid(documentation) + "-doc"
doc = Document(
page_content=documentation,
metadata={"id": _id},
)
self.documentation_collection.add_documents([doc],
ids=[doc.metadata["id"]])
return _id

def get_collection(self, collection_name):
match collection_name:
case "sql":
return self.sql_collection
case "ddl":
return self.ddl_collection
case "documentation":
return self.documentation_collection
case _:
raise ValueError("Specified collection does not exist.")

def get_similar_question_sql(self, question: str, **kwargs) -> list:
documents = self.sql_collection.similarity_search(query=question,
k=self.n_results)
return [ast.literal_eval(document.page_content) for document in documents]

def get_related_ddl(self, question: str, **kwargs) -> list:
documents = self.ddl_collection.similarity_search(query=question,
k=self.n_results)
return [document.page_content for document in documents]

def get_related_documentation(self, question: str, **kwargs) -> list:
documents = self.documentation_collection.similarity_search(query=question,
k=self.n_results)
return [document.page_content for document in documents]

def train(
self,
question: str | None = None,
sql: str | None = None,
ddl: str | None = None,
documentation: str | None = None,
plan: TrainingPlan | None = None,
createdat: str | None = None,
):
if question and not sql:
raise ValidationError("Please provide a SQL query.")

if documentation:
logging.info(f"Adding documentation: {documentation}")
return self.add_documentation(documentation)

if sql and question:
return self.add_question_sql(question=question, sql=sql,
createdat=createdat)

if ddl:
logging.info(f"Adding ddl: {ddl}")
return self.add_ddl(ddl)

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL and item.item_name:
self.add_question_sql(question=item.item_name, sql=item.item_value)

def get_training_data(self, **kwargs) -> pd.DataFrame:
# Establishing the connection
engine = create_engine(self.connection_string)

# Querying the 'langchain_pg_embedding' table
query_embedding = "SELECT cmetadata, document FROM langchain_pg_embedding"
df_embedding = pd.read_sql(query_embedding, engine)

# List to accumulate the processed rows
processed_rows = []

# Process each row in the DataFrame
for _, row in df_embedding.iterrows():
custom_id = row["cmetadata"]["id"]
document = row["document"]
training_data_type = "documentation" if custom_id[
-3:] == "doc" else custom_id[-3:]

if training_data_type == "sql":
# Convert the document string to a dictionary
try:
doc_dict = ast.literal_eval(document)
question = doc_dict.get("question")
content = doc_dict.get("sql")
except (ValueError, SyntaxError):
logging.info(
f"Skipping row with custom_id {custom_id} due to parsing error.")
continue
elif training_data_type in ["documentation", "ddl"]:
question = None # Default value for question
content = document
else:
# If the suffix is not recognized, skip this row
logging.info(
f"Skipping row with custom_id {custom_id} due to unrecognized training data type.")
continue

# Append the processed data to the list
processed_rows.append(
{"id": custom_id, "question": question, "content": content,
"training_data_type": training_data_type}
)

# Create a DataFrame from the list of processed rows
df_processed = pd.DataFrame(processed_rows)

return df_processed

def remove_training_data(self, id: str, **kwargs) -> bool:
# Create the database engine
engine = create_engine(self.connection_string)

# SQL DELETE statement
delete_statement = text(
"""
DELETE FROM langchain_pg_embedding
WHERE cmetadata ->> 'id' = :id
"""
)

# Connect to the database and execute the delete statement
with engine.connect() as connection:
# Start a transaction
with connection.begin() as transaction:
try:
result = connection.execute(delete_statement, {"id": id})
# Commit the transaction if the delete was successful
transaction.commit()
# Check if any row was deleted and return True or False accordingly
return result.rowcount() > 0
except Exception as e:
# Rollback the transaction in case of error
logging.error(f"An error occurred: {e}")
transaction.rollback()
return False

def remove_collection(self, collection_name: str) -> bool:
engine = create_engine(self.connection_string)

# Determine the suffix to look for based on the collection name
suffix_map = {"ddl": "ddl", "sql": "sql", "documentation": "doc"}
suffix = suffix_map.get(collection_name)

if not suffix:
logging.info(
"Invalid collection name. Choose from 'ddl', 'sql', or 'documentation'.")
return False

# SQL query to delete rows based on the condition
query = text(
f"""
DELETE FROM langchain_pg_embedding
WHERE cmetadata->>'id' LIKE '%{suffix}'
"""
)

# Execute the deletion within a transaction block
with engine.connect() as connection:
with connection.begin() as transaction:
try:
result = connection.execute(query)
transaction.commit() # Explicitly commit the transaction
if result.rowcount() > 0:
logging.info(
f"Deleted {result.rowcount()} rows from "
f"langchain_pg_embedding where collection is {collection_name}."
)
return True
else:
logging.info(f"No rows deleted for collection {collection_name}.")
return False
except Exception as e:
logging.error(f"An error occurred: {e}")
transaction.rollback() # Rollback in case of error
return False

def generate_embedding(self, *args, **kwargs):
pass
15 changes: 8 additions & 7 deletions src/vanna/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import ValidationError
from ..base import VannaBase
from ..types import TrainingPlan, TrainingPlanItem
from ..utils import deterministic_uuid


class PG_VectorStore(VannaBase):
Expand Down Expand Up @@ -55,7 +56,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
},
ensure_ascii=False,
)
id = str(uuid.uuid4()) + "-sql"
id = deterministic_uuid(question_sql_json) + "-sql"
createdat = kwargs.get("createdat")
doc = Document(
page_content=question_sql_json,
Expand All @@ -66,7 +67,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
return id

def add_ddl(self, ddl: str, **kwargs) -> str:
_id = str(uuid.uuid4()) + "-ddl"
_id = deterministic_uuid(ddl) + "-ddl"
doc = Document(
page_content=ddl,
metadata={"id": _id},
Expand All @@ -75,7 +76,7 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
return _id

def add_documentation(self, documentation: str, **kwargs) -> str:
_id = str(uuid.uuid4()) + "-doc"
_id = deterministic_uuid(documentation) + "-doc"
doc = Document(
page_content=documentation,
metadata={"id": _id},
Expand All @@ -94,7 +95,7 @@ def get_collection(self, collection_name):
case _:
raise ValueError("Specified collection does not exist.")

def get_similar_question_sql(self, question: str) -> list:
def get_similar_question_sql(self, question: str, **kwargs) -> list:
documents = self.sql_collection.similarity_search(query=question, k=self.n_results)
return [ast.literal_eval(document.page_content) for document in documents]

Expand Down Expand Up @@ -203,7 +204,7 @@ def remove_training_data(self, id: str, **kwargs) -> bool:
# Commit the transaction if the delete was successful
transaction.commit()
# Check if any row was deleted and return True or False accordingly
return result.rowcount > 0
return result.rowcount() > 0
except Exception as e:
# Rollback the transaction in case of error
logging.error(f"An error occurred: {e}")
Expand Down Expand Up @@ -235,9 +236,9 @@ def remove_collection(self, collection_name: str) -> bool:
try:
result = connection.execute(query)
transaction.commit() # Explicitly commit the transaction
if result.rowcount > 0:
if result.rowcount() > 0:
logging.info(
f"Deleted {result.rowcount} rows from "
f"Deleted {result.rowcount()} rows from "
f"langchain_pg_embedding where collection is {collection_name}."
)
return True
Expand Down

0 comments on commit 856dfa9

Please sign in to comment.