Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/azuresearch vector support #598

Merged
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -52,3 +52,4 @@ hf = ["transformers"]
milvus = ["pymilvus[model]"]
bedrock = ["boto3", "botocore"]
weaviate = ["weaviate-client"]
azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"]
1 change: 1 addition & 0 deletions src/vanna/azuresearch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .azuresearch_vector import AzureAISearch_VectorStore
236 changes: 236 additions & 0 deletions src/vanna/azuresearch/azuresearch_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import ast
import json
from typing import List

import pandas as pd
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
SearchableField,
SearchField,
SearchFieldDataType,
SearchIndex,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from azure.search.documents.models import VectorFilterMode, VectorizedQuery
from fastembed import TextEmbedding

from ..base import VannaBase
from ..utils import deterministic_uuid


class AzureAISearch_VectorStore(VannaBase):
"""
AzureAISearch_VectorStore is a class that provides a vector store for Azure AI Search.

Args:
config (dict): Configuration dictionary. Defaults to {}. You must provide an API key in the config.
- azure_search_endpoint (str, optional): Azure Search endpoint. Defaults to "https://azcognetive.search.windows.net".
- azure_search_api_key (str): Azure Search API key.
- dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which corresponds to the dimensions of BAAI/bge-small-en-v1.5.
- fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
- index_name (str, optional): Name of the index. Defaults to "vanna-index".
- n_results (int, optional): Number of results to return. Defaults to 10.
- n_results_ddl (int, optional): Number of results to return for DDL queries. Defaults to the value of n_results.
- n_results_sql (int, optional): Number of results to return for SQL queries. Defaults to the value of n_results.
- n_results_documentation (int, optional): Number of results to return for documentation queries. Defaults to the value of n_results.

Raises:
ValueError: If config is None, or if 'azure_search_api_key' is not provided in the config.
"""
def __init__(self, config=None):
VannaBase.__init__(self, config=config)

self.config = config or None

if config is None:
raise ValueError(
"config is required, pass an API key, 'azure_search_api_key', in the config."
)

azure_search_endpoint = config.get("azure_search_endpoint", "https://azcognetive.search.windows.net")
azure_search_api_key = config.get("azure_search_api_key")

self.dimensions = config.get("dimensions", 384)
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")

self.index_name = config.get("index_name", "vanna-index")

self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))

if not azure_search_api_key:
raise ValueError(
"'azure_search_api_key' is required in config to use AzureAISearch_VectorStore"
)

self.index_client = SearchIndexClient(
endpoint=azure_search_endpoint,
credential=AzureKeyCredential(azure_search_api_key)
)

self.search_client = SearchClient(
endpoint=azure_search_endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(azure_search_api_key)
)

if self.index_name not in self._get_indexes():
self._create_index()

