From 85c201677d8bcead7975c8fff6c7df4c34330dc2 Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 9 Jul 2023 17:56:20 +0530 Subject: [PATCH 1/4] swarm: make externally used dialQueue method public --- p2p/net/swarm/dial_worker.go | 14 +++++++------- p2p/net/swarm/dial_worker_test.go | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 78795f17b3..52e19911bc 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -123,7 +123,7 @@ func (w *dialWorker) loop() { <-dialTimer.Ch() } timerRunning = false - if dq.len() > 0 { + if dq.Len() > 0 { if dialsInFlight == 0 && !w.connected { // if there are no dials in flight, trigger the next dials immediately dialTimer.Reset(startTime) @@ -417,7 +417,7 @@ func newDialQueue() *dialQueue { // Add adds adelay to the queue. If another element exists in the queue with // the same address, it replaces that element. func (dq *dialQueue) Add(adelay network.AddrDelay) { - for i := 0; i < dq.len(); i++ { + for i := 0; i < dq.Len(); i++ { if dq.q[i].Addr.Equal(adelay.Addr) { if dq.q[i].Delay == adelay.Delay { // existing element is the same. nothing to do @@ -430,7 +430,7 @@ func (dq *dialQueue) Add(adelay network.AddrDelay) { } } - for i := 0; i < dq.len(); i++ { + for i := 0; i < dq.Len(); i++ { if dq.q[i].Delay > adelay.Delay { dq.q = append(dq.q, network.AddrDelay{}) // extend the slice copy(dq.q[i+1:], dq.q[i:]) @@ -443,13 +443,13 @@ func (dq *dialQueue) Add(adelay network.AddrDelay) { // NextBatch returns all the elements in the queue with the highest priority func (dq *dialQueue) NextBatch() []network.AddrDelay { - if dq.len() == 0 { + if dq.Len() == 0 { return nil } // i is the index of the second highest priority element var i int - for i = 0; i < dq.len(); i++ { + for i = 0; i < dq.Len(); i++ { if dq.q[i].Delay != dq.q[0].Delay { break } @@ -464,7 +464,7 @@ func (dq *dialQueue) top() network.AddrDelay { return dq.q[0] } -// len returns the number of elements in the queue -func (dq *dialQueue) len() int { +// Len returns the number of elements in the queue +func (dq *dialQueue) Len() int { return len(dq.q) } diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 6679bd9587..9ea2180802 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -484,8 +484,8 @@ func TestDialQueueNextBatch(t *testing.T) { } } } - if q.len() != 0 { - t.Errorf("expected queue to be empty at end. got: %d", q.len()) + if q.Len() != 0 { + t.Errorf("expected queue to be empty at end. got: %d", q.Len()) } }) } From 473809b0d6480c48ead28a215170acaff05bdb6f Mon Sep 17 00:00:00 2001 From: sukun Date: Sun, 9 Jul 2023 19:32:28 +0530 Subject: [PATCH 2/4] swarm: move all connection post processing to worker loop --- p2p/net/swarm/dial_worker.go | 130 +++++++++++++++++++++--------- p2p/net/swarm/dial_worker_test.go | 9 ++- p2p/net/swarm/limiter.go | 21 +---- p2p/net/swarm/limiter_test.go | 33 ++++---- p2p/net/swarm/swarm_dial.go | 61 +++++++------- 5 files changed, 144 insertions(+), 110 deletions(-) diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 52e19911bc..f3d6f160bd 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -2,10 +2,12 @@ package swarm import ( "context" + "fmt" "math" "sync" "time" + "github.com/libp2p/go-libp2p/core/canonicallog" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -62,6 +64,8 @@ type addrDial struct { createdAt time.Time // dialRankingDelay is the delay in dialing this address introduced by the ranking logic dialRankingDelay time.Duration + // startTime is the dialStartTime + startTime time.Time } // dialWorker synchronises concurrent dials to a peer. It ensures that we make at most one dial to a @@ -152,6 +156,17 @@ loop: if w.s.metricsTracer != nil { w.s.metricsTracer.DialCompleted(w.connected, totalDials) } + for dialsInFlight > 0 { + res := <-w.resch + // We're recording any error as a failure here. + // Notably, this also applies to cancelations (i.e. if another dial attempt was faster). + // This is ok since the black hole detector uses a very low threshold (5%). + w.s.bhd.RecordResult(res.Addr, res.Err == nil) + if res.Conn != nil { + res.Conn.Close() + } + dialsInFlight-- + } return } // We have received a new request. If we do not have a suitable connection, @@ -303,64 +318,101 @@ loop: // Update all requests waiting on this address. On success, complete the request. // On error, record the error - dialsInFlight-- ad, ok := w.trackedDials[string(res.Addr.Bytes())] if !ok { log.Errorf("SWARM BUG: no entry for address %s in trackedDials", res.Addr) if res.Conn != nil { res.Conn.Close() } + // It is better to decrement the dials in flight and schedule one extra dial + // than risking not closing the worker loop on cleanup + dialsInFlight-- + continue + } + + if res.Kind == DialStarted { + ad.startTime = w.cl.Now() + scheduleNextDial() continue } + dialsInFlight-- + // We're recording any error as a failure here. + // Notably, this also applies to cancelations (i.e. if another dial attempt was faster). + // This is ok since the black hole detector uses a very low threshold (5%). + w.s.bhd.RecordResult(ad.addr, res.Err == nil) + if res.Conn != nil { - // we got a connection, add it to the swarm - conn, err := w.s.addConn(res.Conn, network.DirOutbound) - if err != nil { - // oops no, we failed to add it to the swarm - res.Conn.Close() - w.dispatchError(ad, err) - continue loop - } + w.handleSuccess(ad, res) + } else { + w.handleError(ad, res) + } + scheduleNextDial() + } + } +} - for pr := range w.pendingRequests { - if _, ok := pr.addrs[string(ad.addr.Bytes())]; ok { - pr.req.resch <- dialResponse{conn: conn} - delete(w.pendingRequests, pr) - } - } +func (w *dialWorker) handleSuccess(ad *addrDial, res dialResult) { + // Ensure we connected to the correct peer. + // This was most likely already checked by the security protocol, but it doesn't hurt do it again here. + if res.Conn.RemotePeer() != w.peer { + res.Conn.Close() + tpt := w.s.TransportForDialing(res.Addr) + err := fmt.Errorf("BUG in transport %T: tried to dial %s, dialed %s", w.peer, res.Conn.RemotePeer(), tpt) + log.Error(err) + w.dispatchError(ad, err) + return + } - ad.conn = conn - if !w.connected { - w.connected = true - if w.s.metricsTracer != nil { - w.s.metricsTracer.DialRankingDelay(ad.dialRankingDelay) - } - } + canonicallog.LogPeerStatus(100, res.Conn.RemotePeer(), res.Conn.RemoteMultiaddr(), "connection_status", "established", "dir", "outbound") + if w.s.metricsTracer != nil { + connWithMetrics := wrapWithMetrics(res.Conn, w.s.metricsTracer, ad.startTime, network.DirOutbound) + connWithMetrics.completedHandshake() + res.Conn = connWithMetrics + } - continue loop - } + // we got a connection, add it to the swarm + conn, err := w.s.addConn(res.Conn, network.DirOutbound) + if err != nil { + // oops no, we failed to add it to the swarm + res.Conn.Close() + w.dispatchError(ad, err) + return + } + ad.conn = conn - // it must be an error -- add backoff if applicable and dispatch - // ErrDialRefusedBlackHole shouldn't end up here, just a safety check - if res.Err != ErrDialRefusedBlackHole && res.Err != context.Canceled && !w.connected { - // we only add backoff if there has not been a successful connection - // for consistency with the old dialer behavior. - w.s.backf.AddBackoff(w.peer, res.Addr) - } else if res.Err == ErrDialRefusedBlackHole { - log.Errorf("SWARM BUG: unexpected ErrDialRefusedBlackHole while dialing peer %s to addr %s", - w.peer, res.Addr) - } + for pr := range w.pendingRequests { + if _, ok := pr.addrs[string(ad.addr.Bytes())]; ok { + pr.req.resch <- dialResponse{conn: conn} + delete(w.pendingRequests, pr) + } + } - w.dispatchError(ad, res.Err) - // Only schedule next dial on error. - // If we scheduleNextDial on success, we will end up making one dial more than - // required because the final successful dial will spawn one more dial - scheduleNextDial() + if !w.connected { + w.connected = true + if w.s.metricsTracer != nil { + w.s.metricsTracer.DialRankingDelay(ad.dialRankingDelay) } } } +func (w *dialWorker) handleError(ad *addrDial, res dialResult) { + if res.Err != nil && w.s.metricsTracer != nil { + w.s.metricsTracer.FailedDialing(res.Addr, res.Err, context.Cause(ad.ctx)) + } + // it must be an error -- add backoff if applicable and dispatch + // ErrDialRefusedBlackHole shouldn't end up here, just a safety check + if res.Err != ErrDialRefusedBlackHole && res.Err != context.Canceled && !w.connected { + // we only add backoff if there has not been a successful connection + // for consistency with the old dialer behavior. + w.s.backf.AddBackoff(w.peer, res.Addr) + } else if res.Err == ErrDialRefusedBlackHole { + log.Errorf("SWARM BUG: unexpected ErrDialRefusedBlackHole while dialing peer %s to addr %s", + w.peer, res.Addr) + } + w.dispatchError(ad, res.Err) +} + // dispatches an error to a specific addr dial func (w *dialWorker) dispatchError(ad *addrDial, err error) { ad.err = err diff --git a/p2p/net/swarm/dial_worker_test.go b/p2p/net/swarm/dial_worker_test.go index 9ea2180802..427ecc0025 100644 --- a/p2p/net/swarm/dial_worker_test.go +++ b/p2p/net/swarm/dial_worker_test.go @@ -623,9 +623,11 @@ func checkDialWorkerLoopScheduling(t *testing.T, s1, s2 *Swarm, tc schedulingTes go worker1.loop() defer worker1.wg.Wait() defer close(reqch) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // trigger the request - reqch <- dialRequest{ctx: context.Background(), resch: resch} + reqch <- dialRequest{ctx: ctx, resch: resch} connected := false @@ -989,8 +991,9 @@ func TestDialWorkerLoopHolePunching(t *testing.T) { go worker.loop() defer worker.wg.Wait() defer close(reqch) - - reqch <- dialRequest{ctx: context.Background(), resch: resch} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + reqch <- dialRequest{ctx: ctx, resch: resch} <-recvCh // received connection on t1 select { diff --git a/p2p/net/swarm/limiter.go b/p2p/net/swarm/limiter.go index ccfe7d2371..0866788589 100644 --- a/p2p/net/swarm/limiter.go +++ b/p2p/net/swarm/limiter.go @@ -8,17 +8,10 @@ import ( "time" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" ) -type dialResult struct { - Conn transport.CapableConn - Addr ma.Multiaddr - Err error -} - type dialJob struct { addr ma.Multiaddr peer peer.ID @@ -45,7 +38,7 @@ type dialLimiter struct { waitingOnPeerLimit map[peer.ID][]*dialJob } -type dialfunc func(context.Context, peer.ID, ma.Multiaddr) (transport.CapableConn, error) +type dialfunc func(context.Context, peer.ID, ma.Multiaddr, chan dialResult) func newDialLimiter(df dialfunc) *dialLimiter { fd := ConcurrentFdDials @@ -209,19 +202,9 @@ func (dl *dialLimiter) clearAllPeerDials(p peer.ID) { // it held during the dial. func (dl *dialLimiter) executeDial(j *dialJob) { defer dl.finishedDial(j) - if j.cancelled() { - return - } dctx, cancel := context.WithTimeout(j.ctx, j.timeout) defer cancel() - con, err := dl.dialFunc(dctx, j.peer, j.addr) - select { - case j.resp <- dialResult{Conn: con, Addr: j.addr, Err: err}: - case <-j.ctx.Done(): - if con != nil { - con.Close() - } - } + dl.dialFunc(dctx, j.peer, j.addr, j.resp) } diff --git a/p2p/net/swarm/limiter_test.go b/p2p/net/swarm/limiter_test.go index 046b4c754e..52fd40f9dd 100644 --- a/p2p/net/swarm/limiter_test.go +++ b/p2p/net/swarm/limiter_test.go @@ -12,7 +12,6 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/test" - "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -51,22 +50,25 @@ func tryDialAddrs(ctx context.Context, l *dialLimiter, p peer.ID, addrs []ma.Mul } func hangDialFunc(hang chan struct{}) dialfunc { - return func(ctx context.Context, p peer.ID, a ma.Multiaddr) (transport.CapableConn, error) { + return func(ctx context.Context, p peer.ID, a ma.Multiaddr, resCh chan dialResult) { if mafmt.UTP.Matches(a) { - return transport.CapableConn(nil), nil + resCh <- dialResult{} + return } _, err := a.ValueForProtocol(ma.P_CIRCUIT) if err == nil { - return transport.CapableConn(nil), nil + resCh <- dialResult{} + return } if tcpPortOver(a, 10) { - return transport.CapableConn(nil), nil + resCh <- dialResult{} + return } <-hang - return nil, fmt.Errorf("test bad dial") + resCh <- dialResult{Err: errors.New("bad dial")} } } @@ -188,16 +190,17 @@ func TestFDLimiting(t *testing.T) { func TestTokenRedistribution(t *testing.T) { var lk sync.Mutex hangchs := make(map[peer.ID]chan struct{}) - df := func(ctx context.Context, p peer.ID, a ma.Multiaddr) (transport.CapableConn, error) { + df := func(ctx context.Context, p peer.ID, a ma.Multiaddr, resCh chan dialResult) { if tcpPortOver(a, 10) { - return (transport.CapableConn)(nil), nil + resCh <- dialResult{} + return } lk.Lock() ch := hangchs[p] lk.Unlock() <-ch - return nil, fmt.Errorf("test bad dial") + resCh <- dialResult{Err: fmt.Errorf("test bad dial")} } l := newDialLimiterWithParams(df, 8, 4) @@ -281,13 +284,14 @@ func TestTokenRedistribution(t *testing.T) { } func TestStressLimiter(t *testing.T) { - df := func(ctx context.Context, p peer.ID, a ma.Multiaddr) (transport.CapableConn, error) { + df := func(ctx context.Context, p peer.ID, a ma.Multiaddr, resCh chan dialResult) { if tcpPortOver(a, 1000) { - return transport.CapableConn(nil), nil + resCh <- dialResult{} + return } time.Sleep(time.Millisecond * time.Duration(5+rand.Intn(100))) - return nil, fmt.Errorf("test bad dial") + resCh <- dialResult{Err: fmt.Errorf("test bad dial")} } l := newDialLimiterWithParams(df, 20, 5) @@ -319,7 +323,6 @@ func TestStressLimiter(t *testing.T) { for res := range resp { if res.Err == nil { success <- struct{}{} - return } } }(peer.ID(fmt.Sprintf("testpeer%d", i))) @@ -335,12 +338,12 @@ func TestStressLimiter(t *testing.T) { } func TestFDLimitUnderflow(t *testing.T) { - df := func(ctx context.Context, p peer.ID, addr ma.Multiaddr) (transport.CapableConn, error) { + df := func(ctx context.Context, p peer.ID, addr ma.Multiaddr, resCh chan dialResult) { select { case <-ctx.Done(): case <-time.After(5 * time.Second): } - return nil, fmt.Errorf("df timed out") + resCh <- dialResult{Err: fmt.Errorf("df timed out")} } const fdLimit = 20 diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index aa6e077f4b..056331718b 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -9,7 +9,6 @@ import ( "sync" "time" - "github.com/libp2p/go-libp2p/core/canonicallog" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" @@ -511,55 +510,49 @@ func (s *Swarm) limitedDial(ctx context.Context, p peer.ID, a ma.Multiaddr, resp }) } +type dialResult struct { + Kind dialUpdateKind + Conn transport.CapableConn + Addr ma.Multiaddr + Err error +} + +type dialUpdateKind int + +const ( + DialStarted dialUpdateKind = iota + DialFailed + DialSuccessful +) + // dialAddr is the actual dial for an addr, indirectly invoked through the limiter -func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr) (transport.CapableConn, error) { +func (s *Swarm) dialAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resCh chan dialResult) { // Just to double check. Costs nothing. if s.local == p { - return nil, ErrDialToSelf + resCh <- dialResult{Kind: DialFailed, Addr: addr, Err: ErrDialToSelf} + return } // Check before we start work if err := ctx.Err(); err != nil { log.Debugf("%s swarm not dialing. Context cancelled: %v. %s %s", s.local, err, p, addr) - return nil, err + resCh <- dialResult{Kind: DialFailed, Addr: addr, Err: err} + return } log.Debugf("%s swarm dialing %s %s", s.local, p, addr) tpt := s.TransportForDialing(addr) if tpt == nil { - return nil, ErrNoTransport + resCh <- dialResult{Kind: DialFailed, Addr: addr, Err: ErrNoTransport} + return } - start := time.Now() - connC, err := tpt.Dial(ctx, addr, p) - - // We're recording any error as a failure here. - // Notably, this also applies to cancelations (i.e. if another dial attempt was faster). - // This is ok since the black hole detector uses a very low threshold (5%). - s.bhd.RecordResult(addr, err == nil) - + resCh <- dialResult{Kind: DialStarted, Addr: addr} + conn, err := tpt.Dial(ctx, addr, p) if err != nil { - if s.metricsTracer != nil { - s.metricsTracer.FailedDialing(addr, err, context.Cause(ctx)) - } - return nil, err - } - canonicallog.LogPeerStatus(100, connC.RemotePeer(), connC.RemoteMultiaddr(), "connection_status", "established", "dir", "outbound") - if s.metricsTracer != nil { - connWithMetrics := wrapWithMetrics(connC, s.metricsTracer, start, network.DirOutbound) - connWithMetrics.completedHandshake() - connC = connWithMetrics - } - - // Trust the transport? Yeah... right. - if connC.RemotePeer() != p { - connC.Close() - err = fmt.Errorf("BUG in transport %T: tried to dial %s, dialed %s", p, connC.RemotePeer(), tpt) - log.Error(err) - return nil, err + resCh <- dialResult{Kind: DialFailed, Addr: addr, Conn: nil, Err: err} + return } - - // success! we got one! - return connC, nil + resCh <- dialResult{Kind: DialSuccessful, Addr: addr, Conn: conn, Err: err} } // TODO We should have a `IsFdConsuming() bool` method on the `Transport` interface in go-libp2p/core/transport. From 2dea4848d829411db6cbb6c0344b6f772a28f6b0 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 7 Aug 2023 19:14:27 +0530 Subject: [PATCH 3/4] swarm: extract method to add new request in dial worker loop --- p2p/net/swarm/dial_worker.go | 247 ++++++++++++++++------------------- 1 file changed, 115 insertions(+), 132 deletions(-) diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index f3d6f160bd..33d5a30b4e 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -82,8 +82,14 @@ type dialWorker struct { trackedDials map[string]*addrDial // resch is used to receive response for dials to the peers addresses. resch chan dialResult - - connected bool // true when a connection has been successfully established + // connected is true when a connection has been successfully established + connected bool + // dq is used to pace dials to different addresses of the peer + dq *dialQueue + // dialsInFlight are the addresses with dials pending completion. + dialsInFlight int + // totalDials is used to track number of dials made by this worker for metrics + totalDials int // for testing wg sync.WaitGroup @@ -111,12 +117,9 @@ func (w *dialWorker) loop() { w.wg.Add(1) defer w.wg.Done() defer w.s.limiter.clearAllPeerDials(w.peer) + defer w.cleanup() - // dq is used to pace dials to different addresses of the peer - dq := newDialQueue() - // dialsInFlight is the number of dials in flight. - dialsInFlight := 0 - + w.dq = newDialQueue() startTime := w.cl.Now() // dialTimer is the dialTimer used to trigger dials dialTimer := w.cl.InstantTimer(startTime.Add(math.MaxInt64)) @@ -127,24 +130,22 @@ func (w *dialWorker) loop() { <-dialTimer.Ch() } timerRunning = false - if dq.Len() > 0 { - if dialsInFlight == 0 && !w.connected { + if w.dq.Len() > 0 { + if w.dialsInFlight == 0 && !w.connected { // if there are no dials in flight, trigger the next dials immediately dialTimer.Reset(startTime) } else { - dialTimer.Reset(startTime.Add(dq.top().Delay)) + dialTimer.Reset(startTime.Add(w.dq.top().Delay)) } timerRunning = true } } - // totalDials is used to track number of dials made by this worker for metrics - totalDials := 0 loop: for { // The loop has three parts // 1. Input requests are received on w.reqch. If a suitable connection is not available we create - // a pendRequest object to track the dialRequest and add the addresses to dq. + // a pendRequest object to track the dialRequest and add the addresses to w.dq. // 2. Addresses from the dialQueue are dialed at appropriate time intervals depending on delay logic. // We are notified of the completion of these dials on w.resch. // 3. Responses for dials are received on w.resch. On receiving a response, we updated the pendRequests @@ -153,20 +154,6 @@ loop: select { case req, ok := <-w.reqch: if !ok { - if w.s.metricsTracer != nil { - w.s.metricsTracer.DialCompleted(w.connected, totalDials) - } - for dialsInFlight > 0 { - res := <-w.resch - // We're recording any error as a failure here. - // Notably, this also applies to cancelations (i.e. if another dial attempt was faster). - // This is ok since the black hole detector uses a very low threshold (5%). - w.s.bhd.RecordResult(res.Addr, res.Err == nil) - if res.Conn != nil { - res.Conn.Close() - } - dialsInFlight-- - } return } // We have received a new request. If we do not have a suitable connection, @@ -182,104 +169,10 @@ loop: addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer) if err != nil { - req.resch <- dialResponse{ - err: &DialError{ - Peer: w.peer, - DialErrors: addrErrs, - Cause: err, - }} - continue loop - } - - // get the delays to dial these addrs from the swarms dialRanker - simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) - addrRanking := w.rankAddrs(addrs, simConnect) - addrDelay := make(map[string]time.Duration, len(addrRanking)) - - // create the pending request object - pr := &pendRequest{ - req: req, - addrs: make(map[string]struct{}, len(addrRanking)), - err: &DialError{Peer: w.peer, DialErrors: addrErrs}, - } - for _, adelay := range addrRanking { - pr.addrs[string(adelay.Addr.Bytes())] = struct{}{} - addrDelay[string(adelay.Addr.Bytes())] = adelay.Delay - } - - // Check if dials to any of the addrs have completed already - // If they have errored, record the error in pr. If they have succeeded, - // respond with the connection. - // If they are pending, add them to tojoin. - // If we haven't seen any of the addresses before, add them to todial. - var todial []ma.Multiaddr - var tojoin []*addrDial - - for _, adelay := range addrRanking { - ad, ok := w.trackedDials[string(adelay.Addr.Bytes())] - if !ok { - todial = append(todial, adelay.Addr) - continue - } - - if ad.conn != nil { - // dial to this addr was successful, complete the request - req.resch <- dialResponse{conn: ad.conn} - continue loop - } - - if ad.err != nil { - // dial to this addr errored, accumulate the error - pr.err.recordErr(ad.addr, ad.err) - delete(pr.addrs, string(ad.addr.Bytes())) - continue - } - - // dial is still pending, add to the join list - tojoin = append(tojoin, ad) - } - - if len(todial) == 0 && len(tojoin) == 0 { - // all request applicable addrs have been dialed, we must have errored - pr.err.Cause = ErrAllDialsFailed - req.resch <- dialResponse{err: pr.err} + req.resch <- dialResponse{err: &DialError{Peer: w.peer, DialErrors: addrErrs, Cause: err}} continue loop } - - // The request has some pending or new dials - w.pendingRequests[pr] = struct{}{} - - for _, ad := range tojoin { - if !ad.dialed { - // we haven't dialed this address. update the ad.ctx to have simultaneous connect values - // set correctly - if simConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); simConnect { - if simConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !simConnect { - ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) - // update the element in dq to use the simultaneous connect delay. - dq.Add(network.AddrDelay{ - Addr: ad.addr, - Delay: addrDelay[string(ad.addr.Bytes())], - }) - } - } - } - // add the request to the addrDial - } - - if len(todial) > 0 { - now := time.Now() - // these are new addresses, track them and add them to dq - for _, a := range todial { - w.trackedDials[string(a.Bytes())] = &addrDial{ - addr: a, - ctx: req.ctx, - createdAt: now, - } - dq.Add(network.AddrDelay{Addr: a, Delay: addrDelay[string(a.Bytes())]}) - } - } - // setup dialTimer for updates to dq + w.addNewRequest(req, addrs, addrErrs) scheduleNextDial() case <-dialTimer.Ch(): @@ -289,7 +182,7 @@ loop: // the inflight dials have errored and we should dial the next batch of // addresses now := time.Now() - for _, adelay := range dq.NextBatch() { + for _, adelay := range w.dq.NextBatch() { // spawn the dial ad, ok := w.trackedDials[string(adelay.Addr.Bytes())] if !ok { @@ -300,13 +193,11 @@ loop: ad.dialRankingDelay = now.Sub(ad.createdAt) err := w.s.dialNextAddr(ad.ctx, w.peer, ad.addr, w.resch) if err != nil { - // Errored without attempting a dial. This happens in case of - // backoff or black hole. + // Errored without attempting a dial. This happens in case of backoff. w.dispatchError(ad, err) } else { - // the dial was successful. update inflight dials - dialsInFlight++ - totalDials++ + w.dialsInFlight++ + w.totalDials++ } } timerRunning = false @@ -326,7 +217,7 @@ loop: } // It is better to decrement the dials in flight and schedule one extra dial // than risking not closing the worker loop on cleanup - dialsInFlight-- + w.dialsInFlight-- continue } @@ -336,7 +227,7 @@ loop: continue } - dialsInFlight-- + w.dialsInFlight-- // We're recording any error as a failure here. // Notably, this also applies to cancelations (i.e. if another dial attempt was faster). // This is ok since the black hole detector uses a very low threshold (5%). @@ -352,6 +243,80 @@ loop: } } +// addNewRequest adds a new dial request to the worker loop. If the request has no pending dials, a response +// is sent immediately otherwise it is tracked in pendingRequests +func (w *dialWorker) addNewRequest(req dialRequest, addrs []ma.Multiaddr, addrErrs []TransportError) { + // check if a dial to any of the addrs has succeeded already + for _, addr := range addrs { + if ad, ok := w.trackedDials[string(addr.Bytes())]; ok { + if ad.conn != nil { + // dial to this addr was successful, complete the request + req.resch <- dialResponse{conn: ad.conn} + } + } + } + + // get the delays to dial these addrs from the swarms dialRanker + simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) + addrRanking := w.rankAddrs(addrs, simConnect) + + // create the pending request object + pr := &pendRequest{ + req: req, + err: &DialError{Peer: w.peer, DialErrors: addrErrs}, + addrs: make(map[string]struct{}, len(addrRanking)), + } + for _, adelay := range addrRanking { + pr.addrs[string(adelay.Addr.Bytes())] = struct{}{} + } + + for _, adelay := range addrRanking { + ad, ok := w.trackedDials[string(adelay.Addr.Bytes())] + if !ok { + // new address, track and enqueue + now := time.Now() + w.trackedDials[string(adelay.Addr.Bytes())] = &addrDial{ + addr: adelay.Addr, + ctx: req.ctx, + createdAt: now, + } + w.dq.Add(network.AddrDelay{Addr: adelay.Addr, Delay: adelay.Delay}) + continue + } + + if ad.err != nil { + // dial to this addr errored, accumulate the error + pr.err.recordErr(ad.addr, ad.err) + delete(pr.addrs, string(ad.addr.Bytes())) + continue + } + + if !ad.dialed { + // we haven't dialed this address. update the ad.ctx to have simultaneous connect values + // set correctly + if isSimConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); isSimConnect { + if wasSimConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !wasSimConnect { + ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) + // update the element in dq to use the simultaneous connect delay. + w.dq.Add(network.AddrDelay{ + Addr: ad.addr, + Delay: adelay.Delay, + }) + } + } + } + } + + if len(pr.addrs) == 0 { + // all request applicable addrs have been dialed, we must have errored + pr.err.Cause = ErrAllDialsFailed + req.resch <- dialResponse{err: pr.err} + } else { + // The request has some pending or new dials + w.pendingRequests[pr] = struct{}{} + } +} + func (w *dialWorker) handleSuccess(ad *addrDial, res dialResult) { // Ensure we connected to the correct peer. // This was most likely already checked by the security protocol, but it doesn't hurt do it again here. @@ -400,7 +365,7 @@ func (w *dialWorker) handleError(ad *addrDial, res dialResult) { if res.Err != nil && w.s.metricsTracer != nil { w.s.metricsTracer.FailedDialing(res.Addr, res.Err, context.Cause(ad.ctx)) } - // it must be an error -- add backoff if applicable and dispatch + // add backoff if applicable and dispatch // ErrDialRefusedBlackHole shouldn't end up here, just a safety check if res.Err != ErrDialRefusedBlackHole && res.Err != context.Canceled && !w.connected { // we only add backoff if there has not been a successful connection @@ -455,6 +420,24 @@ func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr, isSimConnect bool) []networ return w.s.dialRanker(addrs) } +// cleanup is called on workerloop close +func (w *dialWorker) cleanup() { + if w.s.metricsTracer != nil { + w.s.metricsTracer.DialCompleted(w.connected, w.totalDials) + } + for w.dialsInFlight > 0 { + res := <-w.resch + // We're recording any error as a failure here. + // Notably, this also applies to cancelations (i.e. if another dial attempt was faster). + // This is ok since the black hole detector uses a very low threshold (5%). + w.s.bhd.RecordResult(res.Addr, res.Err == nil) + if res.Conn != nil { + res.Conn.Close() + } + w.dialsInFlight-- + } +} + // dialQueue is a priority queue used to schedule dials type dialQueue struct { // q contains dials ordered by delay From b1c30cfe098dcf3461491fe319e2ea09a50dae25 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 7 Aug 2023 20:00:06 +0530 Subject: [PATCH 4/4] use a map to accurately track dials in flight --- p2p/net/swarm/dial_worker.go | 64 ++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 33d5a30b4e..1afe3ddc9c 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -86,8 +86,9 @@ type dialWorker struct { connected bool // dq is used to pace dials to different addresses of the peer dq *dialQueue - // dialsInFlight are the addresses with dials pending completion. - dialsInFlight int + // dialsInFlight are the addresses with dials pending completion. We use this to schedule new dials + // and to cleanup all pending dials when closing the loop + dialsInFlight map[string]bool // totalDials is used to track number of dials made by this worker for metrics totalDials int @@ -107,6 +108,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest, cl Clock) *dia pendingRequests: make(map[*pendRequest]struct{}), trackedDials: make(map[string]*addrDial), resch: make(chan dialResult), + dialsInFlight: make(map[string]bool), cl: cl, } } @@ -131,7 +133,7 @@ func (w *dialWorker) loop() { } timerRunning = false if w.dq.Len() > 0 { - if w.dialsInFlight == 0 && !w.connected { + if len(w.dialsInFlight) == 0 && !w.connected { // if there are no dials in flight, trigger the next dials immediately dialTimer.Reset(startTime) } else { @@ -156,23 +158,14 @@ loop: if !ok { return } - // We have received a new request. If we do not have a suitable connection, - // track this dialRequest with a pendRequest. - // Enqueue the peer's addresses relevant to this request in dq and - // track dials to the addresses relevant to this request. - + // Check if we have a suitable connection already c, err := w.s.bestAcceptableConnToPeer(req.ctx, w.peer) if c != nil || err != nil { req.resch <- dialResponse{conn: c, err: err} continue loop } - - addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer) - if err != nil { - req.resch <- dialResponse{err: &DialError{Peer: w.peer, DialErrors: addrErrs, Cause: err}} - continue loop - } - w.addNewRequest(req, addrs, addrErrs) + // we don't have any suitable connection, add the request to the worker loop + w.addNewRequest(req) scheduleNextDial() case <-dialTimer.Ch(): @@ -196,7 +189,7 @@ loop: // Errored without attempting a dial. This happens in case of backoff. w.dispatchError(ad, err) } else { - w.dialsInFlight++ + w.dialsInFlight[string(ad.addr.Bytes())] = true w.totalDials++ } } @@ -215,9 +208,7 @@ loop: if res.Conn != nil { res.Conn.Close() } - // It is better to decrement the dials in flight and schedule one extra dial - // than risking not closing the worker loop on cleanup - w.dialsInFlight-- + delete(w.dialsInFlight, string(res.Addr.Bytes())) continue } @@ -227,7 +218,7 @@ loop: continue } - w.dialsInFlight-- + delete(w.dialsInFlight, string(res.Addr.Bytes())) // We're recording any error as a failure here. // Notably, this also applies to cancelations (i.e. if another dial attempt was faster). // This is ok since the black hole detector uses a very low threshold (5%). @@ -243,24 +234,29 @@ loop: } } -// addNewRequest adds a new dial request to the worker loop. If the request has no pending dials, a response -// is sent immediately otherwise it is tracked in pendingRequests -func (w *dialWorker) addNewRequest(req dialRequest, addrs []ma.Multiaddr, addrErrs []TransportError) { - // check if a dial to any of the addrs has succeeded already +// addNewRequest adds a new dial request to the worker loop. If the request has a valid connection or all relevant +// dials have failed, the request is handled immediately, otherwise it is added to pendingRequests. +func (w *dialWorker) addNewRequest(req dialRequest) { + addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer) + if err != nil { + req.resch <- dialResponse{err: &DialError{Peer: w.peer, DialErrors: addrErrs, Cause: err}} + return + } + + // check if a dial to any of the relevant address has succeeded already for _, addr := range addrs { if ad, ok := w.trackedDials[string(addr.Bytes())]; ok { if ad.conn != nil { - // dial to this addr was successful, complete the request req.resch <- dialResponse{conn: ad.conn} + return } } } - // get the delays to dial these addrs from the swarms dialRanker + // no dial has succeeded, get the delays to dial the addrs simConnect, _, _ := network.GetSimultaneousConnect(req.ctx) addrRanking := w.rankAddrs(addrs, simConnect) - // create the pending request object pr := &pendRequest{ req: req, err: &DialError{Peer: w.peer, DialErrors: addrErrs}, @@ -273,7 +269,7 @@ func (w *dialWorker) addNewRequest(req dialRequest, addrs []ma.Multiaddr, addrEr for _, adelay := range addrRanking { ad, ok := w.trackedDials[string(adelay.Addr.Bytes())] if !ok { - // new address, track and enqueue + // new address, track and enqueue for dialing now := time.Now() w.trackedDials[string(adelay.Addr.Bytes())] = &addrDial{ addr: adelay.Addr, @@ -292,8 +288,9 @@ func (w *dialWorker) addNewRequest(req dialRequest, addrs []ma.Multiaddr, addrEr } if !ad.dialed { - // we haven't dialed this address. update the ad.ctx to have simultaneous connect values - // set correctly + // We are tracking a dial to this address but we haven't dialled it already. + // If the new request is a holepunching request, update the context and the element in the + // dial queue if isSimConnect, isClient, reason := network.GetSimultaneousConnect(req.ctx); isSimConnect { if wasSimConnect, _, _ := network.GetSimultaneousConnect(ad.ctx); !wasSimConnect { ad.ctx = network.WithSimultaneousConnect(ad.ctx, isClient, reason) @@ -425,8 +422,11 @@ func (w *dialWorker) cleanup() { if w.s.metricsTracer != nil { w.s.metricsTracer.DialCompleted(w.connected, w.totalDials) } - for w.dialsInFlight > 0 { + for len(w.dialsInFlight) > 0 { res := <-w.resch + if res.Kind != DialFailed && res.Kind != DialSuccessful { + continue + } // We're recording any error as a failure here. // Notably, this also applies to cancelations (i.e. if another dial attempt was faster). // This is ok since the black hole detector uses a very low threshold (5%). @@ -434,7 +434,7 @@ func (w *dialWorker) cleanup() { if res.Conn != nil { res.Conn.Close() } - w.dialsInFlight-- + delete(w.dialsInFlight, string(res.Addr.Bytes())) } }