diff --git a/pkg/storage/cmdq/interval_btree.go b/pkg/storage/cmdq/interval_btree.go index 0f144948eb0a..770bbc90c28d 100644 --- a/pkg/storage/cmdq/interval_btree.go +++ b/pkg/storage/cmdq/interval_btree.go @@ -18,6 +18,8 @@ import ( "bytes" "sort" "strings" + "sync" + "sync/atomic" "unsafe" "github.com/cockroachdb/cockroach/pkg/roachpb" @@ -107,21 +109,126 @@ func upperBound(c *cmd) keyBound { } type leafNode struct { - max keyBound + ref int32 count int16 leaf bool + max keyBound cmds [maxCmds]*cmd } -func newLeafNode() *node { - return (*node)(unsafe.Pointer(&leafNode{leaf: true})) -} - type node struct { leafNode children [maxCmds + 1]*node } +func leafToNode(ln *leafNode) *node { + return (*node)(unsafe.Pointer(ln)) +} + +func nodeToLeaf(n *node) *leafNode { + return (*leafNode)(unsafe.Pointer(n)) +} + +var leafPool = sync.Pool{ + New: func() interface{} { + return new(leafNode) + }, +} + +var nodePool = sync.Pool{ + New: func() interface{} { + return new(node) + }, +} + +func newLeafNode() *node { + n := leafToNode(leafPool.Get().(*leafNode)) + n.leaf = true + n.ref = 1 + return n +} + +func newNode() *node { + n := nodePool.Get().(*node) + n.ref = 1 + return n +} + +// mut creates and returns a mutable node reference. If the node is not shared +// with any other trees then it can be modified in place. Otherwise, it must be +// cloned to ensure unique ownership. In this way, we enforce a copy-on-write +// policy which transparently incorporates the idea of local mutations, like +// Clojure's transients or Haskell's ST monad, where nodes are only copied +// during the first time that they are modified between Clone operations. +// +// When a node is cloned, the provided pointer will be redirected to the new +// mutable node. +func mut(n **node) *node { + if atomic.LoadInt32(&(*n).ref) == 1 { + // Exclusive ownership. Can mutate in place. + return *n + } + // If we do not have unique ownership over the node then we + // clone it to gain unique ownership. After doing so, we can + // release our reference to the old node. + c := (*n).clone() + (*n).decRef(true /* recursive */) + *n = c + return *n +} + +// incRef acquires a reference to the node. +func (n *node) incRef() { + atomic.AddInt32(&n.ref, 1) +} + +// decRef releases a reference to the node. If requested, the method +// will recurse into child nodes and decrease their refcounts as well. +func (n *node) decRef(recursive bool) { + if atomic.AddInt32(&n.ref, -1) > 0 { + // Other references remain. Can't free. + return + } + // Clear and release node into memory pool. + if n.leaf { + ln := nodeToLeaf(n) + *ln = leafNode{} + leafPool.Put(ln) + } else { + // Release child references first, if requested. + if recursive { + for i := int16(0); i <= n.count; i++ { + n.children[i].decRef(true /* recursive */) + } + } + *n = node{} + nodePool.Put(n) + } +} + +// clone creates a clone of the receiver with a single reference count. +func (n *node) clone() *node { + var c *node + if n.leaf { + c = newLeafNode() + } else { + c = newNode() + } + // NB: copy field-by-field without touching n.ref to avoid + // triggering the race detector and looking like a data race. + c.count = n.count + c.max = n.max + c.cmds = n.cmds + if !c.leaf { + // Copy children and increase each refcount. + c.children = n.children + for i := int16(0); i <= c.count; i++ { + c.children[i].incRef() + } + } + return c +} + func (n *node) insertAt(index int, c *cmd, nd *node) { if index < int(n.count) { copy(n.cmds[index+1:n.count+1], n.cmds[index:n.count]) @@ -247,7 +354,7 @@ func (n *node) split(i int) (*cmd, *node) { if n.leaf { next = newLeafNode() } else { - next = &node{} + next = newNode() } next.count = n.count - int16(i+1) copy(next.cmds[:], n.cmds[i+1:n.count]) @@ -287,7 +394,7 @@ func (n *node) insert(c *cmd) (replaced, newBound bool) { return false, n.adjustUpperBoundOnInsertion(c, nil) } if n.children[i].count >= maxCmds { - splitcmd, splitNode := n.children[i].split(maxCmds / 2) + splitcmd, splitNode := mut(&n.children[i]).split(maxCmds / 2) n.insertAt(i, splitcmd, splitNode) switch cmp := cmp(c, n.cmds[i]); { @@ -300,7 +407,7 @@ func (n *node) insert(c *cmd) (replaced, newBound bool) { return true, false } } - replaced, newBound = n.children[i].insert(c) + replaced, newBound = mut(&n.children[i]).insert(c) if newBound { newBound = n.adjustUpperBoundOnInsertion(c, nil) } @@ -317,7 +424,7 @@ func (n *node) removeMax() *cmd { n.adjustUpperBoundOnRemoval(out, nil) return out } - child := n.children[n.count] + child := mut(&n.children[n.count]) if child.count <= minCmds { n.rebalanceOrMerge(int(n.count)) return n.removeMax() @@ -337,12 +444,12 @@ func (n *node) remove(c *cmd) (out *cmd, newBound bool) { } return nil, false } - child := n.children[i] - if child.count <= minCmds { + if n.children[i].count <= minCmds { // Child not large enough to remove from. n.rebalanceOrMerge(i) return n.remove(c) } + child := mut(&n.children[i]) if found { // Replace the cmd being removed with the max cmd in our left child. out = n.cmds[i] @@ -390,8 +497,8 @@ func (n *node) rebalanceOrMerge(i int) { // v // a // - left := n.children[i-1] - child := n.children[i] + left := mut(&n.children[i-1]) + child := mut(&n.children[i]) xCmd, grandChild := left.popBack() yCmd := n.cmds[i-1] child.pushFront(yCmd, grandChild) @@ -429,8 +536,8 @@ func (n *node) rebalanceOrMerge(i int) { // v // a // - right := n.children[i+1] - child := n.children[i] + right := mut(&n.children[i+1]) + child := mut(&n.children[i]) xCmd, grandChild := right.popFront() yCmd := n.cmds[i] child.pushBack(yCmd, grandChild) @@ -465,7 +572,9 @@ func (n *node) rebalanceOrMerge(i int) { if i >= int(n.count) { i = int(n.count - 1) } - child := n.children[i] + child := mut(&n.children[i]) + // Make mergeChild mutable, bumping the refcounts on its children if necessary. + _ = mut(&n.children[i+1]) mergeCmd, mergeChild := n.removeAt(i) child.cmds[child.count] = mergeCmd copy(child.cmds[child.count+1:], mergeChild.cmds[:mergeChild.count]) @@ -475,6 +584,7 @@ func (n *node) rebalanceOrMerge(i int) { child.count += mergeChild.count + 1 child.adjustUpperBoundOnInsertion(mergeCmd, mergeChild) + mergeChild.decRef(false /* recursive */) } } @@ -548,25 +658,39 @@ type btree struct { length int } -// Reset removes all cmds from the btree. +// Reset removes all cmds from the btree. In doing so, it allows memory +// held by the btree to be recycled. Failure to call this method before +// letting a btree be GCed is safe in that it won't cause a memory leak, +// but it will prevent btree nodes from being efficiently re-used. func (t *btree) Reset() { - t.root = nil + if t.root != nil { + t.root.decRef(true /* recursive */) + t.root = nil + } t.length = 0 } -// Silent unused warning. -var _ = (*btree).Reset +// Clone clones the btree, lazily. +func (t *btree) Clone() btree { + c := *t + if c.root != nil { + c.root.incRef() + } + return c +} // Delete removes a cmd equal to the passed in cmd from the tree. func (t *btree) Delete(c *cmd) { if t.root == nil || t.root.count == 0 { return } - if out, _ := t.root.remove(c); out != nil { + if out, _ := mut(&t.root).remove(c); out != nil { t.length-- } if t.root.count == 0 && !t.root.leaf { + old := t.root t.root = t.root.children[0] + old.decRef(false /* recursive */) } } @@ -576,8 +700,8 @@ func (t *btree) Set(c *cmd) { if t.root == nil { t.root = newLeafNode() } else if t.root.count >= maxCmds { - splitcmd, splitNode := t.root.split(maxCmds / 2) - newRoot := &node{} + splitcmd, splitNode := mut(&t.root).split(maxCmds / 2) + newRoot := newNode() newRoot.count = 1 newRoot.cmds[0] = splitcmd newRoot.children[0] = t.root @@ -585,7 +709,7 @@ func (t *btree) Set(c *cmd) { newRoot.max = newRoot.findUpperBound() t.root = newRoot } - if replaced, _ := t.root.insert(c); !replaced { + if replaced, _ := mut(&t.root).insert(c); !replaced { t.length++ } } diff --git a/pkg/storage/cmdq/interval_btree_test.go b/pkg/storage/cmdq/interval_btree_test.go index 693787c5d86c..30fb6daa90cb 100644 --- a/pkg/storage/cmdq/interval_btree_test.go +++ b/pkg/storage/cmdq/interval_btree_test.go @@ -17,6 +17,8 @@ package cmdq import ( "fmt" "math/rand" + "reflect" + "sync" "testing" "github.com/stretchr/testify/require" @@ -443,6 +445,78 @@ func TestBTreeSeekOverlapRandom(t *testing.T) { } } +func TestBTreeCloneConcurrentOperations(t *testing.T) { + const cloneTestSize = 10000 + p := perm(cloneTestSize) + + var trees []*btree + treeC, treeDone := make(chan *btree), make(chan struct{}) + go func() { + for b := range treeC { + trees = append(trees, b) + } + close(treeDone) + }() + + var wg sync.WaitGroup + var populate func(tr *btree, start int) + populate = func(tr *btree, start int) { + t.Logf("Starting new clone at %v", start) + treeC <- tr + for i := start; i < cloneTestSize; i++ { + tr.Set(p[i]) + if i%(cloneTestSize/5) == 0 { + wg.Add(1) + c := tr.Clone() + go populate(&c, i+1) + } + } + wg.Done() + } + + wg.Add(1) + var tr btree + go populate(&tr, 0) + wg.Wait() + close(treeC) + <-treeDone + + t.Logf("Starting equality checks on %d trees", len(trees)) + want := rang(0, cloneTestSize-1) + for i, tree := range trees { + if !reflect.DeepEqual(want, all(tree)) { + t.Errorf("tree %v mismatch", i) + } + } + + t.Log("Removing half of cmds from first half") + toRemove := want[cloneTestSize/2:] + for i := 0; i < len(trees)/2; i++ { + tree := trees[i] + wg.Add(1) + go func() { + for _, cmd := range toRemove { + tree.Delete(cmd) + } + wg.Done() + }() + } + wg.Wait() + + t.Log("Checking all values again") + for i, tree := range trees { + var wantpart []*cmd + if i < len(trees)/2 { + wantpart = want[:cloneTestSize/2] + } else { + wantpart = want + } + if got := all(tree); !reflect.DeepEqual(wantpart, got) { + t.Errorf("tree %v mismatch, want %v got %v", i, len(want), len(got)) + } + } +} + func TestBTreeCmp(t *testing.T) { testCases := []struct { spanA, spanB roachpb.Span @@ -544,6 +618,25 @@ func perm(n int) (out []*cmd) { return out } +// rang returns an ordered list of cmds with spans in the range [m, n]. +func rang(m, n int) (out []*cmd) { + for i := m; i <= n; i++ { + out = append(out, newCmd(spanWithEnd(i, i+1))) + } + return out +} + +// all extracts all cmds from a tree in order as a slice. +func all(tr *btree) (out []*cmd) { + it := tr.MakeIter() + it.First() + for it.Valid() { + out = append(out, it.Cmd()) + it.Next() + } + return out +} + func forBenchmarkSizes(b *testing.B, f func(b *testing.B, count int)) { for _, count := range []int{16, 128, 1024, 8192, 65536} { b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) { @@ -610,6 +703,48 @@ func BenchmarkBTreeDeleteInsert(b *testing.B) { }) } +func BenchmarkBTreeDeleteInsertCloneOnce(b *testing.B) { + forBenchmarkSizes(b, func(b *testing.B, count int) { + insertP := perm(count) + var tr btree + for _, cmd := range insertP { + tr.Set(cmd) + } + tr = tr.Clone() + b.ResetTimer() + for i := 0; i < b.N; i++ { + cmd := insertP[i%count] + tr.Delete(cmd) + tr.Set(cmd) + } + }) +} + +func BenchmarkBTreeDeleteInsertCloneEachTime(b *testing.B) { + for _, reset := range []bool{false, true} { + b.Run(fmt.Sprintf("reset=%t", reset), func(b *testing.B) { + forBenchmarkSizes(b, func(b *testing.B, count int) { + insertP := perm(count) + var tr, trReset btree + for _, cmd := range insertP { + tr.Set(cmd) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + cmd := insertP[i%count] + if reset { + trReset.Reset() + trReset = tr + } + tr = tr.Clone() + tr.Delete(cmd) + tr.Set(cmd) + } + }) + }) + } +} + func BenchmarkBTreeMakeIter(b *testing.B) { var tr btree for i := 0; i < b.N; i++ {