Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Filtered Index support to Python bindings #482

Merged
merged 11 commits into from
Nov 7, 2023
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "diskannpy"
version = "0.6.1"
version = "0.7.0rc1"

description = "DiskANN Python extension module"
readme = "python/README.md"
Expand Down
5 changes: 3 additions & 2 deletions python/include/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ template <typename DT, typename TagT = DynamicIdType, typename LabelT = filterT>
void build_memory_index(diskann::Metric metric, const std::string &vector_bin_path,
const std::string &index_output_path, uint32_t graph_degree, uint32_t complexity,
float alpha, uint32_t num_threads, bool use_pq_build,
size_t num_pq_bytes, bool use_opq, uint32_t filter_complexity,
bool use_tags = false);
size_t num_pq_bytes, bool use_opq, bool use_tags = false,
const std::string& filter_labels_file = "", const std::string& universal_label = "",
uint32_t filter_complexity = 0);

}
4 changes: 4 additions & 0 deletions python/include/static_memory_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ template <typename DT> class StaticMemoryIndex
NeighborsAndDistances<StaticIdType> search(py::array_t<DT, py::array::c_style | py::array::forcecast> &query,
uint64_t knn, uint64_t complexity);

NeighborsAndDistances<StaticIdType> search_with_filter(
py::array_t<DT, py::array::c_style | py::array::forcecast> &query, uint64_t knn, uint64_t complexity,
filterT filter);

