Skip to content

Commit

Permalink
swarm: change maps with multiaddress keys to use strings
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed May 11, 2023
1 parent f7a45b6 commit 417b8cc
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 14 deletions.
28 changes: 14 additions & 14 deletions p2p/net/swarm/dial_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ type dialResponse struct {
}

type pendRequest struct {
req dialRequest // the original request
err *DialError // dial error accumulator
addrs map[ma.Multiaddr]struct{} // pending addr dials
req dialRequest // the original request
err *DialError // dial error accumulator
addrs map[string]struct{} // pending address to dial. The key is a multiaddr
}

type addrDial struct {
Expand All @@ -46,7 +46,7 @@ type dialWorker struct {
reqch <-chan dialRequest
reqno int
requests map[int]*pendRequest
pending map[ma.Multiaddr]*addrDial
pending map[string]*addrDial // pending addresses to dial. The key is a multiaddr
resch chan dialResult

connected bool // true when a connection has been successfully established
Expand All @@ -66,7 +66,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest) *dialWorker {
peer: p,
reqch: reqch,
requests: make(map[int]*pendRequest),
pending: make(map[ma.Multiaddr]*addrDial),
pending: make(map[string]*addrDial),
resch: make(chan dialResult),
}
}
Expand Down Expand Up @@ -108,10 +108,10 @@ loop:
pr := &pendRequest{
req: req,
err: &DialError{Peer: w.peer},
addrs: make(map[ma.Multiaddr]struct{}),
addrs: make(map[string]struct{}),
}
for _, a := range addrs {
pr.addrs[a] = struct{}{}
pr.addrs[string(a.Bytes())] = struct{}{}
}

// check if any of the addrs has been successfully dialed and accumulate
Expand All @@ -120,7 +120,7 @@ loop:
var tojoin []*addrDial

for _, a := range addrs {
ad, ok := w.pending[a]
ad, ok := w.pending[string(a.Bytes())]
if !ok {
todial = append(todial, a)
continue
Expand All @@ -135,7 +135,7 @@ loop:
if ad.err != nil {
// dial to this addr errored, accumulate the error
pr.err.recordErr(a, ad.err)
delete(pr.addrs, a)
delete(pr.addrs, string(a.Bytes()))
continue
}

Expand Down Expand Up @@ -164,7 +164,7 @@ loop:

if len(todial) > 0 {
for _, a := range todial {
w.pending[a] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}}
w.pending[string(a.Bytes())] = &addrDial{addr: a, ctx: req.ctx, requests: []int{w.reqno}}
}

w.nextDial = append(w.nextDial, todial...)
Expand All @@ -177,7 +177,7 @@ loop:
case <-w.triggerDial:
for _, addr := range w.nextDial {
// spawn the dial
ad := w.pending[addr]
ad := w.pending[string(addr.Bytes())]
err := w.s.dialNextAddr(ad.ctx, w.peer, addr, w.resch)
if err != nil {
w.dispatchError(ad, err)
Expand All @@ -192,7 +192,7 @@ loop:
w.connected = true
}

ad := w.pending[res.Addr]
ad := w.pending[string(res.Addr.Bytes())]

if res.Conn != nil {
// we got a connection, add it to the swarm
Expand Down Expand Up @@ -247,7 +247,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) {
// accumulate the error
pr.err.recordErr(ad.addr, err)

delete(pr.addrs, ad.addr)
delete(pr.addrs, string(ad.addr.Bytes()))
if len(pr.addrs) == 0 {
// all addrs have erred, dispatch dial error
// but first do a last one check in case an acceptable connection has landed from
Expand All @@ -271,7 +271,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) {
// it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff
// regresses without this.
if err == ErrDialBackoff {
delete(w.pending, ad.addr)
delete(w.pending, string(ad.addr.Bytes()))
}
}

Expand Down
68 changes: 68 additions & 0 deletions p2p/net/swarm/dial_worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -24,6 +25,7 @@ import (
"github.com/libp2p/go-libp2p/p2p/transport/tcp"

ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -342,3 +344,69 @@ func TestDialWorkerLoopConcurrentFailureStress(t *testing.T) {
close(reqch)
worker.wg.Wait()
}

func TestDialWorkerLoopAddrDedup(t *testing.T) {
s1 := makeSwarm(t)
s2 := makeSwarm(t)
defer s1.Close()
defer s2.Close()

t1 := ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 10000))
t2 := ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%d", 10000))

// connCount counts the number of connection attempts made
var connCount atomic.Int64

// acceptAndClose accepts two tcp connections and closes them
// we need to wait for the second connection before failing because otherwise the
// address would be placed on backoff
acceptAndClose := func(a ma.Multiaddr, closech chan struct{}) {
list, err := manet.Listen(a)
if err != nil {
t.Error(err)
return
}
go func() {
conn1, err := list.Accept()
if err != nil {
return
}
connCount.Add(1)
conn2, err := list.Accept()
if err != nil {
conn1.Close()
return
}
connCount.Add(2)
conn1.Close()
conn2.Close()
}()
<-closech
list.Close()
}
closeCh := make(chan struct{})
go acceptAndClose(t1, closeCh)
defer close(closeCh)

s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{t1}, peerstore.PermanentAddrTTL)

reqch := make(chan dialRequest)
resch := make(chan dialResponse, 2)

worker := newDialWorker(s1, s2.LocalPeer(), reqch)
go worker.loop()
defer worker.wg.Wait()
defer close(reqch)

reqch <- dialRequest{ctx: context.Background(), resch: resch}

s1.Peerstore().ClearAddrs(s2.LocalPeer())
s1.Peerstore().AddAddrs(s2.LocalPeer(), []ma.Multiaddr{t2}, peerstore.PermanentAddrTTL)

reqch <- dialRequest{ctx: context.Background(), resch: resch}
require.Never(t, func() bool { return connCount.Load() > 1 }, 3*time.Second, 100*time.Millisecond)
if connCount.Load() != 1 {
t.Errorf("did expect one connection. got 0")
}

}

0 comments on commit 417b8cc

Please sign in to comment.