def _create_index(self) -> bool:
fields = [
SearchableField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
SearchableField(name="document", type=SearchFieldDataType.String, searchable=True, filterable=True),
SearchField(name="type", type=SearchFieldDataType.String, filterable=True, searchable=True),
SearchField(name="document_vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=self.dimensions, vector_search_profile_name="ExhaustiveKnnProfile"),
]

vector_search = VectorSearch(
algorithms=[
ExhaustiveKnnAlgorithmConfiguration(
name="ExhaustiveKnn",
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
parameters=ExhaustiveKnnParameters(
metric=VectorSearchAlgorithmMetric.COSINE
)
)
],
profiles=[
VectorSearchProfile(
name="ExhaustiveKnnProfile",
algorithm_configuration_name="ExhaustiveKnn",
)
]
)

index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
result = self.index_client.create_or_update_index(index)
print(f'{result.name} created')

def _get_indexes(self) -> list:
return [index for index in self.index_client.list_index_names()]

def add_ddl(self, ddl: str) -> str:
id = deterministic_uuid(ddl) + "-ddl"
document = {
"id": id,
"document": ddl,
"type": "ddl",
"document_vector": self.generate_embedding(ddl)
}
self.search_client.upload_documents(documents=[document])
return id

def add_documentation(self, doc: str) -> str:
id = deterministic_uuid(doc) + "-doc"
document = {
"id": id,
"document": doc,
"type": "doc",
"document_vector": self.generate_embedding(doc)
}
self.search_client.upload_documents(documents=[document])
return id

def add_question_sql(self, question: str, sql: str) -> str:
question_sql_json = json.dumps({"question": question, "sql": sql}, ensure_ascii=False)
id = deterministic_uuid(question_sql_json) + "-sql"
document = {
"id": id,
"document": question_sql_json,
"type": "sql",
"document_vector": self.generate_embedding(question_sql_json)
}
self.search_client.upload_documents(documents=[document])
return id

def get_related_ddl(self, text: str) -> List[str]:
result = []
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
df = pd.DataFrame(
self.search_client.search(
top=self.n_results_ddl,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'ddl'"
)
)

if len(df):
result = df["document"].tolist()
return result

def get_related_documentation(self, text: str) -> List[str]:
result = []
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")

df = pd.DataFrame(
self.search_client.search(
top=self.n_results_documentation,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'doc'",
vector_filter_mode=VectorFilterMode.PRE_FILTER
)
)

if len(df):
result = df["document"].tolist()
return result

def get_similar_question_sql(self, text: str) -> List[str]:
result = []
# Vectorize the text
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
df = pd.DataFrame(
self.search_client.search(
top=self.n_results_sql,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'sql'"
)
)

if len(df): # Check if there is similar query and the result is not empty
result = [ast.literal_eval(element) for element in df["document"].tolist()]

return result

def get_training_data(self) -> List[str]:

search = self.search_client.search(
search_text="*",
select=['id', 'document', 'type'],
filter=f"(type eq 'sql') or (type eq 'ddl') or (type eq 'doc')"
).by_page()

df = pd.DataFrame([item for page in search for item in page])

if len(df):
df.loc[df["type"] == "sql", "question"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["question"])
df.loc[df["type"] == "sql", "content"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["sql"])
df.loc[df["type"] != "sql", "content"] = df.loc[df["type"] != "sql"]["document"]

return df[["id", "question", "content", "type"]]

return pd.DataFrame()

def remove_training_data(self, id: str) -> bool:
result = self.search_client.delete_documents(documents=[{'id':id}])
return result[0].succeeded

def remove_index(self):
self.index_client.delete_index(self.index_name)

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = TextEmbedding(model_name=self.fastembed_model)
embedding = next(embedding_model.embed(data))
return embedding.tolist()
2 changes: 2 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_regular_imports():
from vanna.weaviate.weaviate_vector import WeaviateDatabase
from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat
from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings
from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore

def test_shortcut_imports():
from vanna.anthropic import Anthropic_Chat
Expand All @@ -36,3 +37,4 @@ def test_shortcut_imports():
from vanna.vllm import Vllm
from vanna.weaviate import WeaviateDatabase
from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings
from vanna.azuresearch import AzureAISearch_VectorStore
26 changes: 26 additions & 0 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SNOWFLAKE_ACCOUNT = os.environ['SNOWFLAKE_ACCOUNT']
SNOWFLAKE_USERNAME = os.environ['SNOWFLAKE_USERNAME']
SNOWFLAKE_PASSWORD = os.environ['SNOWFLAKE_PASSWORD']
# AZURE_SEARCH_API_KEY = os.environ['AZURE_SEARCH_API_KEY']

class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
Expand Down Expand Up @@ -111,6 +112,31 @@ def test_vn_chroma():
df = vn_chroma.run_sql(sql)
assert len(df) == 7

# from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore


# class VannaAzureSearch(AzureAISearch_VectorStore, OpenAI_Chat):
# def __init__(self, config=None):
# AzureAISearch_VectorStore.__init__(self, config=config)
# OpenAI_Chat.__init__(self, config=config)

# vn_azure_search = VannaAzureSearch(config={'azure_search_api_key': AZURE_SEARCH_API_KEY,'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
# vn_azure_search.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

# def test_vn_azure_search():
# existing_training_data = vn_azure_search.get_training_data()
# print(existing_training_data)
# if len(existing_training_data) > 0:
# for _, training_data in existing_training_data.iterrows():
# vn_azure_search.remove_training_data(training_data['id'])

# df_ddl = vn_azure_search.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
# for ddl in df_ddl['sql'].to_list():
# vn_azure_search.train(ddl=ddl)

# sql = vn_azure_search.generate_sql("What are the top 7 customers by sales?")
# df = vn_azure_search.run_sql(sql)
# assert len(df) == 7

from vanna.milvus import Milvus_VectorStore

Expand Down