diff --git a/p2p/net/swarm/dial_sync.go b/p2p/net/swarm/dial_sync.go index 48b77899f2..69739e54ec 100644 --- a/p2p/net/swarm/dial_sync.go +++ b/p2p/net/swarm/dial_sync.go @@ -53,22 +53,38 @@ func (ad *activeDial) incref() { func (ad *activeDial) decref() { ad.refCntLk.Lock() - defer ad.refCntLk.Unlock() ad.refCnt-- - if ad.refCnt <= 0 { - ad.cancel() + maybeZero := (ad.refCnt <= 0) + ad.refCntLk.Unlock() + + // make sure to always take locks in correct order. + if maybeZero { ad.ds.dialsLk.Lock() - delete(ad.ds.dials, ad.id) + ad.refCntLk.Lock() + // check again after lock swap drop to make sure nobody else called incref + // in between locks + if ad.refCnt <= 0 { + ad.cancel() + delete(ad.ds.dials, ad.id) + } + ad.refCntLk.Unlock() ad.ds.dialsLk.Unlock() } } -func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { +func (ad *activeDial) start(ctx context.Context) { + ad.conn, ad.err = ad.ds.dialFunc(ctx, ad.id) + close(ad.waitch) + ad.cancel() +} + +func (ds *DialSync) getActiveDial(p peer.ID) *activeDial { ds.dialsLk.Lock() + defer ds.dialsLk.Unlock() actd, ok := ds.dials[p] if !ok { - ctx, cancel := context.WithCancel(context.Background()) + adctx, cancel := context.WithCancel(context.Background()) actd = &activeDial{ id: p, cancel: cancel, @@ -77,15 +93,15 @@ func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { } ds.dials[p] = actd - go func(ctx context.Context, p peer.ID, ad *activeDial) { - ad.conn, ad.err = ds.dialFunc(ctx, p) - close(ad.waitch) - ad.cancel() - }(ctx, p, actd) + go actd.start(adctx) } + // increase ref count before dropping dialsLk actd.incref() - ds.dialsLk.Unlock() - return actd.wait(ctx) + return actd +} + +func (ds *DialSync) DialLock(ctx context.Context, p peer.ID) (*Conn, error) { + return ds.getActiveDial(p).wait(ctx) } diff --git a/p2p/net/swarm/dial_sync_test.go b/p2p/net/swarm/dial_sync_test.go index ca81a9c872..0d70e226af 100644 --- a/p2p/net/swarm/dial_sync_test.go +++ b/p2p/net/swarm/dial_sync_test.go @@ -201,3 +201,27 @@ func TestFailFirst(t *testing.T) { t.Fatal("should have gotten a 'real' conn back") } } + +func TestStressActiveDial(t *testing.T) { + ds := NewDialSync(func(ctx context.Context, p peer.ID) (*Conn, error) { + return nil, nil + }) + + wg := sync.WaitGroup{} + + pid := peer.ID("foo") + + makeDials := func() { + for i := 0; i < 10000; i++ { + ds.DialLock(context.Background(), pid) + } + wg.Done() + } + + for i := 0; i < 100; i++ { + wg.Add(1) + go makeDials() + } + + wg.Wait() +}