Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda authored Feb 8, 2025
2 parents bb03250 + 9388668 commit 20b0668
Show file tree
Hide file tree
Showing 7 changed files with 598 additions and 195 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ These are some of the user interfaces that we've built using Vanna. You can use
- [Milvus](https://github.com/vanna-ai/vanna/tree/main/src/vanna/milvus)
- [Qdrant](https://github.com/vanna-ai/vanna/tree/main/src/vanna/qdrant)
- [Weaviate](https://github.com/vanna-ai/vanna/tree/main/src/vanna/weaviate)
- [Oracle](https://github.com/vanna-ai/vanna/tree/main/src/vanna/oracle)

## Supported Databases

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ pgvector = ["langchain-postgres>=0.0.12"]
faiss-cpu = ["faiss-cpu"]
faiss-gpu = ["faiss-gpu"]
xinference-client = ["xinference-client"]
oracle = ["oracledb", "chromadb"]
86 changes: 4 additions & 82 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import sqlparse

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem, TableMetadata
from ..types import TrainingPlan, TrainingPlanItem
from ..utils import validate_config_path


Expand Down Expand Up @@ -210,54 +210,6 @@ def extract_sql(self, llm_response: str) -> str:

return llm_response

def extract_table_metadata(ddl: str) -> TableMetadata:
"""
Example:
```python
vn.extract_table_metadata("CREATE TABLE hive.bi_ads.customers (id INT, name TEXT, sales DECIMAL)")
```
Extracts the table metadata from a DDL statement. This is useful in case the DDL statement contains other information besides the table metadata.
Override this function if your DDL statements need custom extraction logic.
Args:
ddl (str): The DDL statement.
Returns:
TableMetadata: The extracted table metadata.
"""
pattern_with_catalog_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_schema = re.compile(
r'CREATE TABLE\s+(\w+)\.(\w+)\s*\(',
re.IGNORECASE
)
pattern_with_table = re.compile(
r'CREATE TABLE\s+(\w+)\s*\(',
re.IGNORECASE
)

match_with_catalog_schema = pattern_with_catalog_schema.search(ddl)
match_with_schema = pattern_with_schema.search(ddl)
match_with_table = pattern_with_table.search(ddl)

if match_with_catalog_schema:
catalog = match_with_catalog_schema.group(1)
schema = match_with_catalog_schema.group(2)
table_name = match_with_catalog_schema.group(3)
return TableMetadata(catalog, schema, table_name)
elif match_with_schema:
schema = match_with_schema.group(1)
table_name = match_with_schema.group(2)
return TableMetadata(None, schema, table_name)
elif match_with_table:
table_name = match_with_table.group(1)
return TableMetadata(None, None, table_name)
else:
return TableMetadata()

def is_sql_valid(self, sql: str) -> bool:
"""
Example:
Expand Down Expand Up @@ -443,31 +395,6 @@ def get_related_ddl(self, question: str, **kwargs) -> list:
"""
pass

@abstractmethod
def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
size: int = 10,
**kwargs) -> list:
"""
This method is used to get similar tables metadata.
Args:
engine (str): The database engine.
catalog (str): The catalog.
schema (str): The schema.
table_name (str): The table name.
ddl (str): The DDL statement.
size (int): The number of tables to return.
Returns:
list: A list of tables metadata.
"""
pass

@abstractmethod
def get_related_documentation(self, question: str, **kwargs) -> list:
"""
Expand Down Expand Up @@ -496,13 +423,12 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
pass

@abstractmethod
def add_ddl(self, ddl: str, engine: str = None, **kwargs) -> str:
def add_ddl(self, ddl: str, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.
Args:
ddl (str): The DDL statement to add.
engine (str): The database engine that the DDL statement applies to.
Returns:
str: The ID of the training data that was added.
Expand Down Expand Up @@ -1852,7 +1778,6 @@ def train(
question: str = None,
sql: str = None,
ddl: str = None,
engine: str = None,
documentation: str = None,
plan: TrainingPlan = None,
) -> str:
Expand All @@ -1873,11 +1798,8 @@ def train(
question (str): The question to train on.
sql (str): The SQL query to train on.
ddl (str): The DDL statement.
engine (str): The database engine.
documentation (str): The documentation to train on.
plan (TrainingPlan): The training plan to train on.
Returns:
str: The training pl
"""

if question and not sql:
Expand All @@ -1895,12 +1817,12 @@ def train(

if ddl:
print("Adding ddl:", ddl)
return self.add_ddl(ddl=ddl, engine=engine)
return self.add_ddl(ddl)

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
self.add_ddl(ddl=item.item_value, engine=engine)
self.add_ddl(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
self.add_documentation(item.item_value)
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
Expand Down
93 changes: 6 additions & 87 deletions src/vanna/opensearch/opensearch_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

import pandas as pd
from opensearchpy import OpenSearch
from ..types import TableMetadata

from ..base import VannaBase
from ..utils import deterministic_uuid


class OpenSearch_VectorStore(VannaBase):
Expand Down Expand Up @@ -58,18 +56,6 @@ def __init__(self, config=None):
},
"mappings": {
"properties": {
"engine": {
"type": "keyword",
},
"catalog": {
"type": "keyword",
},
"schema": {
"type": "keyword",
},
"table_name": {
"type": "keyword",
},
"ddl": {
"type": "text",
},
Expand Down Expand Up @@ -106,8 +92,6 @@ def __init__(self, config=None):
if config is not None and "es_question_sql_index_settings" in config:
question_sql_index_settings = config["es_question_sql_index_settings"]

self.n_results = config.get("n_results", 10)

self.document_index_settings = document_index_settings
self.ddl_index_settings = ddl_index_settings
self.question_sql_index_settings = question_sql_index_settings
Expand Down Expand Up @@ -247,29 +231,10 @@ def create_index_if_not_exists(self, index_name: str,
print(f"Error creating index: {index_name} ", e)
return False

def calculate_md5(self, string: str) -> str:
# 将字符串编码为 bytes
string_bytes = self.encode('utf-8')
# 计算 MD5 哈希值
md5_hash = hashlib.md5(string_bytes)
# 获取十六进制表示的哈希值
md5_hex = md5_hash.hexdigest()
return md5_hex

def add_ddl(self, ddl: str, engine: str = None,
**kwargs) -> str:
def add_ddl(self, ddl: str, **kwargs) -> str:
# Assuming that you have a DDL index in your OpenSearch
table_metadata = VannaBase.extract_table_metadata(ddl)
full_table_name = table_metadata.get_full_table_name()
if full_table_name is not None and engine is not None:
id = deterministic_uuid(engine + "-" + full_table_name) + "-ddl"
else:
id = str(uuid.uuid4()) + "-ddl"
id = str(uuid.uuid4()) + "-ddl"
ddl_dict = {
"engine": engine,
"catalog": table_metadata.catalog,
"schema": table_metadata.schema,
"table_name": table_metadata.table_name,
"ddl": ddl
}
response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id,
Expand Down Expand Up @@ -305,8 +270,7 @@ def get_related_ddl(self, question: str, **kwargs) -> List[str]:
"match": {
"ddl": question
}
},
"size": self.n_results
}
}
print(query)
response = self.client.search(index=self.ddl_index, body=query,
Expand All @@ -319,8 +283,7 @@ def get_related_documentation(self, question: str, **kwargs) -> List[str]:
"match": {
"doc": question
}
},
"size": self.n_results
}
}
print(query)
response = self.client.search(index=self.document_index,
Expand All @@ -334,8 +297,7 @@ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
"match": {
"question": question
}
},
"size": self.n_results
}
}
print(query)
response = self.client.search(index=self.question_sql_index,
Expand All @@ -344,50 +306,6 @@ def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
return [(hit['_source']['question'], hit['_source']['sql']) for hit in
response['hits']['hits']]

def search_tables_metadata(self,
engine: str = None,
catalog: str = None,
schema: str = None,
table_name: str = None,
ddl: str = None,
size: int = 10,
**kwargs) -> list:
# Assume you have some vector search mechanism associated with your data
query = {}
if engine is None and catalog is None and schema is None and table_name is None and ddl is None:
query = {
"query": {
"match_all": {}
}
}
else:
query["query"] = {
"bool": {
"should": [
]
}
}
if engine is not None:
query["query"]["bool"]["should"].append({"match": {"engine": engine}})

if catalog is not None:
query["query"]["bool"]["should"].append({"match": {"catalog": catalog}})

if schema is not None:
query["query"]["bool"]["should"].append({"match": {"schema": schema}})
if table_name is not None:
query["query"]["bool"]["should"].append({"match": {"table_name": table_name}})

if ddl is not None:
query["query"]["bool"]["should"].append({"match": {"ddl": ddl}})

if size > 0:
query["size"] = size

print(query)
response = self.client.search(index=self.ddl_index, body=query, **kwargs)
return [hit['_source'] for hit in response['hits']['hits']]

def get_training_data(self, **kwargs) -> pd.DataFrame:
# This will be a simple example pulling all data from an index
# WARNING: Do not use this approach in production for large indices!
Expand All @@ -397,6 +315,7 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:
body={"query": {"match_all": {}}},
size=1000
)
print(query)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
Expand Down
1 change: 1 addition & 0 deletions src/vanna/oracle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .oracle_vector import Oracle_VectorStore
Loading

0 comments on commit 20b0668

Please sign in to comment.