Skip to content

Commit

Permalink
Remove NEGATIVE_MODE_REORDER redesign NEGATIVE_MODE_FILTER
Browse files Browse the repository at this point in the history
Removed NEGATIVE_MODE_REORDER as it didn't make much sense.

NEGATIVE_MODE_FILTER now filters at query time instead of the result.
This means that when you want 10 results, you can get 10 results.
Instead of those 10 results then being filtered and resulting in less.

Removed default negative mode. You now always have to pick one when
providing negative text or embeddings.
  • Loading branch information
erikdubbelboer committed Jun 23, 2024
1 parent f2bfe40 commit 6237db1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 137 deletions.
123 changes: 23 additions & 100 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"path/filepath"
"slices"
"sort"
"sync"
)

Expand All @@ -33,24 +32,15 @@ type Collection struct {
type NegativeMode string

const (
// NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding.
// This is the default behavior.
NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract"

// NEGATIVE_MODE_REORDER reorders the results based on the similarity between the
// negative embedding and the document embeddings.
// NegativeReorderStrength controls the strength of the reordering. Lower values
// will reorder the results less aggressively.
NEGATIVE_MODE_REORDER NegativeMode = "reorder"

// NEGATIVE_MODE_FILTER filters out results based on the similarity between the
// negative embedding and the document embeddings.
// NegativeFilterThreshold controls the threshold for filtering. Documents with
// similarity above the threshold will be removed from the results.
NEGATIVE_MODE_FILTER NegativeMode = "filter"

// Default values for negative reordering and filtering.
DEFAULT_NEGATIVE_REORDER_STRENGTH = 1
// NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding.
// This is the default behavior.
NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract"

// The default threshold for the negative filter.
DEFAULT_NEGATIVE_FILTER_THRESHOLD = 0.5
Expand Down Expand Up @@ -82,6 +72,9 @@ type QueryOptions struct {
}

type NegativeQueryOptions struct {
// Mode is the mode to use for the negative text.
Mode NegativeMode

// Text is the text to exclude from the results.
Text string

Expand All @@ -91,12 +84,6 @@ type NegativeQueryOptions struct {
// If both Text and Embedding are set, Embedding will be used.
Embedding []float32

// Mode is the mode to use for the negative text.
Mode NegativeMode

// ReorderStrength is the strength of the negative reordering. Used when Mode is NEGATIVE_MODE_REORDER.
ReorderStrength float32

// FilterThreshold is the threshold for the negative filter. Used when Mode is NEGATIVE_MODE_FILTER.
FilterThreshold float32
}
Expand Down Expand Up @@ -435,11 +422,7 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions)
}
}

negativeMode := options.Negative.Mode
if negativeMode == "" {
negativeMode = NEGATIVE_MODE_SUBTRACT
}

negativeFilterThreshold := options.Negative.FilterThreshold
negativeVector := options.Negative.Embedding
if len(negativeVector) == 0 && options.Negative.Text != "" {
negativeVector, err = c.embed(ctx, options.Negative.Text)
Expand All @@ -453,41 +436,23 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions)
negativeVector = normalizeVector(negativeVector)
}

if negativeMode == NEGATIVE_MODE_SUBTRACT {
if options.Negative.Mode == NEGATIVE_MODE_SUBTRACT {
queryVector = subtractVector(queryVector, negativeVector)
queryVector = normalizeVector(queryVector)
} else if options.Negative.Mode == NEGATIVE_MODE_FILTER {
if negativeFilterThreshold == 0 {
negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD
}
} else {
return nil, fmt.Errorf("unsupported negative mode: %q", options.Negative.Mode)
}
}

result, err := c.QueryEmbedding(ctx, queryVector, options.NResults, options.Where, options.WhereDocument)
result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocument)
if err != nil {
return nil, err
}

if len(negativeVector) != 0 {
if negativeMode == NEGATIVE_MODE_REORDER {
negativeReorderStrength := options.Negative.ReorderStrength
if negativeReorderStrength == 0 {
negativeReorderStrength = DEFAULT_NEGATIVE_REORDER_STRENGTH
}

result, err = reorderResults(result, negativeVector, negativeReorderStrength)
if err != nil {
return nil, fmt.Errorf("couldn't reorder results: %w", err)
}
} else if negativeMode == NEGATIVE_MODE_FILTER {
negativeFilterThreshold := options.Negative.FilterThreshold
if negativeFilterThreshold == 0 {
negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD
}

result, err = filterResults(result, negativeVector, negativeFilterThreshold)
if err != nil {
return nil, fmt.Errorf("couldn't filter results: %w", err)
}
}
}

return result, nil
}

Expand All @@ -501,6 +466,11 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions)
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocument)
}

