forked from aponom84/navigable-graphs-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest-hnsw.py
127 lines (98 loc) · 4.79 KB
/
test-hnsw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
# coding: utf-8
import argparse
import random
from tqdm import tqdm
# from tqdm import tqdm_notebook as tqdm
import numpy as np
from hnsw import HNSW
from hnsw import l2_distance, heuristic
random.seed(108)
np.random.seed(108)
def brute_force_knn_search(distance_func, k, q, data):
'''
Return the list of (idx, dist) for k-closest elements to {x} in {data}
'''
return sorted(enumerate(distance_func(q, x) for x in data), key=lambda a: a[1])[:k]
def calculate_recall(distance_func, kg, test, groundtruth, k, ef, m):
if groundtruth is None:
print("Ground truth not found. Calculating ground truth...")
groundtruth = [
[idx for idx, dist in brute_force_knn_search(distance_func, k, query, kg.data)] for query in tqdm(test)
]
print("Calculating recall...")
recalls = []
total_calc = 0
for query, true_neighbors in tqdm(zip(test, groundtruth), total=len(test)):
true_neighbors = true_neighbors[:k] # Use only the top k ground truth neighbors
observed = [neighbor for neighbor, dist in kg.search(q=query, k=k, ef=ef, return_observed=True)]
total_calc = total_calc + len(observed)
results = observed[:k]
intersection = len(set(true_neighbors).intersection(set(results)))
# print(f'true_neighbors: {true_neighbors}, results: {results}. Intersection: {intersection}')
recall = intersection / k
recalls.append(recall)
return np.mean(recalls), total_calc / len(test)
def read_fvecs(filename):
with open(filename, 'rb') as f:
while True:
vec_size = np.fromfile(f, dtype=np.int32, count=1)
if not vec_size:
break
vec = np.fromfile(f, dtype=np.float32, count=vec_size[0])
yield vec
def read_ivecs(filename):
with open(filename, 'rb') as f:
while True:
vec_size = np.fromfile(f, dtype=np.int32, count=1)
if not vec_size:
break
vec = np.fromfile(f, dtype=np.int32, count=vec_size[0])
yield vec
def load_sift_dataset():
train_file = 'datasets/siftsmall/siftsmall_base.fvecs'
test_file = 'datasets/siftsmall/siftsmall_query.fvecs'
groundtruth_file = 'datasets/siftsmall/siftsmall_groundtruth.ivecs'
train_data = np.array(list(read_fvecs(train_file)))
test_data = np.array(list(read_fvecs(test_file)))
groundtruth_data = np.array(list(read_ivecs(groundtruth_file)))
return train_data, test_data, groundtruth_data
def generate_synthetic_data(dim, n, nq):
train_data = np.random.random((n, dim)).astype(np.float32)
test_data = np.random.random((nq, dim)).astype(np.float32)
return train_data, test_data
def main():
parser = argparse.ArgumentParser(description='Test recall of beam search method with KGraph.')
parser.add_argument('--dataset', choices=['synthetic', 'sift'], default='synthetic',
help="Choose the dataset to use: 'synthetic' or 'sift'.")
parser.add_argument('--K', type=int, default=5, help='The size of the neighbourhood')
parser.add_argument('--M', type=int, default=50, help='Avg number of neighbors')
parser.add_argument('--M0', type=int, default=50, help='Avg number of neighbors')
parser.add_argument('--dim', type=int, default=2, help='Dimensionality of synthetic data (ignored for SIFT).')
parser.add_argument('--n', type=int, default=200,
help='Number of training points for synthetic data (ignored for SIFT).')
parser.add_argument('--nq', type=int, default=50,
help='Number of query points for synthetic data (ignored for SIFT).')
parser.add_argument('--k', type=int, default=5, help='Number of nearest neighbors to search in the test stage')
parser.add_argument('--ef', type=int, default=10, help='Size of the beam for beam search.')
parser.add_argument('--m', type=int, default=3, help='Number of random entry points.')
args = parser.parse_args()
# Load dataset
if args.dataset == 'sift':
print("Loading SIFT dataset...")
train_data, test_data, groundtruth_data = load_sift_dataset()
else:
print(f"Generating synthetic dataset with {args.dim}-dimensional space...")
train_data, test_data = generate_synthetic_data(args.dim, args.n, args.nq)
groundtruth_data = None
# Create HNSW
hnsw = HNSW(distance_func=l2_distance, m=args.M, m0=args.M0, ef=10, ef_construction=30,
neighborhood_construction=heuristic)
# Add data to HNSW
for x in tqdm(train_data):
hnsw.add(x)
# Calculate recall
recall, avg_cal = calculate_recall(l2_distance, hnsw, test_data, groundtruth_data, k=args.k, ef=args.ef, m=args.m)
print(f"Average recall: {recall}, avg calc: {avg_cal}")
if __name__ == "__main__":
main()