From bbd28365c5d7d86b259403a379fd34a8279e4b4b Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Fri, 19 Aug 2022 00:33:22 -0700 Subject: [PATCH] fix: don't prefer local ports from other addresses when dialing (#1673) * fix: don't prefer local ports from other addresses when dialing This address may already be in-use (on that other address) somewhere else. Thanks to @schomatis for figuring this out. fixes #1611 * chore: document reuseport dialer logic --- p2p/net/reuseport/dial.go | 57 +------------- p2p/net/reuseport/dialer.go | 114 ++++++++++++++++++++++++++++ p2p/net/reuseport/multidialer.go | 90 ---------------------- p2p/net/reuseport/singledialer.go | 16 ---- p2p/net/reuseport/transport.go | 2 +- p2p/net/reuseport/transport_test.go | 8 -- 6 files changed, 118 insertions(+), 169 deletions(-) create mode 100644 p2p/net/reuseport/dialer.go delete mode 100644 p2p/net/reuseport/multidialer.go delete mode 100644 p2p/net/reuseport/singledialer.go diff --git a/p2p/net/reuseport/dial.go b/p2p/net/reuseport/dial.go index b998be7d29..6a3d18ff21 100644 --- a/p2p/net/reuseport/dial.go +++ b/p2p/net/reuseport/dial.go @@ -2,18 +2,11 @@ package reuseport import ( "context" - "net" - "github.com/libp2p/go-reuseport" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) -type dialer interface { - Dial(network, addr string) (net.Conn, error) - DialContext(ctx context.Context, network, addr string) (net.Conn, error) -} - // Dial dials the given multiaddr, reusing ports we're currently listening on if // possible. // @@ -31,7 +24,7 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet. if err != nil { return nil, err } - var d dialer + var d *dialer switch network { case "tcp4": d = t.v4.getDialer(network) @@ -52,7 +45,7 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet. return maconn, nil } -func (n *network) getDialer(network string) dialer { +func (n *network) getDialer(network string) *dialer { n.mu.RLock() d := n.dialer n.mu.RUnlock() @@ -61,53 +54,9 @@ func (n *network) getDialer(network string) dialer { defer n.mu.Unlock() if n.dialer == nil { - n.dialer = n.makeDialer(network) + n.dialer = newDialer(n.listeners) } d = n.dialer } return d } - -func (n *network) makeDialer(network string) dialer { - if !reuseport.Available() { - log.Debug("reuseport not available") - return &net.Dialer{} - } - - var unspec net.IP - switch network { - case "tcp4": - unspec = net.IPv4zero - case "tcp6": - unspec = net.IPv6unspecified - default: - panic("invalid network: must be either tcp4 or tcp6") - } - - // How many ports are we listening on. - var port = 0 - for l := range n.listeners { - newPort := l.Addr().(*net.TCPAddr).Port - switch { - case newPort == 0: // Any port, ignore (really, we shouldn't get this case...). - case port == 0: // Haven't selected a port yet, choose this one. - port = newPort - case newPort == port: // Same as the selected port, continue... - default: // Multiple ports, use the multi dialer - return newMultiDialer(unspec, n.listeners) - } - } - - // None. - if port == 0 { - return &net.Dialer{} - } - - // One. Always dial from the single port we're listening on. - laddr := &net.TCPAddr{ - IP: unspec, - Port: port, - } - - return (*singleDialer)(laddr) -} diff --git a/p2p/net/reuseport/dialer.go b/p2p/net/reuseport/dialer.go new file mode 100644 index 0000000000..2efc02d393 --- /dev/null +++ b/p2p/net/reuseport/dialer.go @@ -0,0 +1,114 @@ +package reuseport + +import ( + "context" + "fmt" + "math/rand" + "net" + + "github.com/libp2p/go-netroute" +) + +type dialer struct { + // All address that are _not_ loopback or unspecified (0.0.0.0 or ::). + specific []*net.TCPAddr + // All loopback addresses (127.*.*.*, ::1). + loopback []*net.TCPAddr + // Unspecified addresses (0.0.0.0, ::) + unspecified []*net.TCPAddr +} + +func (d *dialer) Dial(network, addr string) (net.Conn, error) { + return d.DialContext(context.Background(), network, addr) +} + +func randAddr(addrs []*net.TCPAddr) *net.TCPAddr { + if len(addrs) > 0 { + return addrs[rand.Intn(len(addrs))] + } + return nil +} + +// DialContext dials a target addr. +// +// In-order: +// +// 1. If we're _explicitly_ listening on the prefered source address for the destination address +// (per the system's routes), we'll use that listener's port as the source port. +// 2. If we're listening on one or more _unspecified_ addresses (zero address), we'll pick a source +// port from one of these listener's. +// 3. Otherwise, we'll let the system pick the source port. +func (d *dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + // We only check this case if the user is listening on a specific address (loopback or + // otherwise). Generally, users will listen on the "unspecified" address (0.0.0.0 or ::) and + // we can skip this section. + // + // This lets us avoid resolving the address twice, in most cases. + if len(d.specific) > 0 || len(d.loopback) > 0 { + tcpAddr, err := net.ResolveTCPAddr(network, addr) + if err != nil { + return nil, err + } + ip := tcpAddr.IP + if !ip.IsLoopback() && !ip.IsGlobalUnicast() { + return nil, fmt.Errorf("undialable IP: %s", ip) + } + + // If we're listening on some specific address and that specific address happens to + // be the preferred source address for the target destination address, we try to + // dial with that address/port. + // + // We skip this check if we _aren't_ listening on any specific addresses, because + // checking routing tables can be expensive and users rarely listen on specific IP + // addresses. + if len(d.specific) > 0 { + if router, err := netroute.New(); err == nil { + if _, _, preferredSrc, err := router.Route(ip); err == nil { + for _, optAddr := range d.specific { + if optAddr.IP.Equal(preferredSrc) { + return reuseDial(ctx, optAddr, network, addr) + } + } + } + } + } + + // Otherwise, if we are listening on a loopback address and the destination is also + // a loopback address, use the port from our loopback listener. + if len(d.loopback) > 0 && ip.IsLoopback() { + return reuseDial(ctx, randAddr(d.loopback), network, addr) + } + } + + // If we're listening on any uspecified addresses, use a randomly chosen port from one of + // these listeners. + if len(d.unspecified) > 0 { + return reuseDial(ctx, randAddr(d.unspecified), network, addr) + } + + // Finally, just pick a random port. + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) +} + +func newDialer(listeners map[*listener]struct{}) *dialer { + specific := make([]*net.TCPAddr, 0) + loopback := make([]*net.TCPAddr, 0) + unspecified := make([]*net.TCPAddr, 0) + + for l := range listeners { + addr := l.Addr().(*net.TCPAddr) + if addr.IP.IsLoopback() { + loopback = append(loopback, addr) + } else if addr.IP.IsUnspecified() { + unspecified = append(unspecified, addr) + } else { + specific = append(specific, addr) + } + } + return &dialer{ + specific: specific, + loopback: loopback, + unspecified: unspecified, + } +} diff --git a/p2p/net/reuseport/multidialer.go b/p2p/net/reuseport/multidialer.go deleted file mode 100644 index a3b5e2e99f..0000000000 --- a/p2p/net/reuseport/multidialer.go +++ /dev/null @@ -1,90 +0,0 @@ -package reuseport - -import ( - "context" - "fmt" - "math/rand" - "net" - - "github.com/libp2p/go-netroute" -) - -type multiDialer struct { - listeningAddresses []*net.TCPAddr - loopback []*net.TCPAddr - unspecified []*net.TCPAddr - fallback net.TCPAddr -} - -func (d *multiDialer) Dial(network, addr string) (net.Conn, error) { - return d.DialContext(context.Background(), network, addr) -} - -func randAddr(addrs []*net.TCPAddr) *net.TCPAddr { - if len(addrs) > 0 { - return addrs[rand.Intn(len(addrs))] - } - return nil -} - -// DialContext dials a target addr. -// Dialing preference is -// * If there is a listener on the local interface the OS expects to use to route towards addr, use that. -// * If there is a listener on a loopback address, addr is loopback, use that. -// * If there is a listener on an undefined address (0.0.0.0 or ::), use that. -// * Use the fallback IP specified during construction, with a port that's already being listened on, if one exists. -func (d *multiDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { - tcpAddr, err := net.ResolveTCPAddr(network, addr) - if err != nil { - return nil, err - } - ip := tcpAddr.IP - if !ip.IsLoopback() && !ip.IsGlobalUnicast() { - return nil, fmt.Errorf("undialable IP: %s", ip) - } - - if router, err := netroute.New(); err == nil { - if _, _, preferredSrc, err := router.Route(ip); err == nil { - for _, optAddr := range d.listeningAddresses { - if optAddr.IP.Equal(preferredSrc) { - return reuseDial(ctx, optAddr, network, addr) - } - } - } - } - - if ip.IsLoopback() && len(d.loopback) > 0 { - return reuseDial(ctx, randAddr(d.loopback), network, addr) - } - if len(d.unspecified) == 0 { - return reuseDial(ctx, &d.fallback, network, addr) - } - - return reuseDial(ctx, randAddr(d.unspecified), network, addr) -} - -func newMultiDialer(unspec net.IP, listeners map[*listener]struct{}) (m dialer) { - addrs := make([]*net.TCPAddr, 0) - loopback := make([]*net.TCPAddr, 0) - unspecified := make([]*net.TCPAddr, 0) - existingPort := 0 - - for l := range listeners { - addr := l.Addr().(*net.TCPAddr) - addrs = append(addrs, addr) - if addr.IP.IsLoopback() { - loopback = append(loopback, addr) - } else if addr.IP.IsGlobalUnicast() && existingPort == 0 { - existingPort = addr.Port - } else if addr.IP.IsUnspecified() { - unspecified = append(unspecified, addr) - } - } - m = &multiDialer{ - listeningAddresses: addrs, - loopback: loopback, - unspecified: unspecified, - fallback: net.TCPAddr{IP: unspec, Port: existingPort}, - } - return -} diff --git a/p2p/net/reuseport/singledialer.go b/p2p/net/reuseport/singledialer.go deleted file mode 100644 index b15dae80b9..0000000000 --- a/p2p/net/reuseport/singledialer.go +++ /dev/null @@ -1,16 +0,0 @@ -package reuseport - -import ( - "context" - "net" -) - -type singleDialer net.TCPAddr - -func (d *singleDialer) Dial(network, address string) (net.Conn, error) { - return d.DialContext(context.Background(), network, address) -} - -func (d *singleDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return reuseDial(ctx, (*net.TCPAddr)(d), network, address) -} diff --git a/p2p/net/reuseport/transport.go b/p2p/net/reuseport/transport.go index 37fb446cb7..ba7b9debf7 100644 --- a/p2p/net/reuseport/transport.go +++ b/p2p/net/reuseport/transport.go @@ -31,5 +31,5 @@ type Transport struct { type network struct { mu sync.RWMutex listeners map[*listener]struct{} - dialer dialer + dialer *dialer } diff --git a/p2p/net/reuseport/transport_test.go b/p2p/net/reuseport/transport_test.go index b99b583824..88f9cdb98f 100644 --- a/p2p/net/reuseport/transport_test.go +++ b/p2p/net/reuseport/transport_test.go @@ -141,7 +141,6 @@ func TestGlobalPreferenceV4(t *testing.T) { testPrefer(t, loopbackV4, loopbackV4, globalV4) t.Logf("when listening on %v, should prefer %v over %v", loopbackV4, unspecV4, globalV4) testPrefer(t, loopbackV4, unspecV4, globalV4) - t.Logf("when listening on %v, should prefer %v over %v", globalV4, unspecV4, loopbackV4) testPrefer(t, globalV4, unspecV4, loopbackV4) } @@ -177,8 +176,6 @@ func testPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) { } defer listenerB1.Close() - dialOne(t, &trB, listenerA, listenerB1.Addr().(*net.TCPAddr).Port) - listenerB2, err := trB.Listen(prefer) if err != nil { t.Fatal(err) @@ -186,11 +183,6 @@ func testPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) { defer listenerB2.Close() dialOne(t, &trB, listenerA, listenerB2.Addr().(*net.TCPAddr).Port) - - // Closing the listener should reset the dialer. - listenerB2.Close() - - dialOne(t, &trB, listenerA, listenerB1.Addr().(*net.TCPAddr).Port) } func TestV6V4(t *testing.T) {