Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Configurable collection names in Qdrant_VectorStore #407

Merged
merged 1 commit into from
May 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 71 additions & 57 deletions src/vanna/qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,51 @@
from ..base import VannaBase
from ..utils import deterministic_uuid

DOCUMENTATION_COLLECTION_NAME = "documentation"
DDL_COLLECTION_NAME = "ddl"
SQL_COLLECTION_NAME = "sql"
SCROLL_SIZE = 1000

ID_SUFFIXES = {
DDL_COLLECTION_NAME: "ddl",
DOCUMENTATION_COLLECTION_NAME: "doc",
SQL_COLLECTION_NAME: "sql",
}


class Qdrant_VectorStore(VannaBase):
"""Vectorstore implementation using Qdrant - https://qdrant.tech/"""
"""
Vectorstore implementation using Qdrant - https://qdrant.tech/

Args:
- config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
- client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
- location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
- url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
- prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
- https: If `true` - use HTTPS(SSL) protocol. Default: `None`
- api_key: API key for authentication in Qdrant Cloud. Default: `None`
- timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
- path: Persistence path for QdrantLocal. Default: `None`.
- prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
- n_results: Number of results to return from similarity search. Defaults to 10.
- fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
Defaults to `"BAAI/bge-small-en-v1.5"`.
- collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
- distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
- documentation_collection_name: Name of the collection to store documentation. Defaults to `"documentation"`.
- ddl_collection_name: Name of the collection to store DDL. Defaults to `"ddl"`.
- sql_collection_name: Name of the collection to store SQL. Defaults to `"sql"`.

Raises:
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
"""

documentation_collection_name = "documentation"
ddl_collection_name = "ddl"
sql_collection_name = "sql"

id_suffixes = {
ddl_collection_name: "ddl",
documentation_collection_name: "doc",
sql_collection_name: "sql",
}

def __init__(
self,
config={},
):
"""
Vectorstore implementation using Qdrant - https://qdrant.tech/

Args:
- config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
- client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
- location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
- url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
- prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
- https: If `true` - use HTTPS(SSL) protocol. Default: `None`
- api_key: API key for authentication in Qdrant Cloud. Default: `None`
- timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
- path: Persistence path for QdrantLocal. Default: `None`.
- prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
- n_results: Number of results to return from similarity search. Defaults to 10.
- fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
Defaults to `"BAAI/bge-small-en-v1.5"`.
- collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
- distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.

Raises:
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
"""
VannaBase.__init__(self, config=config)
client = config.get("client")

