Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove embeddings and NER, as we do not use them in sql-eval anymore #49

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,15 @@ md.academic

#### Supplementary

We also have column embeddings, joinable columns and special columns with named entities, split by database in [supplementary.py](defog_data/supplementary.py). To access them, use the following code:
We also have joinable columns, split by database in [supplementary.py](defog_data/supplementary.py). To access them, use the following code:

```python
import defog_data.supplementary as sup

# embeddings and accompanying column info in csv format
embeddings, csv_info = sup.load_embeddings("<your path of choice>")
# columns that can be joined on
sup.columns_join
# columns with named entities
sup.columns_ner
```

Note that the embeddings need to be regenerated should the underlying data get updated (eg new columns added, major version bumps). To regenerate the embeddings, the previous ones should be deleted first, which can be done automatically by setting the update parameter to `True` when running the `load_embeddings` function.

## Organization

### Databases
Expand Down
211 changes: 0 additions & 211 deletions defog_data/supplementary.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,9 @@
from defog_data.metadata import dbs
import logging
import os
import pickle
from sentence_transformers import SentenceTransformer
import re

# get package root directory
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def generate_embeddings(emb_path: str, save_emb: bool = True) -> tuple[dict, dict]:
"""
For each db, generate embeddings for all of the column names and descriptions
"""
encoder = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2", device="cpu"
)
emb = {}
csv_descriptions = {}
glossary_emb = {}
for db_name, db in dbs.items():
metadata = db["table_metadata"]
glossary = clean_glossary(db["glossary"])
column_descriptions = []
column_descriptions_typed = []
for table in metadata:
for column in metadata[table]:
col_str = (
table
+ "."
+ column["column_name"]
+ ": "
+ column["column_description"]
)
col_str_typed = (
table
+ "."
+ column["column_name"]
+ ","
+ column["data_type"]
+ ","
+ column["column_description"]
)
column_descriptions.append(col_str)
column_descriptions_typed.append(col_str_typed)
column_emb = encoder.encode(column_descriptions, convert_to_tensor=True)
emb[db_name] = column_emb
csv_descriptions[db_name] = column_descriptions_typed
logging.info(f"Finished embedding {db_name} {len(column_descriptions)} columns")
if len(glossary) > 0:
glossary_embeddings = encoder.encode(glossary, convert_to_tensor=True)
else:
glossary_embeddings = []
glossary_emb[db_name] = glossary_embeddings
if save_emb:
# get directory of emb_path and create if it doesn't exist
emb_dir = os.path.dirname(emb_path)
if not os.path.exists(emb_dir):
os.makedirs(emb_dir)
with open(emb_path, "wb") as f:
pickle.dump((emb, csv_descriptions, glossary_emb), f)
logging.info(f"Saved embeddings to file {emb_path}")
return emb, csv_descriptions, glossary_emb