// queryEmbedding performs an exhaustive nearest neighbor search on the collection.
func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
if len(queryEmbedding) == 0 {
return nil, errors.New("queryEmbedding is empty")
}
Expand Down Expand Up @@ -546,18 +516,13 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
}

// For the remaining documents, get the most similar docs.
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen)
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, negativeEmbeddings, negativeFilterThreshold, filteredDocs, resLen)
if err != nil {
return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
}

// As long as we don't filter by threshold, resLen should match len(nMaxDocs).
if resLen != len(nMaxDocs) {
return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs))
}

res := make([]Result, 0, resLen)
for i := 0; i < resLen; i++ {
res := make([]Result, 0, len(nMaxDocs))
for i := 0; i < len(nMaxDocs); i++ {
res = append(res, Result{
ID: nMaxDocs[i].docID,
Metadata: c.documents[nMaxDocs[i].docID].Metadata,
Expand All @@ -580,45 +545,3 @@ func (c *Collection) getDocPath(docID string) string {
}
return docPath
}

func reorderResults(results []Result, negativeVector []float32, negativeReorderStrength float32) ([]Result, error) {
if len(results) == 0 {
return results, nil
}

// Calculate cosine similarity between negative vector and each result
for i := range results {
sim, err := dotProduct(negativeVector, results[i].Embedding)
if err != nil {
return nil, fmt.Errorf("couldn't calculate dot product: %w", err)
}
results[i].Similarity -= sim * negativeReorderStrength
}

// Sort results by similarity
sort.Slice(results, func(i, j int) bool {
return results[i].Similarity > results[j].Similarity
})

return results, nil
}

func filterResults(results []Result, negativeVector []float32, negativeFilterThreshold float32) ([]Result, error) {
if len(results) == 0 {
return results, nil
}

// Filter out results with similarity above the threshold
filteredResults := make([]Result, 0, len(results))
for _, res := range results {
sim, err := dotProduct(negativeVector, res.Embedding)
if err != nil {
return nil, fmt.Errorf("couldn't calculate dot product: %w", err)
}
if sim < negativeFilterThreshold {
filteredResults = append(filteredResults, res)
}
}

return filteredResults, nil
}
14 changes: 13 additions & 1 deletion query.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string]
return true
}

func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) {
func getMostSimilarDocs(ctx context.Context, queryVectors, negativeVector []float32, negativeFilterThreshold float32, docs []*Document, n int) ([]docSim, error) {
nMaxDocs := newMaxDocSims(n)

// Determine concurrency. Use number of docs or CPUs, whichever is smaller.
Expand Down Expand Up @@ -218,6 +218,18 @@ func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Doc
return
}

if negativeFilterThreshold > 0 {
nsim, err := dotProduct(negativeVector, doc.Embedding)
if err != nil {
setSharedErr(fmt.Errorf("couldn't calculate negative similarity for document '%s': %w", doc.ID, err))
return
}

if nsim > negativeFilterThreshold {
continue
}
}

nMaxDocs.add(docSim{docID: doc.ID, similarity: sim})
}
}(docs[start:end])
Expand Down
37 changes: 1 addition & 36 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func TestNegative(t *testing.T) {
ctx := context.Background()
db := NewDB()

c, err := db.CreateCollection("knowledge-base", nil, nil)
c, err := db.CreateCollection("test", nil, nil)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -169,41 +169,6 @@ func TestNegative(t *testing.T) {
}
})

t.Run("NEGATIVE_MODE_REORDER", func(t *testing.T) {
res, err := c.QueryWithOptions(ctx, QueryOptions{
QueryEmbedding: testEmbeddings["search_query: town"],
NResults: c.Count(),
Negative: NegativeQueryOptions{
Embedding: testEmbeddings["search_query: idle"],
Mode: NEGATIVE_MODE_REORDER,
},
})
if err != nil {
panic(err)
}

for _, r := range res {
t.Logf("%s: %v", r.ID, r.Similarity)
}

if len(res) != 3 {
t.Fatalf("expected 3 results, got %d", len(res))
}

// Village Builder Game
if res[0].ID != "1" {
t.Fatalf("expected document with ID 1, got %s", res[0].ID)
}
// Town Craft Idle Game
if res[1].ID != "2" {
t.Fatalf("expected document with ID 2, got %s", res[1].ID)
}
// Some Idle Game
if res[2].ID != "3" {
t.Fatalf("expected document with ID 3, got %s", res[2].ID)
}
})

t.Run("NEGATIVE_MODE_FILTER", func(t *testing.T) {
res, err := c.QueryWithOptions(ctx, QueryOptions{
QueryEmbedding: testEmbeddings["search_query: town"],
Expand Down

0 comments on commit 6237db1

Please sign in to comment.