Skip to content

Commit

Permalink
fix(stores): Add tests for known results and triangle inequality
Browse files Browse the repository at this point in the history
This adds some more tests to check the cosine similarity function has
some expected mathematical properties.
  • Loading branch information
richiejp committed Jan 22, 2025
1 parent 6913b29 commit cfbb56d
Showing 1 changed file with 132 additions and 11 deletions.
143 changes: 132 additions & 11 deletions tests/integration/stores_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"embed"
"math"
"math/rand"
"os"
"path/filepath"

Expand All @@ -22,6 +23,19 @@ import (
//go:embed backend-assets/*
var backendAssets embed.FS

func normalize(vecs [][]float32) {
for i, k := range vecs {
norm := float64(0)
for _, x := range k {
norm += float64(x * x)
}
norm = math.Sqrt(norm)
for j, x := range k {
vecs[i][j] = x / float32(norm)
}
}
}

var _ = Describe("Integration tests for the stores backend(s) and internal APIs", Label("stores"), func() {
Context("Embedded Store get,set and delete", func() {
var sl *model.ModelLoader
Expand Down Expand Up @@ -192,17 +206,8 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
// set 3 vectors that are at varying angles to {0.5, 0.5, 0.5}
keys := [][]float32{{0.1, 0.3, 0.5}, {0.5, 0.5, 0.5}, {0.6, 0.6, -0.6}, {0.7, -0.7, -0.7}}
vals := [][]byte{[]byte("test0"), []byte("test1"), []byte("test2"), []byte("test3")}
// normalize the keys
for i, k := range keys {
norm := float64(0)
for _, x := range k {
norm += float64(x * x)
}
norm = math.Sqrt(norm)
for j, x := range k {
keys[i][j] = x / float32(norm)
}
}

normalize(keys)

err := store.SetCols(context.Background(), sc, keys, vals)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -225,5 +230,121 @@ var _ = Describe("Integration tests for the stores backend(s) and internal APIs"
Expect(ks[1]).To(Equal(keys[1]))
Expect(vals[1]).To(Equal(vals[1]))
})

It("It produces the correct cosine similarities for orthogonal and opposite unit vectors", func() {
keys := [][]float32{{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0}, {-1.0, 0.0, 0.0}}
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}

err := store.SetCols(context.Background(), sc, keys, vals);
Expect(err).ToNot(HaveOccurred())

_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
Expect(err).ToNot(HaveOccurred())
Expect(sims).To(Equal([]float32{1.0, 0.0, 0.0, -1.0}))
})

It("It produces the correct cosine similarities for orthogonal and opposite vectors", func() {
keys := [][]float32{{1.0, 0.0, 1.0}, {0.0, 2.0, 0.0}, {0.0, 0.0, -1.0}, {-1.0, 0.0, -1.0}}
vals := [][]byte{[]byte("x"), []byte("y"), []byte("z"), []byte("-z")}

err := store.SetCols(context.Background(), sc, keys, vals);
Expect(err).ToNot(HaveOccurred())

_, _, sims, err := store.Find(context.Background(), sc, keys[0], 4)
Expect(err).ToNot(HaveOccurred())
Expect(sims[0]).To(BeNumerically("~", 1, 0.1))
Expect(sims[1]).To(BeNumerically("~", 0, 0.1))
Expect(sims[2]).To(BeNumerically("~", -0.7, 0.1))
Expect(sims[3]).To(BeNumerically("~", -1, 0.1))
})

expectTriangleEq := func(keys [][]float32, vals [][]byte) {
sims := map[string]map[string]float32{}

// compare every key vector pair and store the similarities in a lookup table
// that uses the values as keys
for i, k := range keys {
_, valsk, simsk, err := store.Find(context.Background(), sc, k, 9)
Expect(err).ToNot(HaveOccurred())

for j, v := range valsk {
p := string(vals[i])
q := string(v)

if sims[p] == nil {
sims[p] = map[string]float32{}
}

//log.Debug().Strs("vals", []string{p, q}).Float32("similarity", simsk[j]).Send()

sims[p][q] = simsk[j]
}
}

// Check that the triangle inequality holds for every combination of the triplet
// u, v and w
for _, simsu := range sims {
for w, simw := range simsu {
// acos(u,w) <= ...
uws := math.Acos(float64(simw))

// ... acos(u,v) + acos(v,w)
for v, _ := range simsu {
uvws := math.Acos(float64(simsu[v])) + math.Acos(float64(sims[v][w]))

//log.Debug().Str("u", u).Str("v", v).Str("w", w).Send()
//log.Debug().Float32("uw", simw).Float32("uv", simsu[v]).Float32("vw", sims[v][w]).Send()
Expect(uws).To(BeNumerically("<=", uvws))
}
}
}
}

It("It obeys the triangle inequality for normalized values", func() {
keys := [][]float32{
{1.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 1.0},
{-1.0, 0.0, 0.0}, {0.0, -1.0, 0.0}, {0.0, 0.0, -1.0},
{2.0, 3.0, 4.0}, {9.0, 7.0, 1.0}, {0.0, -1.2, 2.3},
}
vals := [][]byte{
[]byte("x"), []byte("y"), []byte("z"),
[]byte("-x"), []byte("-y"), []byte("-z"),
[]byte("u"), []byte("v"), []byte("w"),
}

normalize(keys[6:])

err := store.SetCols(context.Background(), sc, keys, vals);
Expect(err).ToNot(HaveOccurred())

expectTriangleEq(keys, vals)
})

It("It obeys the triangle inequality", func() {
rnd := rand.New(rand.NewSource(151))
keys := make([][]float32, 20)
vals := make([][]byte, 20)

for i := range keys {
k := make([]float32, 768)

for j := range k {
k[j] = rnd.Float32()
}

keys[i] = k
}

c := byte('a')
for i := range vals {
vals[i] = []byte{c}
c += 1
}

err := store.SetCols(context.Background(), sc, keys, vals);
Expect(err).ToNot(HaveOccurred())

expectTriangleEq(keys, vals)
})
})
})

0 comments on commit cfbb56d

Please sign in to comment.