Skip to content

Commit

Permalink
Rewrote to QueryWithOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
erikdubbelboer committed May 26, 2024
1 parent 9ca7474 commit 2585ba5
Show file tree
Hide file tree
Showing 3 changed files with 316 additions and 20 deletions.
187 changes: 167 additions & 20 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"path/filepath"
"slices"
"sort"
"sync"
)

Expand All @@ -27,6 +28,69 @@ type Collection struct {
// versions in [DB.Export] and [DB.Import] as well!
}

type NegativeMode string

const (
// NEGATIVE_MODE_SUBTRACT subtracts the negative embedding from the query embedding.
// This is the default behavior.
NEGATIVE_MODE_SUBTRACT NegativeMode = "subtract"

// NEGATIVE_MODE_REORDER reorders the results based on the similarity between the
// negative embedding and the document embeddings.
// NegativeReorderStrength controls the strength of the reordering. Lower values
// will reorder the results less aggressively.
NEGATIVE_MODE_REORDER NegativeMode = "reorder"

// NEGATIVE_MODE_FILTER filters out results based on the similarity between the
// negative embedding and the document embeddings.
// NegativeFilterThreshold controls the threshold for filtering. Documents with
// similarity above the threshold will be removed from the results.
NEGATIVE_MODE_FILTER NegativeMode = "filter"

DEFAULT_NEGATIVE_REORDER_STRENGTH = 1

DEFAULT_NEGATIVE_FILTER_THRESHOLD = 0.5
)

// QueryOptions represents the options for a query.
type QueryOptions struct {
// The text to search for.
QueryText string

// The embedding of the query to search for. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// If both QueryText and QueryEmbedding are set, QueryEmbedding will be used.
QueryEmbedding []float32

// The text to exclude from the results.
NegativeText string

// The embedding of the negative text. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// If both NegativeText and NegativeEmbedding are set, NegativeEmbedding will be used.
NegativeEmbedding []float32

// The mode to use for the negative text.
NegativeMode NegativeMode

// The strength of the negative reordering. Used when NegativeMode is NEGATIVE_MODE_REORDER.
NegativeReorderStrength float32

// The threshold for the negative filter. Used when NegativeMode is NEGATIVE_MODE_FILTER.
NegativeFilterThreshold float32

// The number of results to return.
NResults int

// Conditional filtering on metadata.
Where map[string]string

// Conditional filtering on documents.
WhereDocument map[string]string
}

// We don't export this yet to keep the API surface to the bare minimum.
// Users create collections via [Client.CreateCollection].
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) {
Expand Down Expand Up @@ -336,44 +400,85 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
return nil, errors.New("queryText is empty")
}

queryVectors, err := c.embed(ctx, queryText)
queryVector, err := c.embed(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}

return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
return c.QueryEmbedding(ctx, queryVector, nResults, where, whereDocument)
}

// Performs an exhaustive nearest neighbor search on the collection.
//
// - queryText: The text to search for. Its embedding will be created using the
// collection's embedding function.
// - negativeText: The text to subtract from the query embedding. Its embedding
// will be created using the collection's embedding function.
// - nResults: The number of results to return. Must be > 0.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryWithNegative(ctx context.Context, queryText string, negativeText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
if queryText == "" {
return nil, errors.New("queryText is empty")
// - options: The options for the query. See QueryOptions for more information.
func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions) ([]Result, error) {
if options.QueryText == "" && len(options.QueryEmbedding) == 0 {
return nil, errors.New("QueryText and QueryEmbedding options are empty")
}

queryVectors, err := c.embed(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
var err error
queryVector := options.QueryEmbedding
if len(queryVector) == 0 {
queryVector, err = c.embed(ctx, options.QueryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}
}

negativeMode := options.NegativeMode
if negativeMode == "" {
negativeMode = NEGATIVE_MODE_SUBTRACT
}

if negativeText != "" {
negativeVectors, err := c.embed(ctx, negativeText)
negativeVector := options.NegativeEmbedding
if len(negativeVector) == 0 && options.NegativeText != "" {
negativeVector, err = c.embed(ctx, options.NegativeText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of negative: %w", err)
}
}

if len(negativeVector) != 0 {
if !isNormalized(negativeVector) {
negativeVector = normalizeVector(negativeVector)
}

queryVectors = subtractVector(queryVectors, negativeVectors)
queryVectors = normalizeVector(queryVectors)
if negativeMode == NEGATIVE_MODE_SUBTRACT {
queryVector = subtractVector(queryVector, negativeVector)
queryVector = normalizeVector(queryVector)
}
}

return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
result, err := c.QueryEmbedding(ctx, queryVector, options.NResults, options.Where, options.WhereDocument)
if err != nil {
return nil, err
}

if len(negativeVector) != 0 {
if negativeMode == NEGATIVE_MODE_REORDER {
negativeReorderStrength := options.NegativeReorderStrength
if negativeReorderStrength == 0 {
negativeReorderStrength = DEFAULT_NEGATIVE_REORDER_STRENGTH
}

result, err = reorderResults(result, negativeVector, negativeReorderStrength)
if err != nil {
return nil, fmt.Errorf("couldn't reorder results: %w", err)
}
} else if negativeMode == NEGATIVE_MODE_FILTER {
negativeFilterThreshold := options.NegativeFilterThreshold
if negativeFilterThreshold == 0 {
negativeFilterThreshold = DEFAULT_NEGATIVE_FILTER_THRESHOLD
}

result, err = filterResults(result, negativeVector, negativeFilterThreshold)
if err != nil {
return nil, fmt.Errorf("couldn't filter results: %w", err)
}
}
}

return result, nil
}

// Performs an exhaustive nearest neighbor search on the collection.
Expand Down Expand Up @@ -465,3 +570,45 @@ func (c *Collection) getDocPath(docID string) string {
}
return docPath
}

func reorderResults(results []Result, negativeVector []float32, negativeReorderStrength float32) ([]Result, error) {
if len(results) == 0 {
return results, nil
}

// Calculate cosine similarity between negative vector and each result
for i := range results {
sim, err := dotProduct(negativeVector, results[i].Embedding)
if err != nil {
return nil, fmt.Errorf("couldn't calculate dot product: %w", err)
}
results[i].Similarity -= sim * negativeReorderStrength
}

// Sort results by similarity
sort.Slice(results, func(i, j int) bool {
return results[i].Similarity > results[j].Similarity
})

return results, nil
}

func filterResults(results []Result, negativeVector []float32, negativeFilterThreshold float32) ([]Result, error) {
if len(results) == 0 {
return results, nil
}

// Filter out results with similarity above the threshold
filteredResults := make([]Result, 0, len(results))
for _, res := range results {
sim, err := dotProduct(negativeVector, res.Embedding)
if err != nil {
return nil, fmt.Errorf("couldn't calculate dot product: %w", err)
}
if sim < negativeFilterThreshold {
filteredResults = append(filteredResults, res)
}
}

return filteredResults, nil
}
Loading

0 comments on commit 2585ba5

Please sign in to comment.