diff --git a/base/search/hnsw.go b/base/search/hnsw.go index 597b17a43..94dfc2246 100644 --- a/base/search/hnsw.go +++ b/base/search/hnsw.go @@ -48,6 +48,7 @@ type HNSW struct { levelFactor float32 maxConnection int // maximum number of connections for each element per layer maxConnection0 int + ef int efConstruction int numJobs int } @@ -78,6 +79,14 @@ func SetEFConstruction(efConstruction int) HNSWConfig { } } +// SetEF sets the EF search value in HNSW. +// By default ef for search is the same as efConstruction. To return it to this default behavior, set it to 0. +func SetEF(ef int) HNSWConfig { + return func(h *HNSW) { + h.ef = ef + } +} + // NewHNSW builds a vector index based on Hierarchical Navigable Small Worlds. func NewHNSW(vectors []Vector, configs ...HNSWConfig) *HNSW { h := &HNSW{ @@ -96,7 +105,7 @@ func NewHNSW(vectors []Vector, configs ...HNSWConfig) *HNSW { // Search a vector in Hierarchical Navigable Small Worlds. func (h *HNSW) Search(q Vector, n int, prune0 bool) (values []int32, scores []float32) { - w := h.knnSearch(q, n, mathutil.Max(h.efConstruction, n)) + w := h.knnSearch(q, n, h.efSearchValue(n)) for w.Len() > 0 { value, score := w.Pop() if !prune0 || score < 0 { @@ -418,7 +427,7 @@ func (h *HNSW) MultiSearch(q Vector, terms []string, n int, prune0 bool) (values scores[term] = make([]float32, 0, n) } - w := h.efSearch(q, mathutil.Max(h.efConstruction, n)) + w := h.efSearch(q, h.efSearchValue(n)) for w.Len() > 0 { value, score := w.Pop() if !prune0 || score < 0 { @@ -452,6 +461,14 @@ func (h *HNSW) efSearch(q Vector, ef int) *heap.PriorityQueue { return w } +// efSearchValue returns the efSearch value to use, given the current number of elements desired. +func (h *HNSW) efSearchValue(n int) int { + if h.ef > 0 { + return mathutil.Max(h.ef, n) + } + return mathutil.Max(h.efConstruction, n) +} + func EstimateHNSWBuilderComplexity(dataSize, trials int) int { // build index complexity := dataSize * dataSize diff --git a/base/search/hnsw_test.go b/base/search/hnsw_test.go index 94a789fe2..8aec93ba0 100644 --- a/base/search/hnsw_test.go +++ b/base/search/hnsw_test.go @@ -15,8 +15,9 @@ package search import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestHNSWConfig(t *testing.T) { @@ -30,4 +31,7 @@ func TestHNSWConfig(t *testing.T) { SetEFConstruction(345)(hnsw) assert.Equal(t, 345, hnsw.efConstruction) + + SetEF(456)(hnsw) + assert.Equal(t, 456, hnsw.ef) }