Skip to content

Commit

Permalink
[clean] format + rename model in test
Browse files Browse the repository at this point in the history
  • Loading branch information
lucaordronneau committed Aug 14, 2024
1 parent e56e945 commit 928e7dc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 36 deletions.
63 changes: 29 additions & 34 deletions src/vanna/azuresearch/azuresearch_vector.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,29 @@
import ast
import json
import pandas as pd

from typing import List

from ..base import VannaBase
from ..utils import deterministic_uuid

from fastembed import TextEmbedding

from azure.search.documents.models import VectorizedQuery

from azure.search.documents import SearchClient
import pandas as pd
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient

from azure.search.documents.models import VectorFilterMode

from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
SearchableField,
SearchField,
SearchFieldDataType,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
SearchIndex,
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
SearchableField,
SearchField,
SearchFieldDataType,
SearchIndex,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from azure.search.documents.models import VectorFilterMode, VectorizedQuery
from fastembed import TextEmbedding

from ..base import VannaBase
from ..utils import deterministic_uuid


class AzureAISearch_VectorStore(VannaBase):
"""
Expand Down Expand Up @@ -77,16 +72,16 @@ def __init__(self, config=None):
)

self.index_client = SearchIndexClient(
endpoint=azure_search_endpoint,
endpoint=azure_search_endpoint,
credential=AzureKeyCredential(azure_search_api_key)
)

self.search_client = SearchClient(
endpoint=azure_search_endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(azure_search_api_key)
)

if self.index_name not in self._get_indexes():
self._create_index()

Expand Down Expand Up @@ -122,7 +117,7 @@ def _create_index(self) -> bool:

def _get_indexes(self) -> list:
return [index for index in self.index_client.list_index_names()]

def add_ddl(self, ddl: str) -> str:
id = deterministic_uuid(ddl) + "-ddl"
document = {
Expand Down Expand Up @@ -211,7 +206,7 @@ def get_similar_question_sql(self, text: str) -> List[str]:

def get_training_data(self) -> List[str]:

search = self.search_client.search(
search = self.search_client.search(
search_text="*",
select=['id', 'document', 'type'],
filter=f"(type eq 'sql') or (type eq 'ddl') or (type eq 'doc')"
Expand All @@ -223,19 +218,19 @@ def get_training_data(self) -> List[str]:
df.loc[df["type"] == "sql", "question"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["question"])
df.loc[df["type"] == "sql", "content"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["sql"])
df.loc[df["type"] != "sql", "content"] = df.loc[df["type"] != "sql"]["document"]

return df[["id", "question", "content", "type"]]

return pd.DataFrame()

def remove_training_data(self, id: str) -> bool:
result = self.search_client.delete_documents(documents=[{'id':id}])
return result[0].succeeded

def remove_index(self):
self.index_client.delete_index(self.index_name)

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = TextEmbedding(model_name=self.fastembed_model)
embedding = next(embedding_model.embed(data))
return embedding.tolist()
return embedding.tolist()
5 changes: 3 additions & 2 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,13 @@ def test_vn_chroma():

from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore


class VannaAzureSearch(AzureAISearch_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
AzureAISearch_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)

vn_azure_search = VannaAzureSearch(config={'azure_search_api_key': AZURE_SEARCH_API_KEY,'api_key': OPENAI_API_KEY, 'model': 'gpt35turbo-1106'})
vn_azure_search = VannaAzureSearch(config={'azure_search_api_key': AZURE_SEARCH_API_KEY,'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
vn_azure_search.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_azure_search():
Expand All @@ -128,7 +129,7 @@ def test_vn_azure_search():
if len(existing_training_data) > 0:
for _, training_data in existing_training_data.iterrows():
vn_azure_search.remove_training_data(training_data['id'])

df_ddl = vn_azure_search.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
for ddl in df_ddl['sql'].to_list():
vn_azure_search.train(ddl=ddl)
Expand Down

0 comments on commit 928e7dc

Please sign in to comment.