From d9640eae4c344a1ac8c4cfacccb5bef0033cd76b Mon Sep 17 00:00:00 2001 From: 01804223 <01804223@yto.net.cn> Date: Fri, 17 May 2024 10:31:13 +0800 Subject: [PATCH 1/2] add milvus vectorstore support, but need to deploy milvus and embedding model in advance --- src/vanna/milvus/milvus_vector.py | 326 ++++++++++++++++++++++++++++++ 1 file changed, 326 insertions(+) create mode 100644 src/vanna/milvus/milvus_vector.py diff --git a/src/vanna/milvus/milvus_vector.py b/src/vanna/milvus/milvus_vector.py new file mode 100644 index 00000000..4a943342 --- /dev/null +++ b/src/vanna/milvus/milvus_vector.py @@ -0,0 +1,326 @@ +import uuid +import random +import json +from abc import ABC +import pandas as pd +from langchain_community.embeddings.xinference import XinferenceEmbeddings +from ..base import VannaBase +from langchain_community.vectorstores.milvus import Milvus +from langchain_core.documents import Document +from pymilvus import MilvusClient, DataType + +ip = #Milvus server ip +port = #Milvus server port +client = MilvusClient("http://ip:port") + + +class Milvus_VectorStore(VannaBase): + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + + if config is not None and "url" in config: + milvus_model_url = config["url"] + else: + milvus_model_ip = #Embedding model ip + milvus_model_port = #Embedding model port + milvus_model_url = "http://milvus_model_ip:milvus_model_port" + + if config is not None and "milvus_model" in config: + milvus_model = config["milvus_model"] + else: + milvus_model = "bge-large-zh-v1.5-yto" + + self.xinference_embeddings = XinferenceEmbeddings( + server_url=milvus_model_url, + model_uid=milvus_model + ) + + def create_sql_collection(self, name: str, **kwargs): + has = client.has_collection(collection_name=name) + # print(has) + if not has: + vannasql_schema = MilvusClient.create_schema( + auto_id=True, + enable_dynamic_field=False, + ) + vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535) + vannasql_schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535) + vannasql_schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535) + vannasql_schema.add_field(field_name="my_id", datatype=DataType.INT64, is_primary=True) + vannasql_schema.add_field(field_name="my_vector", datatype=DataType.FLOAT_VECTOR, dim=1024) + + vannasql_index_params = client.prepare_index_params() + vannasql_index_params.add_index( + field_name="my_vector", + index_name="my_vector", + index_type="IVF_FLAT", + metric_type="L2", + params={"nlist": 1024} + ) + client.create_collection( + collection_name=name, + schema=vannasql_schema, + index_params=vannasql_index_params, + consistency_level="Strong" + ) + else: + pass + + def create_ddl_collection(self, name: str, **kwargs): + has = client.has_collection(collection_name=name) + # print(has) + if not has: + vannaddl_schema = MilvusClient.create_schema( + auto_id=True, + enable_dynamic_field=False, + ) + vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535) + vannaddl_schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535) + vannaddl_schema.add_field(field_name="my_id", datatype=DataType.INT64, is_primary=True) + vannaddl_schema.add_field(field_name="my_vector", datatype=DataType.FLOAT_VECTOR, dim=1024) + + vannaddl_index_params = client.prepare_index_params() + vannaddl_index_params.add_index( + field_name="my_vector", + index_name="my_vector", + index_type="IVF_FLAT", + metric_type="L2", + params={"nlist": 1024} + ) + client.create_collection( + collection_name=name, + schema=vannaddl_schema, + index_params=vannaddl_index_params, + consistency_level="Strong" + ) + else: + pass + + def create_doc_collection(self, name: str, **kwargs): + has = client.has_collection(collection_name=name) + # print(has) + if not has: + vannadoc_schema = MilvusClient.create_schema( + auto_id=True, + enable_dynamic_field=False, + ) + vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535) + vannadoc_schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) + vannadoc_schema.add_field(field_name="my_id", datatype=DataType.INT64, is_primary=True) + vannadoc_schema.add_field(field_name="my_vector", datatype=DataType.FLOAT_VECTOR, dim=1024) + + vannadoc_index_params = client.prepare_index_params() + vannadoc_index_params.add_index( + field_name="my_vector", + index_name="my_vector", + index_type="IVF_FLAT", + metric_type="L2", + params={"nlist": 1024} + ) + client.create_collection( + collection_name=name, + schema=vannadoc_schema, + index_params=vannadoc_index_params, + consistency_level="Strong" + ) + else: + pass + + def add_question_sql(self, question: str, sql: str, **kwargs) -> str: + if len(question) == 0 or len(sql) == 0: + raise Exception("pair of question and sql can not be null") + self.create_sql_collection("vannasql") + question_sql_json = json.dumps( + { + "question": question, + "sql": sql, + }, + ensure_ascii=False, + ) + # a = random.randint(1,2**63-1) + _id = str(uuid.uuid4()) + "-sql" + embeddings = self.xinference_embeddings.embed_query(question) + res = client.insert( + collection_name="vannasql", + data={ + 'id': _id, + 'text': question, + 'sql': sql, + # 'my_id':a, + 'my_vector': embeddings + } + ) + return _id + + def add_ddl(self, ddl: str, **kwargs) -> str: + if len(ddl) == 0: + raise Exception("ddl can not be null") + self.create_ddl_collection("vannaddl") + # b = random.randint(1,2**63-1) + _id = str(uuid.uuid4()) + "-ddl" + embeddings = self.xinference_embeddings.embed_query(ddl) + res = client.insert( + collection_name="vannaddl", + data={ + 'id': _id, + 'ddl': ddl, + # 'my_id':b, + 'my_vector': embeddings + } + ) + return _id + + def add_documentation(self, documentation: str, **kwargs) -> str: + if len(documentation) == 0: + raise Exception("documentation can not be null") + self.create_doc_collection("vannadoc") + # c = random.randint(1,2**63-1) + _id = str(uuid.uuid4()) + "-doc" + embeddings = self.xinference_embeddings.embed_query(documentation) + res = client.insert( + collection_name="vannadoc", + data={ + 'id': _id, + 'doc': documentation, + # 'my_id':c, + 'my_vector': embeddings + } + ) + return _id + + def get_training_data(self, **kwargs) -> pd.DataFrame: + sql_data = client.query( + collection_name="vannasql", + filter="my_id > 0", + output_fields=["*"], + ) + df = pd.DataFrame() + if sql_data is not None: + df_sql = pd.DataFrame( + { + "id": [doc["id"] for doc in sql_data], + "question": [doc["text"] for doc in sql_data], + "content": [doc["sql"] for doc in sql_data], + } + ) + df = pd.concat([df, df_sql]) + + ddl_data = client.query( + collection_name="vannaddl", + filter="my_id > 0", + output_fields=["*"], + ) + + if ddl_data is not None: + df_ddl = pd.DataFrame( + { + "id": [doc["id"] for doc in ddl_data], + "question": [None for doc in ddl_data], + "content": [doc["ddl"] for doc in ddl_data], + } + ) + df = pd.concat([df, df_ddl]) + + doc_data = client.query( + collection_name="vannadoc", + filter="my_id > 0", + output_fields=["*"], + ) + + if doc_data is not None: + df_doc = pd.DataFrame( + { + "id": [doc["id"] for doc in doc_data], + "question": [None for doc in doc_data], + "content": [doc["doc"] for doc in doc_data], + } + ) + df = pd.concat([df, df_doc]) + return df + + def get_similar_question_sql(self, question: str, **kwargs) -> list: + search_params = { + "metric_type": "L2", + "params": {"nprobe": 128}, + } + embeddings = [self.xinference_embeddings.embed_query(question)] + res = client.search( + collection_name="vannasql", + anns_field="my_vector", + data=embeddings, + limit=10, + filter='', + output_fields=["text", "sql"], + search_params=search_params + ) + res = res[0] + + list_sql = [] + for doc in res: + dict = {} + dict["question"] = doc["entity"]["text"] + dict["sql"] = doc["entity"]["sql"] + #print(dict) + list_sql.append(dict) + return list_sql + + + def get_related_ddl(self, question: str, **kwargs) -> list: + search_params = { + "metric_type": "L2", + "params": {"nprobe": 128}, + } + embeddings = [self.xinference_embeddings.embed_query(question)] + res = client.search( + collection_name="vannaddl", + anns_field="my_vector", + data=embeddings, + limit=1, + filter='', + output_fields=["ddl"], + search_params=search_params + ) + res = res[0] + + list_ddl = [] + for doc in res: + list_ddl.append(doc["entity"]["ddl"]) + return list_ddl + + + def get_related_documentation(self, question: str, **kwargs) -> list: + search_params = { + "metric_type": "L2", + "params": {"nprobe": 128}, + } + embeddings = [self.xinference_embeddings.embed_query(question)] + res = client.search( + collection_name="vannadoc", + anns_field="my_vector", + data=embeddings, + limit=1, + filter='', + output_fields=["doc"], + search_params=search_params + ) + res = res[0] + + list_doc = [] + for doc in res: + list_doc.append(doc["entity"]["doc"]) + return list_doc + + + def remove_training_data(self, id: str, **kwargs) -> bool: + pass + # if id.endswith("-sql"): + # self.mq.index("vanna-sql").delete_documents(ids=[id]) + # return True + # elif id.endswith("-ddl"): + # self.mq.index("vanna-ddl").delete_documents(ids=[id]) + # return True + # elif id.endswith("-doc"): + # self.mq.index("vanna-doc").delete_documents(ids=[id]) + # return True + # else: + # return False From d23bd5e672c5ed652268b0b60de44e7271ffdc3e Mon Sep 17 00:00:00 2001 From: ChengZi Date: Thu, 13 Jun 2024 19:33:56 +0800 Subject: [PATCH 2/2] refine milvus support, and use emb models from milvus. Signed-off-by: ChengZi --- pyproject.toml | 3 +- src/vanna/milvus/__init__.py | 1 + src/vanna/milvus/milvus_vector.py | 305 ++++++++++++++---------------- tests/test_imports.py | 2 + tests/test_vanna.py | 28 +++ 5 files changed, 175 insertions(+), 164 deletions(-) create mode 100644 src/vanna/milvus/__init__.py diff --git a/pyproject.toml b/pyproject.toml index afb42550..9413ac8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] google = ["google-generativeai", "google-cloud-aiplatform"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] @@ -48,3 +48,4 @@ vllm = ["vllm"] pinecone = ["pinecone-client", "fastembed"] opensearch = ["opensearch-py", "opensearch-dsl"] hf = ["transformers"] +milvus = ["pymilvus[model]"] diff --git a/src/vanna/milvus/__init__.py b/src/vanna/milvus/__init__.py new file mode 100644 index 00000000..a64dd01a --- /dev/null +++ b/src/vanna/milvus/__init__.py @@ -0,0 +1 @@ +from .milvus_vector import Milvus_VectorStore diff --git a/src/vanna/milvus/milvus_vector.py b/src/vanna/milvus/milvus_vector.py index 4a943342..4ae210a9 100644 --- a/src/vanna/milvus/milvus_vector.py +++ b/src/vanna/milvus/milvus_vector.py @@ -1,153 +1,149 @@ import uuid -import random -import json -from abc import ABC +from typing import List + import pandas as pd -from langchain_community.embeddings.xinference import XinferenceEmbeddings +from pymilvus import DataType, MilvusClient, model + from ..base import VannaBase -from langchain_community.vectorstores.milvus import Milvus -from langchain_core.documents import Document -from pymilvus import MilvusClient, DataType -ip = #Milvus server ip -port = #Milvus server port -client = MilvusClient("http://ip:port") +# Setting the URI as a local file, e.g.`./milvus.db`, +# is the most convenient method, as it automatically utilizes Milvus Lite +# to store all data in this file. +# +# If you have large scale of data such as more than a million docs, we +# recommend setting up a more performant Milvus server on docker or kubernetes. +# When using this setup, please use the server URI, +# e.g.`http://localhost:19530`, as your URI. + +DEFAULT_MILVUS_URI = "./milvus.db" +# DEFAULT_MILVUS_URI = "http://localhost:19530" + +MAX_LIMIT_SIZE = 10_000 class Milvus_VectorStore(VannaBase): + """ + Vectorstore implementation using Milvus - https://milvus.io/docs/quickstart.md + + Args: + - config (dict, optional): Dictionary of `Milvus_VectorStore config` options. Defaults to `None`. + - milvus_client: A `pymilvus.MilvusClient` instance. + - embedding_function: + A `milvus_model.base.BaseEmbeddingFunction` instance. Defaults to `DefaultEmbeddingFunction()`. + For more models, please refer to: + https://milvus.io/docs/embeddings.md + """ def __init__(self, config=None): VannaBase.__init__(self, config=config) - if config is not None and "url" in config: - milvus_model_url = config["url"] + if "milvus_client" in config: + self.milvus_client = config["milvus_client"] else: - milvus_model_ip = #Embedding model ip - milvus_model_port = #Embedding model port - milvus_model_url = "http://milvus_model_ip:milvus_model_port" + self.milvus_client = MilvusClient(uri=DEFAULT_MILVUS_URI) - if config is not None and "milvus_model" in config: - milvus_model = config["milvus_model"] + if "embedding_function" in config: + self.embedding_function = config.get("embedding_function") else: - milvus_model = "bge-large-zh-v1.5-yto" + self.embedding_function = model.DefaultEmbeddingFunction() + self._embedding_dim = self.embedding_function.encode_documents(["foo"])[0].shape[0] + self._create_collections() + self.n_results = config.get("n_results", 10) - self.xinference_embeddings = XinferenceEmbeddings( - server_url=milvus_model_url, - model_uid=milvus_model - ) + def _create_collections(self): + self._create_sql_collection("vannasql") + self._create_ddl_collection("vannaddl") + self._create_doc_collection("vannadoc") + + + def generate_embedding(self, data: str, **kwargs) -> List[float]: + return self.embedding_function.encode_documents(data).tolist() - def create_sql_collection(self, name: str, **kwargs): - has = client.has_collection(collection_name=name) - # print(has) - if not has: + + def _create_sql_collection(self, name: str): + if not self.milvus_client.has_collection(collection_name=name): vannasql_schema = MilvusClient.create_schema( - auto_id=True, + auto_id=False, enable_dynamic_field=False, ) - vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535) + vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True) vannasql_schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535) vannasql_schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535) - vannasql_schema.add_field(field_name="my_id", datatype=DataType.INT64, is_primary=True) - vannasql_schema.add_field(field_name="my_vector", datatype=DataType.FLOAT_VECTOR, dim=1024) + vannasql_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim) - vannasql_index_params = client.prepare_index_params() + vannasql_index_params = self.milvus_client.prepare_index_params() vannasql_index_params.add_index( - field_name="my_vector", - index_name="my_vector", - index_type="IVF_FLAT", + field_name="vector", + index_name="vector", + index_type="AUTOINDEX", metric_type="L2", - params={"nlist": 1024} ) - client.create_collection( + self.milvus_client.create_collection( collection_name=name, schema=vannasql_schema, index_params=vannasql_index_params, consistency_level="Strong" ) - else: - pass - def create_ddl_collection(self, name: str, **kwargs): - has = client.has_collection(collection_name=name) - # print(has) - if not has: + def _create_ddl_collection(self, name: str): + if not self.milvus_client.has_collection(collection_name=name): vannaddl_schema = MilvusClient.create_schema( - auto_id=True, + auto_id=False, enable_dynamic_field=False, ) - vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535) + vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True) vannaddl_schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535) - vannaddl_schema.add_field(field_name="my_id", datatype=DataType.INT64, is_primary=True) - vannaddl_schema.add_field(field_name="my_vector", datatype=DataType.FLOAT_VECTOR, dim=1024) + vannaddl_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim) - vannaddl_index_params = client.prepare_index_params() + vannaddl_index_params = self.milvus_client.prepare_index_params() vannaddl_index_params.add_index( - field_name="my_vector", - index_name="my_vector", - index_type="IVF_FLAT", + field_name="vector", + index_name="vector", + index_type="AUTOINDEX", metric_type="L2", - params={"nlist": 1024} ) - client.create_collection( + self.milvus_client.create_collection( collection_name=name, schema=vannaddl_schema, index_params=vannaddl_index_params, consistency_level="Strong" ) - else: - pass - def create_doc_collection(self, name: str, **kwargs): - has = client.has_collection(collection_name=name) - # print(has) - if not has: + def _create_doc_collection(self, name: str): + if not self.milvus_client.has_collection(collection_name=name): vannadoc_schema = MilvusClient.create_schema( - auto_id=True, + auto_id=False, enable_dynamic_field=False, ) - vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535) + vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True) vannadoc_schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) - vannadoc_schema.add_field(field_name="my_id", datatype=DataType.INT64, is_primary=True) - vannadoc_schema.add_field(field_name="my_vector", datatype=DataType.FLOAT_VECTOR, dim=1024) + vannadoc_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim) - vannadoc_index_params = client.prepare_index_params() + vannadoc_index_params = self.milvus_client.prepare_index_params() vannadoc_index_params.add_index( - field_name="my_vector", - index_name="my_vector", - index_type="IVF_FLAT", + field_name="vector", + index_name="vector", + index_type="AUTOINDEX", metric_type="L2", - params={"nlist": 1024} ) - client.create_collection( + self.milvus_client.create_collection( collection_name=name, schema=vannadoc_schema, index_params=vannadoc_index_params, consistency_level="Strong" ) - else: - pass def add_question_sql(self, question: str, sql: str, **kwargs) -> str: if len(question) == 0 or len(sql) == 0: raise Exception("pair of question and sql can not be null") - self.create_sql_collection("vannasql") - question_sql_json = json.dumps( - { - "question": question, - "sql": sql, - }, - ensure_ascii=False, - ) - # a = random.randint(1,2**63-1) _id = str(uuid.uuid4()) + "-sql" - embeddings = self.xinference_embeddings.embed_query(question) - res = client.insert( + embedding = self.embedding_function.encode_documents([question])[0] + self.milvus_client.insert( collection_name="vannasql", data={ - 'id': _id, - 'text': question, - 'sql': sql, - # 'my_id':a, - 'my_vector': embeddings + "id": _id, + "text": question, + "sql": sql, + "vector": embedding } ) return _id @@ -155,17 +151,14 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str: def add_ddl(self, ddl: str, **kwargs) -> str: if len(ddl) == 0: raise Exception("ddl can not be null") - self.create_ddl_collection("vannaddl") - # b = random.randint(1,2**63-1) _id = str(uuid.uuid4()) + "-ddl" - embeddings = self.xinference_embeddings.embed_query(ddl) - res = client.insert( + embedding = self.embedding_function.encode_documents([ddl])[0] + self.milvus_client.insert( collection_name="vannaddl", data={ - 'id': _id, - 'ddl': ddl, - # 'my_id':b, - 'my_vector': embeddings + "id": _id, + "ddl": ddl, + "vector": embedding } ) return _id @@ -173,68 +166,62 @@ def add_ddl(self, ddl: str, **kwargs) -> str: def add_documentation(self, documentation: str, **kwargs) -> str: if len(documentation) == 0: raise Exception("documentation can not be null") - self.create_doc_collection("vannadoc") - # c = random.randint(1,2**63-1) _id = str(uuid.uuid4()) + "-doc" - embeddings = self.xinference_embeddings.embed_query(documentation) - res = client.insert( + embedding = self.embedding_function.encode_documents([documentation])[0] + self.milvus_client.insert( collection_name="vannadoc", data={ - 'id': _id, - 'doc': documentation, - # 'my_id':c, - 'my_vector': embeddings + "id": _id, + "doc": documentation, + "vector": embedding } ) return _id def get_training_data(self, **kwargs) -> pd.DataFrame: - sql_data = client.query( + sql_data = self.milvus_client.query( collection_name="vannasql", - filter="my_id > 0", output_fields=["*"], + limit=MAX_LIMIT_SIZE, ) df = pd.DataFrame() - if sql_data is not None: - df_sql = pd.DataFrame( - { - "id": [doc["id"] for doc in sql_data], - "question": [doc["text"] for doc in sql_data], - "content": [doc["sql"] for doc in sql_data], - } - ) + df_sql = pd.DataFrame( + { + "id": [doc["id"] for doc in sql_data], + "question": [doc["text"] for doc in sql_data], + "content": [doc["sql"] for doc in sql_data], + } + ) df = pd.concat([df, df_sql]) - ddl_data = client.query( + ddl_data = self.milvus_client.query( collection_name="vannaddl", - filter="my_id > 0", output_fields=["*"], + limit=MAX_LIMIT_SIZE, ) - if ddl_data is not None: - df_ddl = pd.DataFrame( - { - "id": [doc["id"] for doc in ddl_data], - "question": [None for doc in ddl_data], - "content": [doc["ddl"] for doc in ddl_data], - } - ) + df_ddl = pd.DataFrame( + { + "id": [doc["id"] for doc in ddl_data], + "question": [None for doc in ddl_data], + "content": [doc["ddl"] for doc in ddl_data], + } + ) df = pd.concat([df, df_ddl]) - doc_data = client.query( + doc_data = self.milvus_client.query( collection_name="vannadoc", - filter="my_id > 0", output_fields=["*"], + limit=MAX_LIMIT_SIZE, ) - if doc_data is not None: - df_doc = pd.DataFrame( - { - "id": [doc["id"] for doc in doc_data], - "question": [None for doc in doc_data], - "content": [doc["doc"] for doc in doc_data], - } - ) + df_doc = pd.DataFrame( + { + "id": [doc["id"] for doc in doc_data], + "question": [None for doc in doc_data], + "content": [doc["doc"] for doc in doc_data], + } + ) df = pd.concat([df, df_doc]) return df @@ -243,13 +230,12 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list: "metric_type": "L2", "params": {"nprobe": 128}, } - embeddings = [self.xinference_embeddings.embed_query(question)] - res = client.search( + embeddings = self.embedding_function.encode_queries([question]) + res = self.milvus_client.search( collection_name="vannasql", - anns_field="my_vector", + anns_field="vector", data=embeddings, - limit=10, - filter='', + limit=self.n_results, output_fields=["text", "sql"], search_params=search_params ) @@ -260,23 +246,20 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list: dict = {} dict["question"] = doc["entity"]["text"] dict["sql"] = doc["entity"]["sql"] - #print(dict) list_sql.append(dict) return list_sql - def get_related_ddl(self, question: str, **kwargs) -> list: search_params = { "metric_type": "L2", "params": {"nprobe": 128}, } - embeddings = [self.xinference_embeddings.embed_query(question)] - res = client.search( + embeddings = self.embedding_function.encode_queries([question]) + res = self.milvus_client.search( collection_name="vannaddl", - anns_field="my_vector", + anns_field="vector", data=embeddings, - limit=1, - filter='', + limit=self.n_results, output_fields=["ddl"], search_params=search_params ) @@ -287,19 +270,17 @@ def get_related_ddl(self, question: str, **kwargs) -> list: list_ddl.append(doc["entity"]["ddl"]) return list_ddl - def get_related_documentation(self, question: str, **kwargs) -> list: search_params = { "metric_type": "L2", "params": {"nprobe": 128}, } - embeddings = [self.xinference_embeddings.embed_query(question)] - res = client.search( + embeddings = self.embedding_function.encode_queries([question]) + res = self.milvus_client.search( collection_name="vannadoc", - anns_field="my_vector", + anns_field="vector", data=embeddings, - limit=1, - filter='', + limit=self.n_results, output_fields=["doc"], search_params=search_params ) @@ -310,17 +291,15 @@ def get_related_documentation(self, question: str, **kwargs) -> list: list_doc.append(doc["entity"]["doc"]) return list_doc - def remove_training_data(self, id: str, **kwargs) -> bool: - pass - # if id.endswith("-sql"): - # self.mq.index("vanna-sql").delete_documents(ids=[id]) - # return True - # elif id.endswith("-ddl"): - # self.mq.index("vanna-ddl").delete_documents(ids=[id]) - # return True - # elif id.endswith("-doc"): - # self.mq.index("vanna-doc").delete_documents(ids=[id]) - # return True - # else: - # return False + if id.endswith("-sql"): + self.milvus_client.delete(collection_name="vannasql", ids=[id]) + return True + elif id.endswith("-ddl"): + self.milvus_client.delete(collection_name="vannaddl", ids=[id]) + return True + elif id.endswith("-doc"): + self.milvus_client.delete(collection_name="vannadoc", ids=[id]) + return True + else: + return False diff --git a/tests/test_imports.py b/tests/test_imports.py index 92db890d..15931e77 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -7,6 +7,7 @@ def test_regular_imports(): from vanna.hf.hf import Hf from vanna.local import LocalContext_OpenAI from vanna.marqo.marqo import Marqo_VectorStore + from vanna.milvus.milvus_vector import Milvus_VectorStore from vanna.mistral.mistral import Mistral from vanna.ollama.ollama import Ollama from vanna.openai.openai_chat import OpenAI_Chat @@ -24,6 +25,7 @@ def test_shortcut_imports(): from vanna.chromadb import ChromaDB_VectorStore from vanna.hf import Hf from vanna.marqo import Marqo_VectorStore + from vanna.milvus import Milvus_VectorStore from vanna.mistral import Mistral from vanna.ollama import Ollama from vanna.openai import OpenAI_Chat, OpenAI_Embeddings diff --git a/tests/test_vanna.py b/tests/test_vanna.py index 378f712b..67e34d7b 100644 --- a/tests/test_vanna.py +++ b/tests/test_vanna.py @@ -111,6 +111,34 @@ def test_vn_chroma(): df = vn_chroma.run_sql(sql) assert len(df) == 7 + +from vanna.milvus import Milvus_VectorStore + + +class VannaMilvus(Milvus_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + Milvus_VectorStore.__init__(self, config=config) + OpenAI_Chat.__init__(self, config=config) + +vn_milvus = VannaMilvus(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) +vn_milvus.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') + +def test_vn_milvus(): + existing_training_data = vn_milvus.get_training_data() + if len(existing_training_data) > 0: + for _, training_data in existing_training_data.iterrows(): + vn_milvus.remove_training_data(training_data['id']) + + df_ddl = vn_milvus.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null") + + for ddl in df_ddl['sql'].to_list(): + vn_milvus.train(ddl=ddl) + + sql = vn_milvus.generate_sql("What are the top 7 customers by sales?") + df = vn_milvus.run_sql(sql) + assert len(df) == 7 + + class VannaNumResults(ChromaDB_VectorStore, OpenAI_Chat): def __init__(self, config=None): ChromaDB_VectorStore.__init__(self, config=config)