diff --git a/posting/index.go b/posting/index.go index 739b45898a1..562083d6546 100644 --- a/posting/index.go +++ b/posting/index.go @@ -262,13 +262,17 @@ func (txn *Txn) addReverseMutationHelper(ctx context.Context, plist *List, return emptyCountParams, errors.Wrapf(ErrTsTooOld, "Adding reverse mutation helper count") } } + if !(hasCountIndex && !shouldAddCountEdge(found, edge)) { if err := plist.addMutationInternal(ctx, txn, edge); err != nil { return emptyCountParams, err } } + if hasCountIndex { - countAfter = countAfterMutation(countBefore, found, edge.Op) + pk, _ := x.Parse(plist.key) + shouldCountOneUid := !schema.State().IsList(edge.Attr) && !pk.IsReverse() && false + countAfter = countAfterMutation(countBefore, found, edge.Op, shouldCountOneUid) return countParams{ attr: edge.Attr, countBefore: countBefore, @@ -361,7 +365,7 @@ func (txn *Txn) addReverseAndCountMutation(ctx context.Context, t *pb.DirectedEd Facets: t.Facets, } - cp, err := txn.addReverseMutationHelper(ctx, plist, hasCountIndex, edge) + cp, err := txn.addReverseMutationHelper(ctx, plist, true, edge) if err != nil { return err } @@ -475,7 +479,21 @@ func (txn *Txn) updateCount(ctx context.Context, params countParams) error { return nil } -func countAfterMutation(countBefore int, found bool, op pb.DirectedEdge_Op) int { +// Gives the count of the posting after the mutation has finished. Currently we use this to figure after the mutation +// what is the count. For non scalar predicate, we need to use found and the operation that the user did to figure out +// if the new node was inserted or not. However, for single uid predicates this information is not useful. For scalar +// predicate, delete only works if the value was found. Set would just result in 1 alaways. +func countAfterMutation(countBefore int, found bool, op pb.DirectedEdge_Op, shouldCountOneUid bool) int { + if shouldCountOneUid { + if op == pb.DirectedEdge_SET { + return 1 + } else if op == pb.DirectedEdge_DEL && found { + return 0 + } else { + return countBefore + } + } + if !found && op != pb.DirectedEdge_DEL { return countBefore + 1 } else if found && op == pb.DirectedEdge_DEL { @@ -516,7 +534,8 @@ func (txn *Txn) addMutationHelper(ctx context.Context, l *List, doUpdateIndex bo var found bool var err error - delNonListPredicate := !schema.State().IsList(t.Attr) && + isScalarPredicate := !schema.State().IsList(t.Attr) + delNonListPredicate := isScalarPredicate && t.Op == pb.DirectedEdge_DEL && string(t.Value) != x.Star switch { @@ -560,7 +579,9 @@ func (txn *Txn) addMutationHelper(ctx context.Context, l *List, doUpdateIndex bo } if hasCountIndex { - countAfter = countAfterMutation(countBefore, found, t.Op) + pk, _ := x.Parse(l.key) + shouldCountOneUid := isScalarPredicate && !pk.IsReverse() && false + countAfter = countAfterMutation(countBefore, found, t.Op, shouldCountOneUid) return val, found, countParams{ attr: t.Attr, countBefore: countBefore, diff --git a/posting/list.go b/posting/list.go index d61b6f20966..6f940d5210c 100644 --- a/posting/list.go +++ b/posting/list.go @@ -345,7 +345,7 @@ func (mm *MutableLayer) populateUidMap(pl *pb.PostingList) { } // insertPosting inserts a new posting in the mutable layers. It updates the currentUids map. -func (mm *MutableLayer) insertPosting(mpost *pb.Posting) { +func (mm *MutableLayer) insertPosting(mpost *pb.Posting, hasCountIndex bool) { if mm.readTs != 0 { x.AssertTrue(mpost.StartTs == mm.readTs) } @@ -359,8 +359,30 @@ func (mm *MutableLayer) insertPosting(mpost *pb.Posting) { } if mpost.Uid != 0 { + // If hasCountIndex, in that case while inserting uids, if there's a delete, we only delete from the + // current entries, we dont' insert the delete posting. If we insert the delete posting, there won't be + // any set posting in the list. This would mess up the count. We can do this for all types, however, + // there might be a performance hit becasue of it. mm.populateUidMap(mm.currentEntries) if postIndex, ok := mm.currentUids[mpost.Uid]; ok { + //if hasCountIndex && mpost.Op == Del { + // // If the posting was there before, just remove it from the map, and then remove it + // // from the array. + // post := mm.currentEntries.Postings[postIndex] + // if post.Op == Del { + // // No need to do anything + // mm.currentEntries.Postings[postIndex] = mpost + // return + // } + // res := mm.currentEntries.Postings[:postIndex] + // if postIndex+1 <= len(mm.currentEntries.Postings) { + // mm.currentEntries.Postings = append(res, + // mm.currentEntries.Postings[(postIndex+1):]...) + // } + // mm.currentUids = nil + // mm.currentEntries.Postings = res + // return + //} mm.currentEntries.Postings[postIndex] = mpost } else { mm.currentEntries.Postings = append(mm.currentEntries.Postings, mpost) @@ -382,7 +404,7 @@ func (mm *MutableLayer) print() string { mm.deleteAllMarker) } -func (l *List) print() string { +func (l *List) Print() string { return fmt.Sprintf("minTs: %d, plist: %+v, mutationMap: %s", l.minTs, l.plist, l.mutationMap.print()) } @@ -721,7 +743,7 @@ func hasDeleteAll(mpost *pb.Posting) bool { } // Ensure that you either abort the uncommitted postings or commit them before calling me. -func (l *List) updateMutationLayer(mpost *pb.Posting, singleUidUpdate bool) error { +func (l *List) updateMutationLayer(mpost *pb.Posting, singleUidUpdate, hasCountIndex bool) error { l.AssertLock() x.AssertTrue(mpost.Op == Set || mpost.Op == Del || mpost.Op == Ovr) @@ -752,6 +774,7 @@ func (l *List) updateMutationLayer(mpost *pb.Posting, singleUidUpdate bool) erro // Add the deletions in the existing plist because those postings are not picked // up by iterating. Not doing so would result in delete operations that are not // applied when the transaction is committed. + //l.mutationMap.currentEntries = &pb.PostingList{} for _, post := range l.mutationMap.currentEntries.Postings { if post.Op == Del && post.Uid != mpost.Uid { newPlist.Postings = append(newPlist.Postings, post) @@ -776,14 +799,11 @@ func (l *List) updateMutationLayer(mpost *pb.Posting, singleUidUpdate bool) erro if err != nil { return err } - - // Update the mutation map with the new plist. Return here since the code below - // does not apply for predicates of type uid. l.mutationMap.setCurrentEntries(mpost.StartTs, newPlist) return nil } - l.mutationMap.insertPosting(mpost) + l.mutationMap.insertPosting(mpost, hasCountIndex) return nil } @@ -907,7 +927,7 @@ func (l *List) addMutationInternal(ctx context.Context, txn *Txn, t *pb.Directed isSingleUidUpdate := ok && !pred.GetList() && pred.GetValueType() == pb.Posting_UID && pk.IsData() && mpost.Op != Del && mpost.PostingType == pb.Posting_REF - if err != l.updateMutationLayer(mpost, isSingleUidUpdate) { + if err != l.updateMutationLayer(mpost, isSingleUidUpdate, pred.GetCount() && (pk.IsData() || pk.IsReverse())) { return errors.Wrapf(err, "cannot update mutation layer of key %s with value %+v", hex.EncodeToString(l.key), mpost) } @@ -1123,7 +1143,7 @@ func (l *List) iterate(readTs uint64, afterUid uint64, f func(obj *pb.Posting) e // pitr iterates through immutable postings err = pitr.seek(l, afterUid, deleteBelowTs) if err != nil { - return errors.Wrapf(err, "cannot initialize iterator when calling List.iterate "+l.print()) + return errors.Wrapf(err, "cannot initialize iterator when calling List.iterate "+l.Print()) } loop: diff --git a/worker/sort_test.go b/worker/sort_test.go index 0c67f2a0689..5f6ac8f861f 100644 --- a/worker/sort_test.go +++ b/worker/sort_test.go @@ -66,6 +66,250 @@ func writePostingListToDisk(kvs []*bpb.KV) error { return writer.Flush() } +func TestMultipleTxnListCount(t *testing.T) { + dir, err := os.MkdirTemp("", "storetest_") + x.Check(err) + defer os.RemoveAll(dir) + + opt := badger.DefaultOptions(dir) + ps, err := badger.OpenManaged(opt) + x.Check(err) + pstore = ps + posting.Init(ps, 0, false) + Init(ps) + err = schema.ParseBytes([]byte("scalarPredicateCount3: [uid] @count ."), 1) + require.NoError(t, err) + + ctx := context.Background() + attr := x.GalaxyAttr("scalarPredicateCount3") + + runM := func(startTs, commitTs uint64, edges []*pb.DirectedEdge) { + txn := posting.Oracle().RegisterStartTs(startTs) + for _, edge := range edges { + x.Check(runMutation(ctx, edge, txn)) + } + txn.Update() + writer := posting.NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, commitTs)) + require.NoError(t, writer.Flush()) + txn.UpdateCachedKeys(commitTs) + } + + runM(9, 11, []*pb.DirectedEdge{{ + ValueId: 3, + ValueType: pb.Posting_UID, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_SET, + }, { + ValueId: 2, + ValueType: pb.Posting_UID, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_SET, + }}) + + txn := posting.Oracle().RegisterStartTs(13) + key := x.CountKey(attr, 1, false) + l, err := txn.Get(key) + require.Nil(t, err) + uids, err := l.Uids(posting.ListOptions{ReadTs: 13}) + require.Nil(t, err) + require.Equal(t, 0, len(uids.Uids)) + + key = x.CountKey(attr, 2, false) + l, err = txn.Get(key) + require.Nil(t, err) + uids, err = l.Uids(posting.ListOptions{ReadTs: 13}) + require.Nil(t, err) + require.Equal(t, 1, len(uids.Uids)) +} + +func TestScalarPredicateRevCount(t *testing.T) { + dir, err := os.MkdirTemp("", "storetest_") + x.Check(err) + defer os.RemoveAll(dir) + + opt := badger.DefaultOptions(dir) + ps, err := badger.OpenManaged(opt) + x.Check(err) + pstore = ps + posting.Init(ps, 0, false) + Init(ps) + err = schema.ParseBytes([]byte("scalarPredicateCount2: uid @reverse @count ."), 1) + require.NoError(t, err) + + ctx := context.Background() + attr := x.GalaxyAttr("scalarPredicateCount2") + + runM := func(startTs, commitTs uint64, edges []*pb.DirectedEdge) { + txn := posting.Oracle().RegisterStartTs(startTs) + for _, edge := range edges { + x.Check(runMutation(ctx, edge, txn)) + } + txn.Update() + writer := posting.NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, commitTs)) + require.NoError(t, writer.Flush()) + txn.UpdateCachedKeys(commitTs) + } + + runM(9, 11, []*pb.DirectedEdge{{ + ValueId: 3, + ValueType: pb.Posting_UID, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_SET, + }, { + ValueId: 3, + ValueType: pb.Posting_UID, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_DEL, + }}) + + txn := posting.Oracle().RegisterStartTs(13) + key := x.DataKey(attr, 1) + l, err := txn.Get(key) + l.RLock() + require.Equal(t, 0, l.GetLength(13)) + l.RUnlock() + + runM(15, 17, []*pb.DirectedEdge{{ + ValueId: 3, + ValueType: pb.Posting_UID, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_SET, + }}) + + txn = posting.Oracle().RegisterStartTs(18) + l, err = txn.Get(key) + l.RLock() + require.Equal(t, 1, l.GetLength(18)) + l.RUnlock() + + runM(18, 19, []*pb.DirectedEdge{{ + ValueId: 3, + ValueType: pb.Posting_UID, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_DEL, + }}) + + txn = posting.Oracle().RegisterStartTs(20) + l, err = txn.Get(key) + l.RLock() + require.Equal(t, 0, l.GetLength(20)) + l.RUnlock() +} + +func TestScalarPredicateIntCount(t *testing.T) { + dir, err := os.MkdirTemp("", "storetest_") + x.Check(err) + defer os.RemoveAll(dir) + + opt := badger.DefaultOptions(dir) + ps, err := badger.OpenManaged(opt) + x.Check(err) + pstore = ps + posting.Init(ps, 0, false) + Init(ps) + err = schema.ParseBytes([]byte("scalarPredicateCount1: string @count ."), 1) + require.NoError(t, err) + + ctx := context.Background() + attr := x.GalaxyAttr("scalarPredicateCount1") + + runM := func(startTs, commitTs uint64, edge *pb.DirectedEdge) { + txn := posting.Oracle().RegisterStartTs(startTs) + x.Check(runMutation(ctx, edge, txn)) + txn.Update() + writer := posting.NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, commitTs)) + require.NoError(t, writer.Flush()) + txn.UpdateCachedKeys(commitTs) + } + + runM(5, 7, &pb.DirectedEdge{ + Value: []byte("a"), + ValueType: pb.Posting_STRING, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_SET, + }) + + key := x.CountKey(attr, 1, false) + rollup(t, key, ps, 8) + + runM(9, 11, &pb.DirectedEdge{ + Value: []byte("a"), + ValueType: pb.Posting_STRING, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_DEL, + }) + + txn := posting.Oracle().RegisterStartTs(20) + l, err := txn.Get(key) + l.RLock() + require.Equal(t, 0, l.GetLength(20)) + l.RUnlock() +} + +func TestScalarPredicateCount(t *testing.T) { + dir, err := os.MkdirTemp("", "storetest_") + x.Check(err) + defer os.RemoveAll(dir) + + opt := badger.DefaultOptions(dir) + ps, err := badger.OpenManaged(opt) + x.Check(err) + pstore = ps + posting.Init(ps, 0, false) + Init(ps) + err = schema.ParseBytes([]byte("scalarPredicateCount: uid @count ."), 1) + require.NoError(t, err) + + ctx := context.Background() + attr := x.GalaxyAttr("scalarPredicateCount") + + runM := func(startTs, commitTs uint64, edge *pb.DirectedEdge) { + txn := posting.Oracle().RegisterStartTs(startTs) + x.Check(runMutation(ctx, edge, txn)) + txn.Update() + writer := posting.NewTxnWriter(pstore) + require.NoError(t, txn.CommitToDisk(writer, commitTs)) + require.NoError(t, writer.Flush()) + txn.UpdateCachedKeys(commitTs) + } + + runM(5, 7, &pb.DirectedEdge{ + ValueId: 2, + ValueType: pb.Posting_UID, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_SET, + }) + + key := x.CountKey(attr, 1, false) + rollup(t, key, ps, 8) + + runM(9, 11, &pb.DirectedEdge{ + ValueId: 3, + ValueType: pb.Posting_UID, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_SET, + }) + + txn := posting.Oracle().RegisterStartTs(15) + l, err := txn.Get(key) + l.RLock() + require.Equal(t, 1, l.GetLength(15)) + l.RUnlock() +} + func TestSingleUid(t *testing.T) { dir, err := os.MkdirTemp("", "storetest_") x.Check(err)