Skip to content

Commit

Permalink
[DH-5238] Fix migration script (#341)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 authored Jan 11, 2024
1 parent 61889ec commit d539be6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 32 deletions.
Binary file removed .DS_Store
Binary file not shown.
37 changes: 14 additions & 23 deletions dataherald/scripts/migrate_v006_to_v100.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import os
from datetime import timedelta

from bson.objectid import ObjectId
from pymongo import ASCENDING
from pymongo.errors import DuplicateKeyError
from sql_metadata import Parser

import dataherald.config
from dataherald.config import System
from dataherald.db import DB
from dataherald.db_scanner.models.types import TableDescriptionStatus
from dataherald.types import GoldenSQL
from dataherald.vector_store import VectorStore


def update_object_id_fields(field_name: str, collection_name: str):
for obj in storage.find_all(collection_name):
if obj[field_name] and obj[field_name] != "":
if (
obj[field_name]
and obj[field_name] != ""
and isinstance(obj[field_name], ObjectId)
):
obj[field_name] = str(obj[field_name])
storage.update_or_create(collection_name, {"_id": obj["_id"]}, obj)

Expand All @@ -26,7 +31,7 @@ def update_object_id_fields(field_name: str, collection_name: str):
storage = system.instance(DB)
# Refresh vector stores
golden_sql_collection = os.environ.get(
"GOLDEN_SQL_COLLECTION", "dataherald-staging"
"GOLDEN_RECORD_COLLECTION", "dataherald-staging"
)
vector_store = system.instance(VectorStore)

Expand All @@ -50,8 +55,8 @@ def update_object_id_fields(field_name: str, collection_name: str):
update_object_id_fields("db_connection_id", "table_descriptions")
update_object_id_fields("db_connection_id", "golden_sqls")
update_object_id_fields("db_connection_id", "instructions")
update_object_id_fields("question_id", "responses")
update_object_id_fields("db_connection_id", "questions")
update_object_id_fields("question_id", "responses")
print("Data types changed...")

try:
Expand All @@ -61,28 +66,14 @@ def update_object_id_fields(field_name: str, collection_name: str):
pass
# Upload golden records
golden_sqls = storage.find_all("golden_sqls")
for golden_sql in golden_sqls:
tables = Parser(golden_sql["sql"]).tables
if len(tables) == 0:
tables = [""]
prompt_text = golden_sql["prompt_text"]
vector_store.add_record(
documents=prompt_text,
db_connection_id=str(golden_sql["db_connection_id"]),
collection=golden_sql_collection,
metadata=[
{
"tables_used": tables[0],
"db_connection_id": str(golden_sql["db_connection_id"]),
}
], # this should be updated for multiple tables
ids=[str(golden_sql["_id"])],
)
print("Updated...")
stored_golden_sqls = []
for golden_sql_dict in golden_sqls:
golden_sql = GoldenSQL(**golden_sql_dict, id=str(golden_sql_dict["_id"]))
stored_golden_sqls.append(golden_sql)
vector_store.add_records(stored_golden_sqls, golden_sql_collection)
print("Golden sqls uploaded...")

# Update the table_descriptions status

for table_description in storage.find_all("table_descriptions"):
if table_description["status"] == "SYNCHRONIZED":
table_description["status"] = TableDescriptionStatus.SCANNED.value
Expand Down
22 changes: 13 additions & 9 deletions dataherald/vector_store/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,20 @@ def add_records(self, golden_sqls: List[GoldenSQL], collection: str):

records = []
for key in range(len(golden_sql_batch)):
records.append(
(
str(golden_sql_batch[key].id),
embeds[key],
{
"tables_used": Parser(golden_sql_batch[key].sql).tables[0],
"db_connection_id": golden_sql_batch[key].db_connection_id,
},
parsed_tables = Parser(golden_sql_batch[key].sql).tables
if len(parsed_tables) > 0:
records.append(
(
str(golden_sql_batch[key].id),
embeds[key],
{
"tables_used": parsed_tables[0],
"db_connection_id": golden_sql_batch[
key
].db_connection_id,
},
)
)
)
index.upsert(vectors=records)

@override
Expand Down

0 comments on commit d539be6

Please sign in to comment.