diff --git a/dht.go b/dht.go index ca2209495..39e62a17c 100644 --- a/dht.go +++ b/dht.go @@ -111,12 +111,6 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er // register for network notifs. dht.host.Network().Notify((*netNotifiee)(dht)) - dht.proc = goprocessctx.WithContextAndTeardown(ctx, func() error { - // remove ourselves from network notifs. - dht.host.Network().StopNotify((*netNotifiee)(dht)) - return nil - }) - dht.proc.AddChild(dht.providers.Process()) dht.Validator = cfg.Validator @@ -172,8 +166,6 @@ func makeDHT(ctx context.Context, h host.Host, cfg *opts.Options) *IpfsDHT { peerstore: h.Peerstore(), host: h, strmap: make(map[peer.ID]*messageSender), - ctx: ctx, - providers: providers.NewProviderManager(ctx, h.ID(), cfg.Datastore), birth: time.Now(), routingTable: rt, protocols: cfg.Protocols, @@ -181,7 +173,19 @@ func makeDHT(ctx context.Context, h host.Host, cfg *opts.Options) *IpfsDHT { triggerRtRefresh: make(chan chan<- error), } - dht.ctx = dht.newContextWithLocalTags(ctx) + // create a DHT proc with the given teardown + dht.proc = goprocess.WithTeardown(func() error { + // remove ourselves from network notifs. + dht.host.Network().StopNotify((*netNotifiee)(dht)) + return nil + }) + + // create a tagged context derived from the original context + ctxTags := dht.newContextWithLocalTags(ctx) + // the DHT context should be done when the process is closed + dht.ctx = goprocessctx.WithProcessClosing(ctxTags, dht.proc) + + dht.providers = providers.NewProviderManager(dht.ctx, h.ID(), cfg.Datastore) return dht } diff --git a/dht_test.go b/dht_test.go index 7f447b194..0beaafa74 100644 --- a/dht_test.go +++ b/dht_test.go @@ -402,6 +402,30 @@ func TestValueSetInvalid(t *testing.T) { testSetGet("valid", true, "newer", nil) } +func TestContextShutDown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dht := setupDHT(ctx, t, false) + + // context is alive + select { + case <-dht.Context().Done(): + t.Fatal("context should not be done") + default: + } + + // shut down dht + require.NoError(t, dht.Close()) + + // now context should be done + select { + case <-dht.Context().Done(): + default: + t.Fatal("context should be done") + } +} + func TestSearchValue(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel()