diff --git a/merkledag.go b/merkledag.go index de2f61c..fe68ce4 100644 --- a/merkledag.go +++ b/merkledag.go @@ -158,25 +158,57 @@ func (n *dagService) Session(ctx context.Context) ipld.NodeGetter { // FetchGraph fetches all nodes that are children of the given node func FetchGraph(ctx context.Context, root *cid.Cid, serv ipld.DAGService) error { + return FetchGraphWithDepthLimit(ctx, root, -1, serv) +} + +// FetchGraphWithDepthLimit fetches all nodes that are children to the given +// node down to the given depth. maxDetph=0 means "only fetch root", +// maxDepth=1 means "fetch root and its direct children" and so on... +// maxDepth=-1 means unlimited. +func FetchGraphWithDepthLimit(ctx context.Context, root *cid.Cid, depthLim int, serv ipld.DAGService) error { var ng ipld.NodeGetter = serv ds, ok := serv.(*dagService) if ok { ng = &sesGetter{bserv.NewSession(ctx, ds.Blocks)} } + set := make(map[string]int) + + // Visit function returns true when: + // * The element is not in the set and we're not over depthLim + // * The element is in the set but recorded depth is deeper + // than currently seen (if we find it higher in the tree we'll need + // to explore deeper than before). + // depthLim = -1 means we only return true if the element is not in the + // set. + visit := func(c *cid.Cid, depth int) bool { + key := string(c.Bytes()) + oldDepth, ok := set[key] + + if (ok && depthLim < 0) || (depthLim >= 0 && depth > depthLim) { + return false + } + + if !ok || oldDepth > depth { + set[key] = depth + return true + } + return false + } + v, _ := ctx.Value(progressContextKey).(*ProgressTracker) if v == nil { - return EnumerateChildrenAsync(ctx, GetLinksDirect(ng), root, cid.NewSet().Visit) + return EnumerateChildrenAsyncDepth(ctx, GetLinksDirect(ng), root, 0, visit) } - set := cid.NewSet() - visit := func(c *cid.Cid) bool { - if set.Visit(c) { + + visitProgress := func(c *cid.Cid, depth int) bool { + if visit(c, depth) { v.Increment() return true } return false } - return EnumerateChildrenAsync(ctx, GetLinksDirect(ng), root, visit) + return EnumerateChildrenAsyncDepth(ctx, GetLinksDirect(ng), root, 0, visitProgress) } // GetMany gets many nodes from the DAG at once. @@ -254,14 +286,26 @@ func GetLinksWithDAG(ng ipld.NodeGetter) GetLinks { // unseen children to the passed in set. // TODO: parallelize to avoid disk latency perf hits? func EnumerateChildren(ctx context.Context, getLinks GetLinks, root *cid.Cid, visit func(*cid.Cid) bool) error { + visitDepth := func(c *cid.Cid, depth int) bool { + return visit(c) + } + + return EnumerateChildrenDepth(ctx, getLinks, root, 0, visitDepth) +} + +// EnumerateChildrenDepth walks the dag below the given root and passes the +// current depth to a given visit function. The visit function can be used to +// limit DAG exploration. +func EnumerateChildrenDepth(ctx context.Context, getLinks GetLinks, root *cid.Cid, depth int, visit func(*cid.Cid, int) bool) error { links, err := getLinks(ctx, root) if err != nil { return err } + for _, lnk := range links { c := lnk.Cid - if visit(c) { - err = EnumerateChildren(ctx, getLinks, c, visit) + if visit(c, depth+1) { + err = EnumerateChildrenDepth(ctx, getLinks, c, depth+1, visit) if err != nil { return err } @@ -305,8 +349,30 @@ var FetchGraphConcurrency = 8 // // NOTE: It *does not* make multiple concurrent calls to the passed `visit` function. func EnumerateChildrenAsync(ctx context.Context, getLinks GetLinks, c *cid.Cid, visit func(*cid.Cid) bool) error { - feed := make(chan *cid.Cid) - out := make(chan []*ipld.Link) + visitDepth := func(c *cid.Cid, depth int) bool { + return visit(c) + } + + return EnumerateChildrenAsyncDepth(ctx, getLinks, c, 0, visitDepth) +} + +// EnumerateChildrenAsyncDepth is equivalent to EnumerateChildrenDepth *except* +// that it fetches children in parallel (down to a maximum depth in the graph). +// +// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function. +func EnumerateChildrenAsyncDepth(ctx context.Context, getLinks GetLinks, c *cid.Cid, startDepth int, visit func(*cid.Cid, int) bool) error { + type cidDepth struct { + cid *cid.Cid + depth int + } + + type linksDepth struct { + links []*ipld.Link + depth int + } + + feed := make(chan *cidDepth) + out := make(chan *linksDepth) done := make(chan struct{}) var setlk sync.Mutex @@ -318,20 +384,28 @@ func EnumerateChildrenAsync(ctx context.Context, getLinks GetLinks, c *cid.Cid, for i := 0; i < FetchGraphConcurrency; i++ { go func() { - for ic := range feed { + for cdepth := range feed { + ci := cdepth.cid + depth := cdepth.depth + setlk.Lock() - shouldVisit := visit(ic) + shouldVisit := visit(ci, depth) setlk.Unlock() if shouldVisit { - links, err := getLinks(ctx, ic) + links, err := getLinks(ctx, ci) if err != nil { errChan <- err return } + outLinks := &linksDepth{ + links: links, + depth: depth + 1, + } + select { - case out <- links: + case out <- outLinks: case <-fetchersCtx.Done(): return } @@ -346,10 +420,13 @@ func EnumerateChildrenAsync(ctx context.Context, getLinks GetLinks, c *cid.Cid, defer close(feed) send := feed - var todobuffer []*cid.Cid + var todobuffer []*cidDepth var inProgress int - next := c + next := &cidDepth{ + cid: c, + depth: startDepth, + } for { select { case send <- next: @@ -366,13 +443,18 @@ func EnumerateChildrenAsync(ctx context.Context, getLinks GetLinks, c *cid.Cid, if inProgress == 0 && next == nil { return nil } - case links := <-out: - for _, lnk := range links { + case linksDepth := <-out: + for _, lnk := range linksDepth.links { + cd := &cidDepth{ + cid: lnk.Cid, + depth: linksDepth.depth, + } + if next == nil { - next = lnk.Cid + next = cd send = feed } else { - todobuffer = append(todobuffer, lnk.Cid) + todobuffer = append(todobuffer, cd) } } case err := <-errChan: diff --git a/merkledag_test.go b/merkledag_test.go index 36c0548..cffaf20 100644 --- a/merkledag_test.go +++ b/merkledag_test.go @@ -26,6 +26,52 @@ import ( ipld "github.com/ipfs/go-ipld-format" ) +// makeDepthTestingGraph makes a small DAG with two levels. The level-two +// nodes are both children of the root and of one of the level 1 nodes. +// This is meant to test the EnumerateChildren*Depth functions. +func makeDepthTestingGraph(t *testing.T, ds ipld.DAGService) ipld.Node { + root := NodeWithData(nil) + l11 := NodeWithData([]byte("leve1_node1")) + l12 := NodeWithData([]byte("leve1_node2")) + l21 := NodeWithData([]byte("leve2_node1")) + l22 := NodeWithData([]byte("leve2_node2")) + l23 := NodeWithData([]byte("leve2_node3")) + + l11.AddNodeLink(l21.Cid().String(), l21) + l11.AddNodeLink(l22.Cid().String(), l22) + l11.AddNodeLink(l23.Cid().String(), l23) + + root.AddNodeLink(l11.Cid().String(), l11) + root.AddNodeLink(l12.Cid().String(), l12) + root.AddNodeLink(l23.Cid().String(), l23) + + ctx := context.Background() + for _, n := range []ipld.Node{l23, l22, l21, l12, l11, root} { + err := ds.Add(ctx, n) + if err != nil { + t.Fatal(err) + } + } + + return root +} + +// Check that all children of root are in the given set and in the datastore +func traverseAndCheck(t *testing.T, root ipld.Node, ds ipld.DAGService, hasF func(c *cid.Cid) bool) { + // traverse dag and check + for _, lnk := range root.Links() { + c := lnk.Cid + if !hasF(c) { + t.Fatal("missing key in set! ", lnk.Cid.String()) + } + child, err := ds.Get(context.Background(), c) + if err != nil { + t.Fatal(err) + } + traverseAndCheck(t, child, ds, hasF) + } +} + func TestNode(t *testing.T) { n1 := NodeWithData([]byte("beep")) @@ -293,6 +339,66 @@ func TestFetchGraph(t *testing.T) { } } +func TestFetchGraphWithDepthLimit(t *testing.T) { + type testcase struct { + depthLim int + setLen int + } + + tests := []testcase{ + testcase{1, 3}, + testcase{0, 0}, + testcase{-1, 5}, + testcase{2, 5}, + testcase{3, 5}, + } + + testF := func(t *testing.T, tc testcase) { + var dservs []ipld.DAGService + bsis := bstest.Mocks(2) + for _, bsi := range bsis { + dservs = append(dservs, NewDAGService(bsi)) + } + + root := makeDepthTestingGraph(t, dservs[0]) + + err := FetchGraphWithDepthLimit(context.TODO(), root.Cid(), tc.depthLim, dservs[1]) + if err != nil { + t.Fatal(err) + } + + // create an offline dagstore and ensure all blocks were fetched + bs := bserv.New(bsis[1].Blockstore(), offline.Exchange(bsis[1].Blockstore())) + + offlineDS := NewDAGService(bs) + + set := make(map[string]int) + visitF := func(c *cid.Cid, depth int) bool { + if tc.depthLim < 0 || depth <= tc.depthLim { + set[string(c.Bytes())] = depth + return true + } + return false + + } + + err = EnumerateChildrenDepth(context.Background(), offlineDS.GetLinks, root.Cid(), 0, visitF) + if err != nil { + t.Fatal(err) + } + + if len(set) != tc.setLen { + t.Fatalf("expected %d nodes but visited %d", tc.setLen, len(set)) + } + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("depth limit %d", tc.depthLim), func(t *testing.T) { + testF(t, tc) + }) + } +} + func TestEnumerateChildren(t *testing.T) { bsi := bstest.Mocks(1) ds := NewDAGService(bsi[0]) @@ -307,23 +413,7 @@ func TestEnumerateChildren(t *testing.T) { t.Fatal(err) } - var traverse func(n ipld.Node) - traverse = func(n ipld.Node) { - // traverse dag and check - for _, lnk := range n.Links() { - c := lnk.Cid - if !set.Has(c) { - t.Fatal("missing key in set! ", lnk.Cid.String()) - } - child, err := ds.Get(context.Background(), c) - if err != nil { - t.Fatal(err) - } - traverse(child) - } - } - - traverse(root) + traverseAndCheck(t, root, ds, set.Has) } func TestFetchFailure(t *testing.T) {