NeighborsAndDistances<StaticIdType> batch_search(
py::array_t<DT, py::array::c_style | py::array::forcecast> &queries, uint64_t num_queries, uint64_t knn,
uint64_t complexity, uint32_t num_threads);
Expand Down
53 changes: 46 additions & 7 deletions python/src/_builder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import json
import os
import shutil
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Hashable, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -174,8 +175,10 @@ def build_memory_index(
num_pq_bytes: int = defaults.NUM_PQ_BYTES,
use_opq: bool = defaults.USE_OPQ,
vector_dtype: Optional[VectorDType] = None,
filter_complexity: int = defaults.FILTER_COMPLEXITY,
tags: Union[str, VectorIdentifierBatch] = "",
filter_labels: Optional[list[list[str]]] = None,
universal_label: str = "",
filter_complexity: int = defaults.FILTER_COMPLEXITY,
index_prefix: str = "ann",
) -> None:
"""
Expand Down Expand Up @@ -223,10 +226,20 @@ def build_memory_index(
Default is `0`.
- **use_opq**: Use optimized product quantization during build.
- **vector_dtype**: Required if the provided `data` is of type `str`, else we use the `data.dtype` if np array.
- **filter_complexity**: Complexity to use when using filters. Default is 0.
- **tags**: A `str` representing a path to a pre-built tags file on disk, or a `numpy.ndarray` of uint32 ids
corresponding to the ordinal position of the vectors provided to build the index. Defaults to "". **This value
must be provided if you want to build a memory index intended for use with `diskannpy.DynamicMemoryIndex`**.
- **tags**: Tags can be defined either as a path on disk to an existing .tags file, or provided as a np.array of
the same length as the number of vectors. Tags are used to identify vectors in the index via your *own*
numbering conventions, and is absolutely required for loading DynamicMemoryIndex indices `from_file`.
- **filter_labels**: An optional, but exhaustive list of categories for each vector. This is used to filter
search results by category. If provided, this must be a list of lists, where each inner list is a list of
categories for the corresponding vector. For example, if you have 3 vectors, and the first vector belongs to
categories "a" and "b", the second vector belongs to category "b", and the third vector belongs to no categories,
you would provide `filter_labels=[["a", "b"], ["b"], []]`. If you do not want to provide categories for a
particular vector, you can provide an empty list. If you do not want to provide categories for any vectors,
you can provide `None` for this parameter (which is the default)
- **universal_label**: An optional label that indicates that this vector should be included in *every* search
in which it also meets the knn search criteria.
- **filter_complexity**: Complexity to use when using filters. Default is 0. 0 is strictly invalid if you are
using filters.
- **index_prefix**: The prefix of the index files. Defaults to "ann".
"""
_assert(
Expand All @@ -245,6 +258,10 @@ def build_memory_index(
_assert_is_nonnegative_uint32(num_pq_bytes, "num_pq_bytes")
_assert_is_nonnegative_uint32(filter_complexity, "filter_complexity")
_assert(index_prefix != "", "index_prefix cannot be an empty string")
_assert(
filter_labels is None or filter_complexity > 0,
"if filter_labels is provided, filter_complexity must not be 0"
)

index_path = Path(index_directory)
_assert(
Expand All @@ -262,6 +279,11 @@ def build_memory_index(
)

num_points, dimensions = vectors_metadata_from_file(vector_bin_path)
if filter_labels is not None:
_assert(
len(filter_labels) == num_points,
"filter_labels must be the same length as the number of points"
)

if vector_dtype_actual == np.uint8:
_builder = _native_dap.build_memory_uint8_index
Expand All @@ -272,6 +294,21 @@ def build_memory_index(

index_prefix_path = os.path.join(index_directory, index_prefix)

filter_labels_file = ""
if filter_labels is not None:
label_counts = {}
filter_labels_file = f"{index_prefix_path}_pylabels.txt"
with open(filter_labels_file, "w") as labels_file:
for labels in filter_labels:
for label in labels:
label_counts[label] = 1 if label not in label_counts else label_counts[label] + 1
if len(labels) == 0:
print("default", file=labels_file)
else:
print(",".join(labels), file=labels_file)
with open(f"{index_prefix_path}_label_metadata.json", "w") as label_metadata_file:
json.dump(label_counts, label_metadata_file, indent=True)

if isinstance(tags, str) and tags != "":
use_tags = True
shutil.copy(tags, index_prefix_path + ".tags")
Expand Down Expand Up @@ -299,8 +336,10 @@ def build_memory_index(
use_pq_build=use_pq_build,
num_pq_bytes=num_pq_bytes,
use_opq=use_opq,
filter_complexity=filter_complexity,
use_tags=use_tags,
filter_labels_file=filter_labels_file,
universal_label=universal_label,
filter_complexity=filter_complexity,
)

_write_index_metadata(
Expand Down
14 changes: 7 additions & 7 deletions python/src/_builder.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

from typing import BinaryIO, Optional, overload
from typing import BinaryIO, Hashable, Optional, overload

import numpy as np

Expand Down Expand Up @@ -47,11 +47,11 @@ def build_memory_index(
use_pq_build: bool,
num_pq_bytes: int,
use_opq: bool,
label_file: str,
tags: Union[str, VectorIdentifierBatch],
filter_labels: Optional[list[list[str]]],
universal_label: str,
filter_complexity: int,
tags: Optional[VectorIdentifierBatch],
index_prefix: str,
index_prefix: str
) -> None: ...
@overload
def build_memory_index(
Expand All @@ -66,9 +66,9 @@ def build_memory_index(
num_pq_bytes: int,
use_opq: bool,
vector_dtype: VectorDType,
label_file: str,
tags: Union[str, VectorIdentifierBatch],
filter_labels_file: Optional[list[list[str]]],
universal_label: str,
filter_complexity: int,
tags: Optional[str],
index_prefix: str,
index_prefix: str
) -> None: ...
22 changes: 12 additions & 10 deletions python/src/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _ensure_index_metadata(
distance_metric: Optional[DistanceMetric],
max_vectors: int,
dimensions: Optional[int],
warn_size_exceeded: bool = False,
daxpryce marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[VectorDType, str, np.uint64, np.uint64]:
possible_metadata = _read_index_metadata(index_path_and_prefix)
if possible_metadata is None:
Expand All @@ -226,16 +227,17 @@ def _ensure_index_metadata(
return vector_dtype, distance_metric, max_vectors, dimensions # type: ignore
else:
vector_dtype, distance_metric, num_vectors, dimensions = possible_metadata
if max_vectors is not None and num_vectors > max_vectors:
warnings.warn(
"The number of vectors in the saved index exceeds the max_vectors parameter. "
"max_vectors is being adjusted to accommodate the dataset, but any insertions will fail."
)
max_vectors = num_vectors
if num_vectors == max_vectors:
warnings.warn(
"The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail."
)
if warn_size_exceeded:
if max_vectors is not None and num_vectors > max_vectors:
warnings.warn(
"The number of vectors in the saved index exceeds the max_vectors parameter. "
"max_vectors is being adjusted to accommodate the dataset, but any insertions will fail."
)
max_vectors = num_vectors
if num_vectors == max_vectors:
warnings.warn(
"The number of vectors in the saved index equals max_vectors parameter. Any insertions will fail."
)
return possible_metadata


Expand Down
2 changes: 1 addition & 1 deletion python/src/_dynamic_memory_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def from_file(
f"The file {tags_file} does not exist in {index_directory}",
)
vector_dtype, dap_metric, num_vectors, dimensions = _ensure_index_metadata(
index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions
index_prefix_path, vector_dtype, distance_metric, max_vectors, dimensions, warn_size_exceeded=True
)

index = cls(
Expand Down
47 changes: 43 additions & 4 deletions python/src/_static_memory_index.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

import json
import os
import warnings
from typing import Optional
from typing import Hashable, Optional

import numpy as np

Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(
distance_metric: Optional[DistanceMetric] = None,
vector_dtype: Optional[VectorDType] = None,
dimensions: Optional[int] = None,
enable_filters: bool = False
):
"""
### Parameters
Expand Down Expand Up @@ -73,8 +75,22 @@ def __init__(
- **dimensions**: The vector dimensionality of this index. All new vectors inserted must be the same
dimensionality. **This value is only used if a `{index_prefix}_metadata.bin` file does not exist.** If it
does not exist, you are required to provide it.
- **enable_filters**: Indexes built with filters can also be used for filtered search.
"""
index_prefix = _valid_index_prefix(index_directory, index_prefix)
self._labels_map = {}
self._labels_metadata = {}
if enable_filters:
try:
with open(index_prefix + "_labels_map.txt", "r") as labels_map_if:
for line in labels_map_if:
(key, val) = line.split("\t")
self._labels_map[key] = int(val)
with open(f"{index_prefix}_label_metadata.json", "r") as labels_metadata_if:
self._labels_metadata = json.load(labels_metadata_if)
except: # noqa: E722
# exceptions are basically presumed to be either file not found or file not formatted correctly
raise RuntimeException("Filter labels file was unable to be processed.")
vector_dtype, metric, num_points, dims = _ensure_index_metadata(
index_prefix,
vector_dtype,
Expand Down Expand Up @@ -109,7 +125,7 @@ def __init__(
)

def search(
self, query: VectorLike, k_neighbors: int, complexity: int
self, query: VectorLike, k_neighbors: int, complexity: int, filter_label: str = ""
) -> QueryResponse:
"""
Searches the index by a single query vector.
Expand All @@ -121,13 +137,25 @@ def search(
- **complexity**: Size of distance ordered list of candidate neighbors to use while searching. List size
increases accuracy at the cost of latency. Must be at least k_neighbors in size.
"""
if filter_label != "":
if len(self._labels_map) == 0:
raise ValueError(
f"A filter label of {filter_label} was provided, but this class was not initialized with filters "
"enabled, e.g. StaticDiskMemory(..., enable_filters=True)"
)
if filter_label not in self._labels_map:
raise ValueError(
f"A filter label of {filter_label} was provided, but the external(str)->internal(np.uint32) labels map "
f"does not include that label."
)
k_neighbors = min(k_neighbors, self._labels_metadata[filter_label])
_query = _castable_dtype_or_raise(query, expected=self._vector_dtype)
_assert(len(_query.shape) == 1, "query vector must be 1-d")
_assert(
_query.shape[0] == self._dimensions,
f"query vector must have the same dimensionality as the index; index dimensionality: {self._dimensions}, "
f"query dimensionality: {_query.shape[0]}",
)
)
_assert_is_positive_uint32(k_neighbors, "k_neighbors")
_assert_is_nonnegative_uint32(complexity, "complexity")

Expand All @@ -136,9 +164,20 @@ def search(
f"k_neighbors={k_neighbors} asked for, but list_size={complexity} was smaller. Increasing {complexity} to {k_neighbors}"
)
complexity = k_neighbors
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)

if filter_label == "":
neighbors, distances = self._index.search(query=_query, knn=k_neighbors, complexity=complexity)
else:
filter = self._labels_map[filter_label]
neighbors, distances = self._index.search_with_filter(
query=query,
knn=k_neighbors,
complexity=complexity,
filter=filter
)
return QueryResponse(identifiers=neighbors, distances=distances)


def batch_search(
self,
queries: VectorLikeBatch,
Expand Down
Loading