Skip to content

Commit

Permalink
perf(vector): Improve how vector is passed to hnsw index (#9287)
Browse files Browse the repository at this point in the history
  • Loading branch information
harshil-goel authored Feb 10, 2025
1 parent 15ba2a3 commit ff401c2
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 61 deletions.
14 changes: 14 additions & 0 deletions posting/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ type List struct {
mutationMap *MutableLayer
minTs uint64 // commit timestamp of immutable layer, reject reads before this ts.
maxTs uint64 // max commit timestamp seen for this list.

cache []byte
}

// MutableLayer is the structure that will store mutable layer of the posting list. Every posting list has an immutable
Expand All @@ -94,6 +96,7 @@ type MutableLayer struct {
committedUids map[uint64]*pb.Posting // Stores the uid to posting mapping in committedEntries.
committedUidsTime uint64 // Stores the latest commitTs in the committedEntries.
length int // Stores the length of the posting list until committedEntries.
lastEntry *pb.PostingList // Stores the last entry stored in committedUids

// We also cache some things required for us to update currentEntries faster
currentUids map[uint64]int // Stores the uid to index mapping in the currentEntries posting list
Expand Down Expand Up @@ -131,6 +134,7 @@ func (mm *MutableLayer) clone() *MutableLayer {
deleteAllMarker: mm.deleteAllMarker,
committedUids: mm.committedUids,
length: mm.length,
lastEntry: mm.lastEntry,
committedUidsTime: mm.committedUidsTime,
}
}
Expand Down Expand Up @@ -299,6 +303,9 @@ func (mm *MutableLayer) insertCommittedPostings(pl *pb.PostingList) {
mm.deleteAllMarker = 0
}

if pl.CommitTs > mm.committedUidsTime {
mm.lastEntry = pl
}
mm.committedUidsTime = x.Max(pl.CommitTs, mm.committedUidsTime)
mm.committedEntries[pl.CommitTs] = pl

Expand Down Expand Up @@ -890,6 +897,7 @@ func (l *List) SetTs(readTs uint64) {
}

func (l *List) addMutationInternal(ctx context.Context, txn *Txn, t *pb.DirectedEdge) error {
l.cache = nil
l.AssertLock()

if txn.ShouldAbort() {
Expand Down Expand Up @@ -991,6 +999,9 @@ func (l *List) setMutationAfterCommit(startTs, commitTs uint64, pl *pb.PostingLi
if l.mutationMap.committedUidsTime == math.MaxUint64 {
l.mutationMap.committedUidsTime = 0
}
if pl.CommitTs > l.mutationMap.committedUidsTime {
l.mutationMap.lastEntry = pl
}
l.mutationMap.committedUidsTime = x.Max(l.mutationMap.committedUidsTime, commitTs)

for _, mpost := range pl.Postings {
Expand Down Expand Up @@ -1898,6 +1909,9 @@ func (l *List) findStaticValue(readTs uint64) *pb.PostingList {

// If maxTs < readTs then we need to read maxTs
if l.maxTs <= readTs {
if l.mutationMap.lastEntry != nil {
return l.mutationMap.lastEntry
}
if mutation := l.mutationMap.get(l.maxTs); mutation != nil {
return mutation
}
Expand Down
15 changes: 9 additions & 6 deletions posting/lists.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/dgraph-io/dgo/v240/protos/api"
"github.com/dgraph-io/ristretto/v2/z"
"github.com/hypermodeinc/dgraph/v24/protos/pb"
"github.com/hypermodeinc/dgraph/v24/tok/index"
"github.com/hypermodeinc/dgraph/v24/x"
)

Expand Down Expand Up @@ -87,7 +86,7 @@ func (vc *viLocalCache) Find(prefix []byte, filter func([]byte) bool) (uint64, e
return vc.delegate.Find(prefix, filter)
}

func (vc *viLocalCache) Get(key []byte) (rval index.Value, rerr error) {
func (vc *viLocalCache) Get(key []byte) ([]byte, error) {
pl, err := vc.delegate.Get(key)
if err != nil {
return nil, err
Expand All @@ -97,26 +96,30 @@ func (vc *viLocalCache) Get(key []byte) (rval index.Value, rerr error) {
return vc.GetValueFromPostingList(pl)
}

func (vc *viLocalCache) GetWithLockHeld(key []byte) (rval index.Value, rerr error) {
func (vc *viLocalCache) GetWithLockHeld(key []byte) ([]byte, error) {
pl, err := vc.delegate.Get(key)
if err != nil {
return nil, err
}
return vc.GetValueFromPostingList(pl)
}

func (vc *viLocalCache) GetValueFromPostingList(pl *List) (rval index.Value, rerr error) {
func (vc *viLocalCache) GetValueFromPostingList(pl *List) ([]byte, error) {
if pl.cache != nil {
return pl.cache, nil
}
value := pl.findStaticValue(vc.delegate.startTs)

if value == nil || len(value.Postings) == 0 {
return nil, ErrNoValue
}

if hasDeleteAll(value.Postings[0]) || value.Postings[0].Op == Del {
if value.Postings[0].Op == Del {
return nil, ErrNoValue
}

return value.Postings[0].Value, nil
pl.cache = value.Postings[0].Value
return pl.cache, nil
}

func NewViLocalCache(delegate *LocalCache) *viLocalCache {
Expand Down
16 changes: 9 additions & 7 deletions posting/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (vt *viTxn) StartTs() uint64 {
return vt.delegate.StartTs
}

func (vt *viTxn) Get(key []byte) (rval index.Value, rerr error) {
func (vt *viTxn) Get(key []byte) ([]byte, error) {
pl, err := vt.delegate.cache.Get(key)
if err != nil {
return nil, err
Expand All @@ -84,28 +84,30 @@ func (vt *viTxn) Get(key []byte) (rval index.Value, rerr error) {
return vt.GetValueFromPostingList(pl)
}

func (vt *viTxn) GetWithLockHeld(key []byte) (rval index.Value, rerr error) {
func (vt *viTxn) GetWithLockHeld(key []byte) ([]byte, error) {
pl, err := vt.delegate.cache.Get(key)
if err != nil {
return nil, err
}
return vt.GetValueFromPostingList(pl)
}

func (vt *viTxn) GetValueFromPostingList(pl *List) (rval index.Value, rerr error) {
func (vt *viTxn) GetValueFromPostingList(pl *List) ([]byte, error) {
if pl.cache != nil {
return pl.cache, nil
}
value := pl.findStaticValue(vt.delegate.StartTs)

// When the posting is deleted, we find the key in the badger, but no postings in it. This should also
// return ErrKeyNotFound as that is what we except in the later functions.
if value == nil || len(value.Postings) == 0 {
return nil, ErrNoValue
}

if hasDeleteAll(value.Postings[0]) || value.Postings[0].Op == Del {
if value.Postings[0].Op == Del {
return nil, ErrNoValue
}

return value.Postings[0].Value, nil
pl.cache = value.Postings[0].Value
return pl.cache, nil
}

func (vt *viTxn) AddMutation(ctx context.Context, key []byte, t *index.KeyValue) error {
Expand Down
22 changes: 12 additions & 10 deletions tok/hnsw/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ type TxnCache struct {
startTs uint64
}

func (tc *TxnCache) Get(key []byte) (rval index.Value, rerr error) {
func (tc *TxnCache) Get(key []byte) (rval []byte, rerr error) {
return tc.txn.Get(key)
}

Expand Down Expand Up @@ -265,7 +265,7 @@ func (qc *QueryCache) Find(prefix []byte, filter func([]byte) bool) (uint64, err
return qc.cache.Find(prefix, filter)
}

func (qc *QueryCache) Get(key []byte) (rval index.Value, rerr error) {
func (qc *QueryCache) Get(key []byte) (rval []byte, rerr error) {
return qc.cache.Get(key)
}

Expand All @@ -282,7 +282,7 @@ func NewQueryCache(cache index.LocalCache, readTs uint64) *QueryCache {

// getDataFromKeyWithCacheType(keyString, uid, c) looks up data in c
// associated with keyString and uid.
func getDataFromKeyWithCacheType(keyString string, uid uint64, c index.CacheType) (index.Value, error) {
func getDataFromKeyWithCacheType(keyString string, uid uint64, c index.CacheType) ([]byte, error) {
key := DataKey(keyString, uid)
data, err := c.Get(key)
if err != nil {
Expand Down Expand Up @@ -313,7 +313,7 @@ func populateEdgeDataFromKeyWithCacheType(
if data == nil {
return false, nil
}
err = decodeUint64MatrixUnsafe(data.([]byte), edgeData)
err = decodeUint64MatrixUnsafe(data, edgeData)
return true, err
}

Expand Down Expand Up @@ -355,23 +355,25 @@ func getInsertLayer(maxLevels int) int {
return level
}

var emptyVec = []byte{}

// adds the data corresponding to a uid to the given vec variable in the form of []T
// this does not allocate memory for vec, so it must be allocated before calling this function
func (ph *persistentHNSW[T]) getVecFromUid(uid uint64, c index.CacheType, vec *[]T) error {
data, err := getDataFromKeyWithCacheType(ph.pred, uid, c)
if err != nil {
if errors.Is(err, errFetchingPostingList) {
// no vector. Return empty array of floats
index.BytesAsFloatArray([]byte{}, vec, ph.floatBits)
index.BytesAsFloatArray(emptyVec, vec, ph.floatBits)
return fmt.Errorf("%w; %w", errNilVector, err)
}
return err
}
if data != nil {
index.BytesAsFloatArray(data.([]byte), vec, ph.floatBits)
index.BytesAsFloatArray(data, vec, ph.floatBits)
return nil
} else {
index.BytesAsFloatArray([]byte{}, vec, ph.floatBits)
index.BytesAsFloatArray(emptyVec, vec, ph.floatBits)
return errNilVector
}
}
Expand Down Expand Up @@ -406,7 +408,7 @@ func (ph *persistentHNSW[T]) createEntryAndStartNodes(
return create_edges(inUuid)
}

entry := BytesToUint64(data.([]byte)) // convert entry Uuid returned from Get to uint64
entry := BytesToUint64(data) // convert entry Uuid returned from Get to uint64
err := ph.getVecFromUid(entry, c, vec)
if err != nil || len(*vec) == 0 {
// The entry vector has been deleted. We have to create a new entry vector.
Expand Down Expand Up @@ -596,7 +598,7 @@ func (ph *persistentHNSW[T]) addNeighbors(ctx context.Context, tc *TxnCache,
allLayerEdges = allLayerNeighbors
} else {
// all edges of nearest neighbor
err := decodeUint64MatrixUnsafe(data.([]byte), &allLayerEdges)
err := decodeUint64MatrixUnsafe(data, &allLayerEdges)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -658,7 +660,7 @@ func (ph *persistentHNSW[T]) removeDeadNodes(nnEdges []uint64, tc *TxnCache) ([]

var deadNodes []uint64
if data != nil { // if dead nodes exist, convert to []uint64
deadNodes, err = ParseEdges(string(data.([]byte)))
deadNodes, err = ParseEdges(string(data))
if err != nil {
return []uint64{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion tok/hnsw/persistent_hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ func (ph *persistentHNSW[T]) PickStartNode(
return 0, err
}

entry := BytesToUint64(data.([]byte))
entry := BytesToUint64(data)
if err = ph.getVecFromUid(entry, c, startVec); err != nil && !errors.Is(err, errNilVector) {
return 0, err
}
Expand Down
38 changes: 19 additions & 19 deletions tok/hnsw/persistent_hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,22 @@ func flatInMemListWriteMutation(test flatInMemListAddMutationTest, t *testing.T)
}
}
// should not modify db [test.startTs, test.finishTs)
if tsDbs[test.finishTs-1].inMemTestDb[test.key] != tsDbs[test.startTs].inMemTestDb[test.key] {
if string(tsDbs[test.finishTs-1].inMemTestDb[test.key]) != string(tsDbs[test.startTs].inMemTestDb[test.key]) {
t.Errorf(
"Database at time %q not equal to expected database at time %q. Expected: %q, Got: %q",
test.finishTs-1, test.startTs,
tsDbs[test.startTs].inMemTestDb[test.key],
tsDbs[test.finishTs-1].inMemTestDb[test.key])
}
if string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]) != string(test.t.Value[:]) {
if string(tsDbs[test.finishTs].inMemTestDb[test.key][:]) != string(test.t.Value[:]) {
t.Errorf("The database at time %q for key %q gave value of %q instead of %q", test.finishTs,
test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]), string(test.t.Value[:]))
test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key][:]), string(test.t.Value[:]))
}
if string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]) !=
string(tsDbs[99].inMemTestDb[test.key].([]byte)[:]) {
if string(tsDbs[test.finishTs].inMemTestDb[test.key][:]) !=
string(tsDbs[99].inMemTestDb[test.key][:]) {
t.Errorf("The database at time %q for key %q gave value of %q instead of %q", test.finishTs,
test.key, string(tsDbs[99].inMemTestDb[test.key].([]byte)[:]),
string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]))
test.key, string(tsDbs[99].inMemTestDb[test.key][:]),
string(tsDbs[test.finishTs].inMemTestDb[test.key][:]))
}
}

Expand Down Expand Up @@ -230,14 +230,14 @@ func TestFlatInMemListAddMultipleWritesMutation(t *testing.T) {
conv := flatInMemListAddMutationTest{test.key, test.startTs, test.finishTs, test.t, test.expectedErr}
flatInMemListWriteMutation(conv, t)
} else {
if string(tsDbs[test.finishTs-1].inMemTestDb[test.key].([]byte)[:]) !=
if string(tsDbs[test.finishTs-1].inMemTestDb[test.key][:]) !=
string(flatInMemListAddMultipleWritesMutationTests[test.currIteration-1].t.Value[:]) {
t.Errorf("The database at time %q for key %q gave value of %q instead of %q", test.finishTs,
test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]), string(test.t.Value[:]))
test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key][:]), string(test.t.Value[:]))
}
if string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]) != string(test.t.Value[:]) {
if string(tsDbs[test.finishTs].inMemTestDb[test.key][:]) != string(test.t.Value[:]) {
t.Errorf("The database at time %q for key %q gave value of %q instead of %q", test.finishTs,
test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key].([]byte)[:]), string(test.t.Value[:]))
test.key, string(tsDbs[test.finishTs].inMemTestDb[test.key][:]), string(test.t.Value[:]))
}
}
}
Expand Down Expand Up @@ -339,8 +339,8 @@ func TestFlatEntryInsertToPersistentFlatStorage(t *testing.T) {
}
var float1, float2 = []float64{}, []float64{}
skey := string(key[:])
index.BytesAsFloatArray(tsDbs[0].inMemTestDb[skey].([]byte), &float1, 64)
index.BytesAsFloatArray(tsDbs[99].inMemTestDb[skey].([]byte), &float2, 64)
index.BytesAsFloatArray(tsDbs[0].inMemTestDb[skey], &float1, 64)
index.BytesAsFloatArray(tsDbs[99].inMemTestDb[skey], &float2, 64)
if !equalFloat64Slice(float1, float2) {
t.Errorf("Vector value for predicate %q at beginning and end of database were "+
"not equivalent. Start Value: %v\n, End Value: %v\n %v\n %v", flatPh.pred, tsDbs[0].inMemTestDb[skey],
Expand All @@ -355,7 +355,7 @@ func TestFlatEntryInsertToPersistentFlatStorage(t *testing.T) {
t.Errorf("Edges created during insert is incorrect. Expected: %v, Got: %v", test.expectedEdgesList, edgesNameList)
}
entryKey := DataKey(ConcatStrings(flatPh.pred, VecEntry), 1)
entryVal := BytesToUint64(tsDbs[99].inMemTestDb[string(entryKey[:])].([]byte))
entryVal := BytesToUint64(tsDbs[99].inMemTestDb[string(entryKey[:])])
if entryVal != test.inUuid {
t.Errorf("entry value stored is incorrect. Expected: %q, Got: %q", test.inUuid, entryVal)
}
Expand Down Expand Up @@ -416,7 +416,7 @@ func TestNonflatEntryInsertToPersistentFlatStorage(t *testing.T) {
// fmt.Print(tsDbs[1].inMemTestDb[string(testKey[:])])
for _, test := range nonflatEntryInsertToPersistentFlatStorageTests {
entryKey := DataKey(ConcatStrings(flatPh.pred, VecEntry), 1)
entryVal := BytesToUint64(tsDbs[99].inMemTestDb[string(entryKey[:])].([]byte))
entryVal := BytesToUint64(tsDbs[99].inMemTestDb[string(entryKey[:])])
if entryVal != 5 {
t.Errorf("entry value stored is incorrect. Expected: %q, Got: %q", 5, entryVal)
}
Expand All @@ -435,12 +435,12 @@ func TestNonflatEntryInsertToPersistentFlatStorage(t *testing.T) {
}
}
var float1, float2 = []float64{}, []float64{}
index.BytesAsFloatArray(tsDbs[0].inMemTestDb[string(key[:])].([]byte), &float1, 64)
index.BytesAsFloatArray(tsDbs[99].inMemTestDb[string(key[:])].([]byte), &float2, 64)
index.BytesAsFloatArray(tsDbs[0].inMemTestDb[string(key[:])], &float1, 64)
index.BytesAsFloatArray(tsDbs[99].inMemTestDb[string(key[:])], &float2, 64)
if !equalFloat64Slice(float1, float2) {
t.Errorf("Vector value for predicate %q at beginning and end of database were "+
"not equivalent. Start Value: %v, End Value: %v", flatPh.pred, tsDbs[0].inMemTestDb[flatPh.pred].([]float64),
tsDbs[99].inMemTestDb[flatPh.pred].([]float64))
"not equivalent. Start Value: %v, End Value: %v", flatPh.pred, tsDbs[0].inMemTestDb[flatPh.pred],
tsDbs[99].inMemTestDb[flatPh.pred])
}
edgesNameList := []string{}
for _, edge := range edges {
Expand Down
Loading

0 comments on commit ff401c2

Please sign in to comment.