From 2cf9a9d355a43dccb8b664cfeef838ae92fd6139 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Thu, 2 Apr 2020 10:26:59 +0200 Subject: [PATCH] feat: add SPTAG indexer, fix #170 Signed-off-by: Han Xiao --- extra-requirements.txt | 2 ++ jina/executors/indexers/annoy.py | 14 ++++++-- jina/executors/indexers/faiss.py | 7 +++- jina/executors/indexers/nmslib.py | 12 ++++--- jina/executors/indexers/sptag.py | 60 +++++++++++++++++++++++++++++++ 5 files changed, 86 insertions(+), 9 deletions(-) create mode 100644 jina/executors/indexers/sptag.py diff --git a/extra-requirements.txt b/extra-requirements.txt index c5ff972b6a0eb..a6527705d8d99 100644 --- a/extra-requirements.txt +++ b/extra-requirements.txt @@ -17,6 +17,8 @@ scipy: numeric, index flask: http, sse +flask_cors: http, sse +nmslib: index docker: flow torch: framework transformers: nlp, encode diff --git a/jina/executors/indexers/annoy.py b/jina/executors/indexers/annoy.py index 51932b4242379..3df26edd491e2 100644 --- a/jina/executors/indexers/annoy.py +++ b/jina/executors/indexers/annoy.py @@ -6,6 +6,14 @@ class AnnoyIndexer(NumpyIndexer): + """Annoy powered vector indexer + + For more information about the Annoy supported parameters, please consult: + - https://github.com/spotify/annoy + + .. note:: + Annoy package dependency is only required at the query time. + """ def __init__(self, metric: str = 'euclidean', n_trees: int = 10, *args, **kwargs): """ @@ -36,10 +44,10 @@ def get_query_handler(self): def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']: if keys.dtype != np.float32: raise ValueError('vectors should be ndarray of float32') - all_ret = [] + all_idx = [] all_dist = [] for k in keys: ret, dist = self.query_handler.get_nns_by_vector(k, top_k, include_distances=True) - all_ret.append(self.int2ext_key[ret]) + all_idx.append(self.int2ext_key[ret]) all_dist.append(dist) - return np.array(all_ret), np.array(all_dist) + return np.array(all_idx), np.array(all_dist) diff --git a/jina/executors/indexers/faiss.py b/jina/executors/indexers/faiss.py index e3176a86874e3..5a9ccb0c8608b 100644 --- a/jina/executors/indexers/faiss.py +++ b/jina/executors/indexers/faiss.py @@ -6,7 +6,11 @@ class FaissIndexer(NumpyIndexer): - """A Faiss indexers based on :class:`NumpyIndexer`. + """Faiss powered vector indexer + + For more information about the Faiss supported parameters and installation problems, please consult: + - https://github.com/spotify/annoy + - https://github.com/facebookresearch/faiss .. note:: Faiss package dependency is only required at the query time. @@ -35,4 +39,5 @@ def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.nd dist, ids = self.query_handler.search(keys, top_k) + # ids is already a numpy array return self.int2ext_key[ids], dist diff --git a/jina/executors/indexers/nmslib.py b/jina/executors/indexers/nmslib.py index e3ef32218e9bc..962bdeb388644 100644 --- a/jina/executors/indexers/nmslib.py +++ b/jina/executors/indexers/nmslib.py @@ -6,12 +6,15 @@ class NmslibIndexer(NumpyIndexer): - """Indexer powered by nmslib + """nmslib powered vector indexer - For documentation and explaination of each parameter, please refer to + For documentation and explanation of each parameter, please refer to - https://nmslib.github.io/nmslib/quickstart.html - https://github.com/nmslib/nmslib/blob/master/manual/methods.md + + .. note:: + Nmslib package dependency is only required at the query time. """ def __init__(self, space: str = 'cosinesimil', method: str = 'hnsw', print_progress: bool = False, @@ -48,6 +51,5 @@ def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.nd if keys.dtype != np.float32: raise ValueError('vectors should be ndarray of float32') ret = self.query_handler.knnQueryBatch(keys, k=top_k, num_threads=self.num_threads) - idx = np.stack([self.int2ext_key[v[0]] for v in ret]) - dist = np.stack([v[1] for v in ret]) - return idx, dist + idx, dist = zip(*ret) + return self.int2ext_key[np.array(idx)], np.array(dist) diff --git a/jina/executors/indexers/sptag.py b/jina/executors/indexers/sptag.py new file mode 100644 index 0000000000000..c8cc84e4eca78 --- /dev/null +++ b/jina/executors/indexers/sptag.py @@ -0,0 +1,60 @@ +from typing import Tuple + +import numpy as np + +from .numpy import NumpyIndexer + + +class SptagIndexer(NumpyIndexer): + """SPTAG powered vector indexer + + For SPTAG installation and python API usage, please consult: + + - https://github.com/microsoft/SPTAG/blob/master/Dockerfile + - https://github.com/microsoft/SPTAG/blob/master/docs/Tutorial.ipynb + - https://github.com/microsoft/SPTAG + + .. note:: + sptag package dependency is only required at the query time. + """ + + def __init__(self, dist_calc_method: str = 'L2', method: str = 'BKT', + num_threads: int = 1, + *args, **kwargs): + """ + Initialize an NmslibIndexer + + :param dist_calc_method: the distance type, currently SPTAG only support Cosine and L2 distances. + :param method: The index method to use, index Algorithm type (e.g. BKT, KDT), required. + :param num_threads: The number of threads to use + :param args: + :param kwargs: + """ + super().__init__(*args, **kwargs) + self.method = method + self.space = dist_calc_method + self.num_threads = num_threads + + def get_query_handler(self): + vecs = super().get_query_handler() + if vecs is not None: + import SPTAG + + _index = SPTAG.AnnIndex(self.method, 'Float', vecs.shape[1]) + + # Set the thread number to speed up the build procedure in parallel + _index.SetBuildParam("NumberOfThreads", str(self.num_threads)) + _index.SetBuildParam("DistCalcMethod", self.method) + + if _index.Build(vecs, vecs.shape[0]): + return _index + else: + return None + + def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']: + if keys.dtype != np.float32: + raise ValueError('vectors should be ndarray of float32') + + ret = self.query_handler.Search(keys, top_k) + idx, dist = zip(*ret) + return self.int2ext_key[np.array(idx)], np.array(dist)