Skip to content

Commit

Permalink
feat: add SPTAG indexer, fix #170
Browse files Browse the repository at this point in the history
Signed-off-by: Han Xiao <[email protected]>
  • Loading branch information
hanxiao committed Apr 2, 2020
1 parent 58c3d70 commit 2cf9a9d
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 9 deletions.
2 changes: 2 additions & 0 deletions extra-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

scipy: numeric, index
flask: http, sse
flask_cors: http, sse
nmslib: index
docker: flow
torch: framework
transformers: nlp, encode
Expand Down
14 changes: 11 additions & 3 deletions jina/executors/indexers/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@


class AnnoyIndexer(NumpyIndexer):
"""Annoy powered vector indexer
For more information about the Annoy supported parameters, please consult:
- https://github.com/spotify/annoy
.. note::
Annoy package dependency is only required at the query time.
"""

def __init__(self, metric: str = 'euclidean', n_trees: int = 10, *args, **kwargs):
"""
Expand Down Expand Up @@ -36,10 +44,10 @@ def get_query_handler(self):
def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']:
if keys.dtype != np.float32:
raise ValueError('vectors should be ndarray of float32')
all_ret = []
all_idx = []
all_dist = []
for k in keys:
ret, dist = self.query_handler.get_nns_by_vector(k, top_k, include_distances=True)
all_ret.append(self.int2ext_key[ret])
all_idx.append(self.int2ext_key[ret])
all_dist.append(dist)
return np.array(all_ret), np.array(all_dist)
return np.array(all_idx), np.array(all_dist)
7 changes: 6 additions & 1 deletion jina/executors/indexers/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@


class FaissIndexer(NumpyIndexer):
"""A Faiss indexers based on :class:`NumpyIndexer`.
"""Faiss powered vector indexer
For more information about the Faiss supported parameters and installation problems, please consult:
- https://github.com/spotify/annoy
- https://github.com/facebookresearch/faiss
.. note::
Faiss package dependency is only required at the query time.
Expand Down Expand Up @@ -35,4 +39,5 @@ def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.nd

dist, ids = self.query_handler.search(keys, top_k)

# ids is already a numpy array
return self.int2ext_key[ids], dist
12 changes: 7 additions & 5 deletions jina/executors/indexers/nmslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@


class NmslibIndexer(NumpyIndexer):
"""Indexer powered by nmslib
"""nmslib powered vector indexer
For documentation and explaination of each parameter, please refer to
For documentation and explanation of each parameter, please refer to
- https://nmslib.github.io/nmslib/quickstart.html
- https://github.com/nmslib/nmslib/blob/master/manual/methods.md
.. note::
Nmslib package dependency is only required at the query time.
"""

def __init__(self, space: str = 'cosinesimil', method: str = 'hnsw', print_progress: bool = False,
Expand Down Expand Up @@ -48,6 +51,5 @@ def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.nd
if keys.dtype != np.float32:
raise ValueError('vectors should be ndarray of float32')
ret = self.query_handler.knnQueryBatch(keys, k=top_k, num_threads=self.num_threads)
idx = np.stack([self.int2ext_key[v[0]] for v in ret])
dist = np.stack([v[1] for v in ret])
return idx, dist
idx, dist = zip(*ret)
return self.int2ext_key[np.array(idx)], np.array(dist)
60 changes: 60 additions & 0 deletions jina/executors/indexers/sptag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Tuple

import numpy as np

from .numpy import NumpyIndexer


class SptagIndexer(NumpyIndexer):
"""SPTAG powered vector indexer
For SPTAG installation and python API usage, please consult:
- https://github.com/microsoft/SPTAG/blob/master/Dockerfile
- https://github.com/microsoft/SPTAG/blob/master/docs/Tutorial.ipynb
- https://github.com/microsoft/SPTAG
.. note::
sptag package dependency is only required at the query time.
"""

def __init__(self, dist_calc_method: str = 'L2', method: str = 'BKT',
num_threads: int = 1,
*args, **kwargs):
"""
Initialize an NmslibIndexer
:param dist_calc_method: the distance type, currently SPTAG only support Cosine and L2 distances.
:param method: The index method to use, index Algorithm type (e.g. BKT, KDT), required.
:param num_threads: The number of threads to use
:param args:
:param kwargs:
"""
super().__init__(*args, **kwargs)
self.method = method
self.space = dist_calc_method
self.num_threads = num_threads

def get_query_handler(self):
vecs = super().get_query_handler()
if vecs is not None:
import SPTAG

_index = SPTAG.AnnIndex(self.method, 'Float', vecs.shape[1])

# Set the thread number to speed up the build procedure in parallel
_index.SetBuildParam("NumberOfThreads", str(self.num_threads))
_index.SetBuildParam("DistCalcMethod", self.method)

if _index.Build(vecs, vecs.shape[0]):
return _index
else:
return None

def query(self, keys: 'np.ndarray', top_k: int, *args, **kwargs) -> Tuple['np.ndarray', 'np.ndarray']:
if keys.dtype != np.float32:
raise ValueError('vectors should be ndarray of float32')

ret = self.query_handler.Search(keys, top_k)
idx, dist = zip(*ret)
return self.int2ext_key[np.array(idx)], np.array(dist)

0 comments on commit 2cf9a9d

Please sign in to comment.