def clean_glossary(glossary: str) -> list[str]:
"""
Clean glossary by removing number bullets and periods, and making sure every line starts with a dash bullet.
Expand All @@ -83,157 +23,6 @@ def clean_glossary(glossary: str) -> list[str]:
glossary = cleaned
return glossary


def load_embeddings(emb_path: str, update: bool = False) -> tuple[dict, dict]:
"""
Load embeddings from file if they exist, otherwise generate them and save them.
"""
if os.path.isfile(emb_path) and update == False:
logging.info(f"Loading embeddings from file {emb_path}")
with open(emb_path, "rb") as f:
emb, csv_descriptions, glossary_emb = pickle.load(f)
return emb, csv_descriptions, glossary_emb
else:
logging.info(f"Embeddings file {emb_path} does not exist or it needs to be updated.")
emb, csv_descriptions, glossary_emb = generate_embeddings(emb_path)
return emb, csv_descriptions, glossary_emb


# entity types: list of (column, type, description) tuples
# note that these are spacy types https://spacy.io/usage/linguistic-features#named-entities
# we can add more types if we want, but PERSON, GPE, ORG should be
# sufficient for most use cases.
# also note that DATE and TIME are not included because they are usually
# retrievable from the top k embedding search due to the limited list of nouns
columns_ner = {
"academic": {
"PERSON": [
"author.name,text,Name of the author",
],
"ORG": [
"conference.name,text,The name of the conference",
"journal.name,text,The name of the journal",
"organization.name,text,Name of the organization",
],
},
"advising": {
"PERSON": [
"instructor.name,text,Name of the instructor",
"student.firstname,text,First name of the student",
"student.lastname,text,Last name of the student",
],
"ORG": [
"program.college,text,Name of the college offering the program",
"program.name,text,Name of the program",
],
},
"atis": {
"GPE": [
"airport_service.city_code,text,The code of the city where the airport is located",
"airport.airport_location,text,The location of the airport, eg 'Las Vegas', 'Chicago'",
"airport.country_name,text,The name of the country where the airport is located.",
"airport.state_code,text,The code assigned to the state where the airport is located.",
"city.city_code,text,The code assigned to the city",
"city.city_name,text,The name of the city",
"city.country_name,text,The name of the country where the city is located",
"city.state_code,text,The 2-letter code assigned to the state where the city is located. E.g. 'NY', 'CA', etc.",
"ground_service.city_code,text,The code for the city where the ground service is provided",
"state.country_name,text,The name of the country the state belongs to",
"state.state_code,text,The 2-letter code assigned to the state. E.g. 'NY', 'CA', etc.",
"state.state_name,text,The name of the state",
],
"ORG": [
"airline.airline_code,text,The code assigned to the airline",
"airline.airline_name,text,The name of the airline",
"airport_service.airport_code,text,The code of the airport",
"airport.airport_code,text,The code assigned to the airport.",
"airport.airport_name,text,The name of the airport",
"dual_carrier.main_airline,text,The name of the main airline operating the flight",
"fare.fare_airline,text,The airline code associated with this fare",
"fare.from_airport,text,The 3-letter airport code for the departure location",
"fare.to_airport,text,The 3-letter airport code for the arrival location",
"flight.airline_code,text,Code assigned to the airline",
"flight.from_airport,text,Code assigned to the departure airport",
"flight.to_airport,text,Code assigned to the arrival airport",
"ground_service.airport_code,text,The 3-letter code for the airport where the ground service is provided",
],
},
"yelp": {
"GPE": [
"business.city,text,The city where the business is located",
"business.state,text,The US state where the business is located, represented by two-letter abbreviations (eg. 'CA', 'NV', 'NY', etc.)",
"business.full_address,text,The full address of the business",
],
"ORG": [
"business.name,text,The name of the business. All apostrophes use ’ instead of ' to avoid SQL errors.",
"neighbourhood.neighbourhood_name,text,Name of the neighbourhood where the business is located",
],
"PER": [
"users.name,text,Name of the user",
],
},
"restaurants": {
"GPE": [
"location.city_name,text,The name of the city where the restaurant is located",
"location.street_name,text,The name of the street where the restaurant is located",
"geographic.city_name,text,The name of the city",
"geographic.county,text,The name of the county",
"geographic.region,text,The name of the region",
"restaurant.city_name,text,The city where the restaurant is located",
],
"ORG": [
"restaurant.name,text,The name of the restaurant",
"restaurant.id,bigint,Unique identifier for each restaurant",
],
"PER": [],
},
"geography": {
"GPE": [
"city.city_name,text,The name of the city",
"city.country_name,text,The name of the country where the city is located",
"city.state_name,text,The name of the state where the city is located",
"lake.country_name,text,The name of the country where the lake is located",
"lake.state_name,text,The name of the state where the lake is located (if applicable)",
"river.country_name,text,The name of the country the river flows through",
"river.traverse,text,The cities or landmarks the river passes through. Comma delimited and in title case, eg `New York,Albany,Boston`",
"state.state_name,text,The name of the state",
"state.country_name,text,The name of the country the state belongs to",
"state.capital,text,The name of the capital city of the state",
"highlow.state_name,text,The name of the state",
"mountain.country_name,text,The name of the country where the mountain is located",
"mountain.state_name,text,The name of the state or province where the mountain is located (if applicable)",
"border_info.state_name,text,The name of the state that shares a border with another state or country.",
"border_info.border,text,The name of the state that shares a border with the state specified in the state_name column.",
],
"LOC": [
"lake.lake_name,text,The name of the lake",
"river.river_name,text,The name of the river. Names exclude the word 'river' e.g. 'Mississippi' instead of 'Mississippi River'",
"mountain.mountain_name,text,The name of the mountain",
],
"ORG": [],
"PER": [],
},
"scholar": {
"GPE": [],
"EVENT": [
"venue.venuename,text,Name of the venue",
],
"ORG": [],
"PER": [
"author.authorname,text,Name of the author",
],
"WORK_OF_ART": [
"paper.title,text,The title of the paper, enclosed in double quotes if it contains commas.",
"dataset.datasetname,text,Name of the dataset",
"journal.journalname,text,Name or title of the journal",
],
},
"broker": {"GPE": []},
"car_dealership": {"GPE": []},
"derm_treatment": {"GPE": []},
"ewallet": {"GPE": []},
}

# (pair of tables): list of (column1, column2) tuples that can be joined
# pairs should be lexically ordered, ie (table1 < table2) and (column1 < column2)
columns_join = {
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
sentence_transformers==3.2.1
14 changes: 0 additions & 14 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import unittest
from defog_data.metadata import get_db, dbs
from defog_data.supplementary import columns_ner


class TestDB(unittest.TestCase):
Expand Down Expand Up @@ -261,19 +260,6 @@ def test_yelp(self):
num_columns = sum([len(db_schema[table]) for table in db_schema])
self.assertEqual(num_columns, 36)

def test_supplementary_columns_ner(self):
# for each db, go through each table and add column names to a set and make sure they are not repeated
for db_name, ner_mapping in columns_ner.items():
column_names = set()
for _, column_str_list in ner_mapping.items():
for column_str in column_str_list:
column_name = column_str.split(",")[0]
if column_name in column_names:
raise Exception(
f"Column name {column_name} is repeated in {db_name}"
)
column_names.add(column_name)


if __name__ == "__main__":
unittest.main()