Expand All @@ -75,6 +78,15 @@ def __init__(
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
self.collection_params = config.get("collection_params", {})
self.distance_metric = config.get("distance_metric", models.Distance.COSINE)
self.documentation_collection_name = config.get(
"documentation_collection_name", self.documentation_collection_name
)
self.ddl_collection_name = config.get(
"ddl_collection_name", self.ddl_collection_name
)
self.sql_collection_name = config.get(
"sql_collection_name", self.sql_collection_name
)

self._setup_collections()

Expand All @@ -83,7 +95,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
id = deterministic_uuid(question_answer)

self._client.upsert(
SQL_COLLECTION_NAME,
self.sql_collection_name,
points=[
models.PointStruct(
id=id,
Expand All @@ -96,12 +108,12 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
],
)

return self._format_point_id(id, SQL_COLLECTION_NAME)
return self._format_point_id(id, self.sql_collection_name)

def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl)
self._client.upsert(
DDL_COLLECTION_NAME,
self.ddl_collection_name,
points=[
models.PointStruct(
id=id,
Expand All @@ -112,13 +124,13 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
)
],
)
return self._format_point_id(id, DDL_COLLECTION_NAME)
return self._format_point_id(id, self.ddl_collection_name)

def add_documentation(self, documentation: str, **kwargs) -> str:
id = deterministic_uuid(documentation)

self._client.upsert(
DOCUMENTATION_COLLECTION_NAME,
self.documentation_collection_name,
points=[
models.PointStruct(
id=id,
Expand All @@ -130,16 +142,17 @@ def add_documentation(self, documentation: str, **kwargs) -> str:
],
)

return self._format_point_id(id, DOCUMENTATION_COLLECTION_NAME)
return self._format_point_id(id, self.documentation_collection_name)

def get_training_data(self, **kwargs) -> pd.DataFrame:
df = pd.DataFrame()

if sql_data := self._get_all_points(SQL_COLLECTION_NAME):
if sql_data := self._get_all_points(self.sql_collection_name):
question_list = [data.payload["question"] for data in sql_data]
sql_list = [data.payload["sql"] for data in sql_data]
id_list = [
self._format_point_id(data.id, SQL_COLLECTION_NAME) for data in sql_data
self._format_point_id(data.id, self.sql_collection_name)
for data in sql_data
]

df_sql = pd.DataFrame(
Expand All @@ -154,10 +167,11 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

df = pd.concat([df, df_sql])

if ddl_data := self._get_all_points(DDL_COLLECTION_NAME):
if ddl_data := self._get_all_points(self.ddl_collection_name):
ddl_list = [data.payload["ddl"] for data in ddl_data]
id_list = [
self._format_point_id(data.id, DDL_COLLECTION_NAME) for data in ddl_data
self._format_point_id(data.id, self.ddl_collection_name)
for data in ddl_data
]

df_ddl = pd.DataFrame(
Expand All @@ -172,10 +186,10 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

df = pd.concat([df, df_ddl])

if doc_data := self._get_all_points(DOCUMENTATION_COLLECTION_NAME):
if doc_data := self._get_all_points(self.documentation_collection_name):
document_list = [data.payload["documentation"] for data in doc_data]
id_list = [
self._format_point_id(data.id, DOCUMENTATION_COLLECTION_NAME)
self._format_point_id(data.id, self.documentation_collection_name)
for data in doc_data
]

Expand Down Expand Up @@ -210,7 +224,7 @@ def remove_collection(self, collection_name: str) -> bool:
Returns:
bool: True if collection is deleted, False otherwise
"""
if collection_name in ID_SUFFIXES.keys():
if collection_name in self.id_suffixes.keys():
self._client.delete_collection(collection_name)
self._setup_collections()
return True
Expand All @@ -223,7 +237,7 @@ def embeddings_dimension(self):

def get_similar_question_sql(self, question: str, **kwargs) -> list:
results = self._client.search(
SQL_COLLECTION_NAME,
self.sql_collection_name,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
Expand All @@ -233,7 +247,7 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list:

def get_related_ddl(self, question: str, **kwargs) -> list:
results = self._client.search(
DDL_COLLECTION_NAME,
self.ddl_collection_name,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
Expand All @@ -243,7 +257,7 @@ def get_related_ddl(self, question: str, **kwargs) -> list:

def get_related_documentation(self, question: str, **kwargs) -> list:
results = self._client.search(
DOCUMENTATION_COLLECTION_NAME,
self.documentation_collection_name,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
Expand Down Expand Up @@ -282,28 +296,28 @@ def _get_all_points(self, collection_name: str):
return results

def _setup_collections(self):
if not self._client.collection_exists(SQL_COLLECTION_NAME):
if not self._client.collection_exists(self.sql_collection_name):
self._client.create_collection(
collection_name=SQL_COLLECTION_NAME,
collection_name=self.sql_collection_name,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
),
**self.collection_params,
)

if not self._client.collection_exists(DDL_COLLECTION_NAME):
if not self._client.collection_exists(self.ddl_collection_name):
self._client.create_collection(
collection_name=DDL_COLLECTION_NAME,
collection_name=self.ddl_collection_name,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
),
**self.collection_params,
)
if not self._client.collection_exists(DOCUMENTATION_COLLECTION_NAME):
if not self._client.collection_exists(self.documentation_collection_name):
self._client.create_collection(
collection_name=DOCUMENTATION_COLLECTION_NAME,
collection_name=self.documentation_collection_name,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
Expand All @@ -312,11 +326,11 @@ def _setup_collections(self):
)

def _format_point_id(self, id: str, collection_name: str) -> str:
return "{0}-{1}".format(id, ID_SUFFIXES[collection_name])
return "{0}-{1}".format(id, self.id_suffixes[collection_name])

def _parse_point_id(self, id: str) -> Tuple[str, str]:
id, suffix = id.rsplit("-", 1)
for collection_name, suffix in ID_SUFFIXES.items():
for collection_name, suffix in self.id_suffixes.items():
if type == suffix:
return id, collection_name
raise ValueError(f"Invalid id {id}")