From 1c8eaabfd385346a7c41b988e2dbc2e20ddfa460 Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Thu, 1 Dec 2022 14:06:13 -0800 Subject: [PATCH] transport.Listener,quic: Support multiple QUIC versions with the same Listener. Only return a single multiaddr per listener. (#1923) * Revert "transport.Listener returns a list of multiaddrs" This reverts commit 8962b2ae336d94627f2f4361f96799ee3a5bd9e4. * Support multiple QUIC versions on the same listener * No long running accept loop * Don't use a goroutine * PR comments --- core/transport/transport.go | 2 +- p2p/host/autonat/dialpolicy_test.go | 6 +- p2p/net/swarm/swarm_addr.go | 2 +- p2p/net/swarm/swarm_listen.go | 23 +-- p2p/net/swarm/swarm_test.go | 2 +- p2p/net/upgrader/listener.go | 5 - p2p/net/upgrader/listener_test.go | 36 ++--- p2p/net/upgrader/upgrader_test.go | 10 +- p2p/transport/quic/cmd/server/main.go | 2 +- p2p/transport/quic/conn_test.go | 58 ++++--- p2p/transport/quic/listener.go | 15 +- p2p/transport/quic/listener_test.go | 40 ++++- p2p/transport/quic/transport.go | 83 +++++++++- p2p/transport/quic/virtuallistener.go | 158 +++++++++++++++++++ p2p/transport/tcp/tcp_test.go | 12 +- p2p/transport/testsuite/stream_suite.go | 6 +- p2p/transport/testsuite/transport_suite.go | 10 +- p2p/transport/websocket/listener.go | 4 - p2p/transport/websocket/websocket_test.go | 16 +- p2p/transport/webtransport/listener.go | 6 +- p2p/transport/webtransport/transport_test.go | 69 ++++---- 21 files changed, 404 insertions(+), 161 deletions(-) create mode 100644 p2p/transport/quic/virtuallistener.go diff --git a/core/transport/transport.go b/core/transport/transport.go index 2a7537738c..ad2ee66496 100644 --- a/core/transport/transport.go +++ b/core/transport/transport.go @@ -91,7 +91,7 @@ type Listener interface { Accept() (CapableConn, error) Close() error Addr() net.Addr - Multiaddrs() []ma.Multiaddr + Multiaddr() ma.Multiaddr } // TransportNetwork is an inet.Network with methods for managing transports. diff --git a/p2p/host/autonat/dialpolicy_test.go b/p2p/host/autonat/dialpolicy_test.go index e56a862b67..75731ae9cf 100644 --- a/p2p/host/autonat/dialpolicy_test.go +++ b/p2p/host/autonat/dialpolicy_test.go @@ -47,9 +47,9 @@ func (l *mockL) Accept() (transport.CapableConn, error) { <-l.ctx.Done() return nil, errors.New("expected in mocked test") } -func (l *mockL) Close() error { return nil } -func (l *mockL) Addr() net.Addr { return nil } -func (l *mockL) Multiaddrs() []multiaddr.Multiaddr { return []multiaddr.Multiaddr{l.addr} } +func (l *mockL) Close() error { return nil } +func (l *mockL) Addr() net.Addr { return nil } +func (l *mockL) Multiaddr() multiaddr.Multiaddr { return l.addr } func TestSkipDial(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) diff --git a/p2p/net/swarm/swarm_addr.go b/p2p/net/swarm/swarm_addr.go index f8bec53afa..b2e3e4e8aa 100644 --- a/p2p/net/swarm/swarm_addr.go +++ b/p2p/net/swarm/swarm_addr.go @@ -18,7 +18,7 @@ func (s *Swarm) ListenAddresses() []ma.Multiaddr { func (s *Swarm) listenAddressesNoLock() []ma.Multiaddr { addrs := make([]ma.Multiaddr, 0, len(s.listeners.m)+10) // A bit extra so we may avoid an extra allocation in the for loop below. for l := range s.listeners.m { - addrs = append(addrs, l.Multiaddrs()...) + addrs = append(addrs, l.Multiaddr()) } return addrs } diff --git a/p2p/net/swarm/swarm_listen.go b/p2p/net/swarm/swarm_listen.go index 2e4bb9c52b..9c5394d438 100644 --- a/p2p/net/swarm/swarm_listen.go +++ b/p2p/net/swarm/swarm_listen.go @@ -47,7 +47,7 @@ func (s *Swarm) ListenClose(addrs ...ma.Multiaddr) { s.listeners.Lock() for l := range s.listeners.m { - if !containsSomeMultiaddr(addrs, l.Multiaddrs()) { + if !containsMultiaddr(addrs, l.Multiaddr()) { continue } @@ -96,13 +96,11 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { s.listeners.cacheEOL = time.Time{} s.listeners.Unlock() - maddrs := list.Multiaddrs() + maddr := list.Multiaddr() // signal to our notifiees on listen. s.notifyAll(func(n network.Notifiee) { - for _, maddr := range maddrs { - n.Listen(s, maddr) - } + n.Listen(s, maddr) }) go func() { @@ -122,9 +120,7 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { // signal to our notifiees on listen close. s.notifyAll(func(n network.Notifiee) { - for _, maddr := range maddrs { - n.ListenClose(s, maddr) - } + n.ListenClose(s, maddr) }) s.refs.Done() }() @@ -155,16 +151,11 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { return nil } -func containsSomeMultiaddr(hayStack []ma.Multiaddr, needles []ma.Multiaddr) bool { - seenSet := make(map[string]struct{}, len(needles)) - for _, a := range needles { - seenSet[string(a.Bytes())] = struct{}{} - } - for _, a := range hayStack { - if _, found := seenSet[string(a.Bytes())]; found { +func containsMultiaddr(addrs []ma.Multiaddr, addr ma.Multiaddr) bool { + for _, a := range addrs { + if addr == a { return true } } return false - } diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index 209914adb6..95aa71d756 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -552,7 +552,7 @@ func TestListenCloseCount(t *testing.T) { t.Fatal(err) } listenedAddrs := s.ListenAddresses() - require.Equal(t, 3, len(listenedAddrs)) + require.Equal(t, 2, len(listenedAddrs)) s.ListenClose(listenedAddrs...) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index 0a3ab557e8..c07299c1a5 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -10,7 +10,6 @@ import ( logging "github.com/ipfs/go-log/v2" tec "github.com/jbenet/go-temp-err-catcher" - "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" ) @@ -176,8 +175,4 @@ func (l *listener) String() string { return fmt.Sprintf("", l.Multiaddr()) } -func (l *listener) Multiaddrs() []multiaddr.Multiaddr { - return []multiaddr.Multiaddr{l.Multiaddr()} -} - var _ transport.Listener = (*listener)(nil) diff --git a/p2p/net/upgrader/listener_test.go b/p2p/net/upgrader/listener_test.go index 26b13f1b2c..ce11b606a5 100644 --- a/p2p/net/upgrader/listener_test.go +++ b/p2p/net/upgrader/listener_test.go @@ -40,7 +40,7 @@ func TestAcceptSingleConn(t *testing.T) { ln := createListener(t, u) defer ln.Close() - cconn, err := dial(t, u, ln.Multiaddrs()[0], id, &network.NullScope{}) + cconn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) sconn, err := ln.Accept() @@ -64,7 +64,7 @@ func TestAcceptMultipleConns(t *testing.T) { }() for i := 0; i < 10; i++ { - cconn, err := dial(t, u, ln.Multiaddrs()[0], id, &network.NullScope{}) + cconn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) toClose = append(toClose, cconn) @@ -88,7 +88,7 @@ func TestConnectionsClosedIfNotAccepted(t *testing.T) { ln := createListener(t, u) defer ln.Close() - conn, err := dial(t, u, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) errCh := make(chan error) @@ -127,7 +127,7 @@ func TestFailedUpgradeOnListen(t *testing.T) { errCh <- err }() - _, err := dial(t, u, ln.Multiaddrs()[0], id, &network.NullScope{}) + _, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.Error(err) // close the listener. @@ -161,7 +161,7 @@ func TestListenerClose(t *testing.T) { require.Contains(err.Error(), "use of closed network connection") // doesn't accept new connections when it is closed - _, err = dial(t, u, ln.Multiaddrs()[0], peer.ID("1"), &network.NullScope{}) + _, err = dial(t, u, ln.Multiaddr(), peer.ID("1"), &network.NullScope{}) require.Error(err) } @@ -173,7 +173,7 @@ func TestListenerCloseClosesQueued(t *testing.T) { var conns []transport.CapableConn for i := 0; i < 10; i++ { - conn, err := dial(t, upgrader, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err := dial(t, upgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) conns = append(conns, conn) } @@ -233,7 +233,7 @@ func TestConcurrentAccept(t *testing.T) { go func() { defer wg.Done() - conn, err := dial(t, u, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) if err != nil { errCh <- err return @@ -263,7 +263,7 @@ func TestAcceptQueueBacklogged(t *testing.T) { // setup AcceptQueueLength connections, but don't accept any of them var counter int32 // to be used atomically doDial := func() { - conn, err := dial(t, u, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) atomic.AddInt32(&counter, 1) t.Cleanup(func() { conn.Close() }) @@ -299,7 +299,7 @@ func TestListenerConnectionGater(t *testing.T) { defer ln.Close() // no gating. - conn, err := dial(t, u, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err := dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.False(conn.IsClosed()) _ = conn.Close() @@ -307,28 +307,28 @@ func TestListenerConnectionGater(t *testing.T) { // rejecting after handshake. testGater.BlockSecured(true) testGater.BlockAccept(false) - conn, err = dial(t, u, ln.Multiaddrs()[0], "invalid", &network.NullScope{}) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", &network.NullScope{}) require.Error(err) require.Nil(conn) // rejecting on accept will trigger firupgrader. testGater.BlockSecured(true) testGater.BlockAccept(true) - conn, err = dial(t, u, ln.Multiaddrs()[0], "invalid", &network.NullScope{}) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", &network.NullScope{}) require.Error(err) require.Nil(conn) // rejecting only on acceptance. testGater.BlockSecured(false) testGater.BlockAccept(true) - conn, err = dial(t, u, ln.Multiaddrs()[0], "invalid", &network.NullScope{}) + conn, err = dial(t, u, ln.Multiaddr(), "invalid", &network.NullScope{}) require.Error(err) require.Nil(conn) // back to normal testGater.BlockSecured(false) testGater.BlockAccept(false) - conn, err = dial(t, u, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err = dial(t, u, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.False(conn.IsClosed()) _ = conn.Close() @@ -344,13 +344,13 @@ func TestListenerResourceManagement(t *testing.T) { connScope := mocknetwork.NewMockConnManagementScope(ctrl) gomock.InOrder( - rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Not(ln.Multiaddrs()[0])).Return(connScope, nil), + rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Not(ln.Multiaddr())).Return(connScope, nil), connScope.EXPECT().PeerScope(), connScope.EXPECT().SetPeer(id), connScope.EXPECT().PeerScope(), ) - cconn, err := dial(t, upgrader, ln.Multiaddrs()[0], id, &network.NullScope{}) + cconn, err := dial(t, upgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(t, err) defer cconn.Close() @@ -367,8 +367,8 @@ func TestListenerResourceManagementDenied(t *testing.T) { id, upgrader := createUpgraderWithResourceManager(t, rcmgr) ln := createListener(t, upgrader) - rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Not(ln.Multiaddrs()[0])).Return(nil, errors.New("nope")) - _, err := dial(t, upgrader, ln.Multiaddrs()[0], id, &network.NullScope{}) + rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Not(ln.Multiaddr())).Return(nil, errors.New("nope")) + _, err := dial(t, upgrader, ln.Multiaddr(), id, &network.NullScope{}) require.Error(t, err) done := make(chan struct{}) @@ -404,7 +404,7 @@ func TestNoCommonSecurityProto(t *testing.T) { ln.Accept() }() - _, err = dial(t, ub, ln.Multiaddrs()[0], idA, &network.NullScope{}) + _, err = dial(t, ub, ln.Multiaddr(), idA, &network.NullScope{}) require.EqualError(t, err, "failed to negotiate security protocol: protocol not supported") select { case <-done: diff --git a/p2p/net/upgrader/upgrader_test.go b/p2p/net/upgrader/upgrader_test.go index fbe9527aba..b61ba8a61e 100644 --- a/p2p/net/upgrader/upgrader_test.go +++ b/p2p/net/upgrader/upgrader_test.go @@ -141,21 +141,21 @@ func TestOutboundConnectionGating(t *testing.T) { testGater := &testGater{} _, dialUpgrader := createUpgraderWithConnGater(t, testGater) - conn, err := dial(t, dialUpgrader, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.NotNil(conn) _ = conn.Close() // blocking accepts doesn't affect the dialling side, only the listener. testGater.BlockAccept(true) - conn, err = dial(t, dialUpgrader, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err = dial(t, dialUpgrader, ln.Multiaddr(), id, &network.NullScope{}) require.NoError(err) require.NotNil(conn) _ = conn.Close() // now let's block all connections after being secured. testGater.BlockSecured(true) - conn, err = dial(t, dialUpgrader, ln.Multiaddrs()[0], id, &network.NullScope{}) + conn, err = dial(t, dialUpgrader, ln.Multiaddr(), id, &network.NullScope{}) require.Error(err) require.Contains(err.Error(), "gater rejected connection") require.Nil(conn) @@ -176,7 +176,7 @@ func TestOutboundResourceManagement(t *testing.T) { connScope.EXPECT().PeerScope().Return(&network.NullScope{}), ) _, dialUpgrader := createUpgrader(t) - conn, err := dial(t, dialUpgrader, ln.Multiaddrs()[0], id, connScope) + conn, err := dial(t, dialUpgrader, ln.Multiaddr(), id, connScope) require.NoError(t, err) require.NotNil(t, conn) connScope.EXPECT().Done() @@ -198,7 +198,7 @@ func TestOutboundResourceManagement(t *testing.T) { connScope.EXPECT().Done(), ) _, dialUpgrader := createUpgrader(t) - _, err := dial(t, dialUpgrader, ln.Multiaddrs()[0], id, connScope) + _, err := dial(t, dialUpgrader, ln.Multiaddr(), id, connScope) require.Error(t, err) }) } diff --git a/p2p/transport/quic/cmd/server/main.go b/p2p/transport/quic/cmd/server/main.go index a939c4725f..e5144f496d 100644 --- a/p2p/transport/quic/cmd/server/main.go +++ b/p2p/transport/quic/cmd/server/main.go @@ -54,7 +54,7 @@ func run(port string) error { if err != nil { return err } - fmt.Printf("Listening. Now run: go run cmd/client/main.go %s %s\n", ln.Multiaddrs()[0], peerID) + fmt.Printf("Listening. Now run: go run cmd/client/main.go %s %s\n", ln.Multiaddr(), peerID) for { conn, err := ln.Accept() if err != nil { diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index 770377869a..157e2ddc3c 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -96,7 +96,7 @@ func testHandshake(t *testing.T, tc *connTestCase) { clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() serverConn, err := ln.Accept() @@ -158,7 +158,7 @@ func testResourceManagerSuccess(t *testing.T, tc *connTestCase) { connChan := make(chan tpt.CapableConn) serverConnScope := mocknetwork.NewMockConnManagementScope(ctrl) go func() { - serverRcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Not(ln.Multiaddrs()[0])).Return(serverConnScope, nil) + serverRcmgr.EXPECT().OpenConnection(network.DirInbound, false, gomock.Not(ln.Multiaddr())).Return(serverConnScope, nil) serverConnScope.EXPECT().SetPeer(clientID) serverConn, err := ln.Accept() require.NoError(t, err) @@ -166,9 +166,9 @@ func testResourceManagerSuccess(t *testing.T, tc *connTestCase) { }() connScope := mocknetwork.NewMockConnManagementScope(ctrl) - clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false, ln.Multiaddrs()[0]).Return(connScope, nil) + clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false, ln.Multiaddr()).Return(connScope, nil) connScope.EXPECT().SetPeer(serverID) - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) serverConn := <-connChan t.Log("received conn") @@ -250,12 +250,12 @@ func testResourceManagerAcceptDenied(t *testing.T, tc *connTestCase) { }() clientConnScope := mocknetwork.NewMockConnManagementScope(ctrl) - clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false, ln.Multiaddrs()[0]).Return(clientConnScope, nil) + clientRcmgr.EXPECT().OpenConnection(network.DirOutbound, false, ln.Multiaddr()).Return(clientConnScope, nil) clientConnScope.EXPECT().SetPeer(serverID) // In rare instances, the connection gating error will already occur on Dial. // In that case, Done is called on the connection scope. clientConnScope.EXPECT().Done().MaxTimes(1) - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) // In rare instances, the connection gating error will already occur on Dial. if err == nil { _, err = conn.AcceptStream() @@ -289,7 +289,7 @@ func testStreams(t *testing.T, tc *connTestCase) { clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() serverConn, err := ln.Accept() @@ -329,7 +329,7 @@ func testHandshakeFailPeerIDMismatch(t *testing.T, tc *connTestCase) { clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) // dial, but expect the wrong peer ID - _, err = clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], thirdPartyID) + _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID) require.Error(t, err) require.Contains(t, err.Error(), "CRYPTO_ERROR") defer clientTransport.(io.Closer).Close() @@ -386,7 +386,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { require.NoError(t, err) defer clientTransport.(io.Closer).Close() // make sure that connection attempts fails - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) // In rare instances, the connection gating error will already occur on Dial. // In most cases, it will be returned by AcceptStream. if err == nil { @@ -397,7 +397,7 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { // now allow the address and make sure the connection goes through cg.EXPECT().InterceptAccept(gomock.Any()).Return(true) cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - conn, err = clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() require.Eventually(t, func() bool { @@ -425,13 +425,13 @@ func testConnectionGating(t *testing.T, tc *connTestCase) { defer clientTransport.(io.Closer).Close() // make sure that connection attempts fails - _, err = clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.Error(t, err) require.Contains(t, err.Error(), "connection gated") // now allow the peerId and make sure the connection goes through cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) conn.Close() }) @@ -484,10 +484,10 @@ func testDialTwo(t *testing.T, tc *connTestCase) { clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) require.NoError(t, err) defer clientTransport.(io.Closer).Close() - c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddrs()[0], serverID) + c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID) require.NoError(t, err) defer c1.Close() - c2, err := clientTransport.Dial(context.Background(), ln2.Multiaddrs()[0], serverID2) + c2, err := clientTransport.Dial(context.Background(), ln2.Multiaddr(), serverID2) require.NoError(t, err) defer c2.Close() @@ -637,7 +637,7 @@ func TestHolePunching(t *testing.T) { go func() { conn, err := t2.Dial( network.WithSimultaneousConnect(context.Background(), false, ""), - ln1.Multiaddrs()[0], + ln1.Multiaddr(), serverID, ) require.NoError(t, err) @@ -655,7 +655,7 @@ func TestHolePunching(t *testing.T) { conn1, err := t1.Dial( network.WithSimultaneousConnect(context.Background(), true, ""), - ln2.Multiaddrs()[0], + ln2.Multiaddr(), clientID, ) require.NoError(t, err) @@ -720,10 +720,20 @@ func TestClientCanDialDifferentQUICVersions(t *testing.T) { t1, err := NewTransport(serverKey, newConnManager(t, serverOpts...), nil, nil, nil) require.NoError(t, err) defer t1.(io.Closer).Close() - laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic-v1") - require.NoError(t, err) + laddr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1") ln1, err := t1.Listen(laddr) require.NoError(t, err) + t.Cleanup(func() { ln1.Close() }) + + mas := []ma.Multiaddr{ln1.Multiaddr()} + var ln2 tpt.Listener + if !tc.serverDisablesDraft29 { + laddrDraft29 := ma.StringCast("/ip4/127.0.0.1/udp/0/quic") + ln2, err = t1.Listen(laddrDraft29) + require.NoError(t, err) + t.Cleanup(func() { ln2.Close() }) + mas = append(mas, ln2.Multiaddr()) + } t2, err := NewTransport(clientKey, newConnManager(t), nil, nil, nil) require.NoError(t, err) @@ -731,14 +741,22 @@ func TestClientCanDialDifferentQUICVersions(t *testing.T) { ctx := context.Background() - for _, a := range ln1.Multiaddrs() { + for _, a := range mas { _, v, err := quicreuse.FromQuicMultiaddr(a) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) - conn, err := ln1.Accept() + var conn tpt.CapableConn + var err error + if v == quic.Version1 { + conn, err = ln1.Accept() + } else if v == quic.VersionDraft29 { + conn, err = ln2.Accept() + } else { + panic("unexpected version") + } require.NoError(t, err) defer conn.Close() diff --git a/p2p/transport/quic/listener.go b/p2p/transport/quic/listener.go index ec7929fc94..ea6b68bd6c 100644 --- a/p2p/transport/quic/listener.go +++ b/p2p/transport/quic/listener.go @@ -26,9 +26,7 @@ type listener struct { localMultiaddrs map[quic.VersionNumber]ma.Multiaddr } -var _ tpt.Listener = &listener{} - -func newListener(ln quicreuse.Listener, t *transport, localPeer peer.ID, key ic.PrivKey, rcmgr network.ResourceManager) (tpt.Listener, error) { +func newListener(ln quicreuse.Listener, t *transport, localPeer peer.ID, key ic.PrivKey, rcmgr network.ResourceManager) (listener, error) { localMultiaddrs := make(map[quic.VersionNumber]ma.Multiaddr) for _, addr := range ln.Multiaddrs() { if _, err := addr.ValueForProtocol(ma.P_QUIC); err == nil { @@ -39,7 +37,7 @@ func newListener(ln quicreuse.Listener, t *transport, localPeer peer.ID, key ic. } } - return &listener{ + return listener{ reuseListener: ln, transport: t, rcmgr: rcmgr, @@ -144,12 +142,3 @@ func (l *listener) Close() error { func (l *listener) Addr() net.Addr { return l.reuseListener.Addr() } - -// Multiaddr returns the multiaddress of this listener. -func (l *listener) Multiaddrs() []ma.Multiaddr { - mas := make([]ma.Multiaddr, 0, len(l.localMultiaddrs)) - for _, a := range l.localMultiaddrs { - mas = append(mas, a) - } - return mas -} diff --git a/p2p/transport/quic/listener_test.go b/p2p/transport/quic/listener_test.go index d45f065e18..f62cddc99b 100644 --- a/p2p/transport/quic/listener_test.go +++ b/p2p/transport/quic/listener_test.go @@ -13,6 +13,7 @@ import ( ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" tpt "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" ma "github.com/multiformats/go-multiaddr" "github.com/stretchr/testify/require" @@ -36,11 +37,16 @@ func TestListenAddr(t *testing.T) { localAddr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic") ln, err := tr.Listen(localAddr) require.NoError(t, err) + localAddrV1 := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1") + ln2, err := tr.Listen(localAddrV1) + require.NoError(t, err) defer ln.Close() + defer ln2.Close() port := ln.Addr().(*net.UDPAddr).Port require.NotZero(t, port) + var multiaddrsStrings []string - for _, a := range ln.Multiaddrs() { + for _, a := range []ma.Multiaddr{ln.Multiaddr(), ln2.Multiaddr()} { multiaddrsStrings = append(multiaddrsStrings, a.String()) } require.Contains(t, multiaddrsStrings, fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", port)) @@ -51,11 +57,15 @@ func TestListenAddr(t *testing.T) { localAddr := ma.StringCast("/ip6/::/udp/0/quic") ln, err := tr.Listen(localAddr) require.NoError(t, err) + localAddrV1 := ma.StringCast("/ip6/::/udp/0/quic-v1") + ln2, err := tr.Listen(localAddrV1) + require.NoError(t, err) defer ln.Close() + defer ln2.Close() port := ln.Addr().(*net.UDPAddr).Port require.NotZero(t, port) var multiaddrsStrings []string - for _, a := range ln.Multiaddrs() { + for _, a := range []ma.Multiaddr{ln.Multiaddr(), ln2.Multiaddr()} { multiaddrsStrings = append(multiaddrsStrings, a.String()) } require.Contains(t, multiaddrsStrings, fmt.Sprintf("/ip6/::/udp/%d/quic", port)) @@ -96,3 +106,29 @@ func TestAcceptAfterClose(t *testing.T) { _, err = ln.Accept() require.Error(t, err) } + +func TestCorrectNumberOfVirtualListeners(t *testing.T) { + tr := newTransport(t, nil) + tpt := tr.(*transport) + defer tr.(io.Closer).Close() + + localAddr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic") + udpAddr, _, err := quicreuse.FromQuicMultiaddr(localAddr) + require.NoError(t, err) + + ln, err := tr.Listen(localAddr) + require.NoError(t, err) + require.Equal(t, 1, len(tpt.listeners[udpAddr.String()])) + localAddrV1 := ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1") + ln2, err := tr.Listen(localAddrV1) + require.NoError(t, err) + + require.NoError(t, err) + require.Equal(t, 2, len(tpt.listeners[udpAddr.String()])) + + ln.Close() + require.Equal(t, 1, len(tpt.listeners[udpAddr.String()])) + ln2.Close() + require.Equal(t, 0, len(tpt.listeners[udpAddr.String()])) + +} diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 0c34d43640..325471973e 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -50,6 +50,10 @@ type transport struct { connMx sync.Mutex conns map[quic.Connection]*conn + + listenersMu sync.Mutex + // map of UDPAddr as string to a virtualListeners + listeners map[string][]*virtualListener } var _ tpt.Transport = &transport{} @@ -92,6 +96,8 @@ func NewTransport(key ic.PrivKey, connManager *quicreuse.ConnManager, psk pnet.P rcmgr: rcmgr, conns: make(map[quic.Connection]*conn), holePunching: make(map[holePunchKey]*activeHolePunch), + + listeners: make(map[string][]*virtualListener), }, nil } @@ -269,16 +275,54 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { return conf, nil } tlsConf.NextProtos = []string{"libp2p"} - - ln, err := t.connManager.ListenQUIC(addr, &tlsConf, t.allowWindowIncrease) + udpAddr, version, err := quicreuse.FromQuicMultiaddr(addr) if err != nil { return nil, err } - l, err := newListener(ln, t, t.localPeer, t.privKey, t.rcmgr) - if err != nil { - _ = ln.Close() - return nil, err + + t.listenersMu.Lock() + defer t.listenersMu.Unlock() + listeners := t.listeners[udpAddr.String()] + var underlyingListener *listener + var acceptRunner *acceptLoopRunner + if len(listeners) != 0 { + // We already have an underlying listener, let's use it + underlyingListener = listeners[0].listener + acceptRunner = listeners[0].acceptRunnner + // Make sure our underlying listener is listening on the specified QUIC version + if _, ok := underlyingListener.localMultiaddrs[version]; !ok { + return nil, fmt.Errorf("can't listen on quic version %v, underlying listener doesn't support it", version) + } + } else { + ln, err := t.connManager.ListenQUIC(addr, &tlsConf, t.allowWindowIncrease) + if err != nil { + return nil, err + } + l, err := newListener(ln, t, t.localPeer, t.privKey, t.rcmgr) + if err != nil { + _ = ln.Close() + return nil, err + } + underlyingListener = &l + + acceptRunner = &acceptLoopRunner{ + acceptSem: make(chan struct{}, 1), + muxer: make(map[quic.VersionNumber]chan acceptVal), + } + } + + l := &virtualListener{ + listener: underlyingListener, + version: version, + udpAddr: udpAddr.String(), + t: t, + acceptRunnner: acceptRunner, + acceptChan: acceptRunner.AcceptForVersion(version), } + + listeners = append(listeners, l) + t.listeners[udpAddr.String()] = listeners + return l, nil } @@ -313,3 +357,30 @@ func (t *transport) String() string { func (t *transport) Close() error { return nil } + +func (t *transport) CloseVirtualListener(l *virtualListener) error { + t.listenersMu.Lock() + defer t.listenersMu.Unlock() + + var err error + listeners := t.listeners[l.udpAddr] + if len(listeners) == 1 { + // This is the last virtual listener here, so we can close the underlying listener + err = l.listener.Close() + delete(t.listeners, l.udpAddr) + return err + } + + for i := 0; i < len(listeners); i++ { + // Swap remove + if l == listeners[i] { + listeners[i] = listeners[len(listeners)-1] + listeners = listeners[:len(listeners)-1] + t.listeners[l.udpAddr] = listeners + break + } + } + + return nil + +} diff --git a/p2p/transport/quic/virtuallistener.go b/p2p/transport/quic/virtuallistener.go new file mode 100644 index 0000000000..a95b857dde --- /dev/null +++ b/p2p/transport/quic/virtuallistener.go @@ -0,0 +1,158 @@ +package libp2pquic + +import ( + "errors" + "sync" + + tpt "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" + "github.com/lucas-clemente/quic-go" + ma "github.com/multiformats/go-multiaddr" +) + +const acceptBufferPerVersion = 4 + +// virtualListener is a listener that exposes a single multiaddr but uses another listener under the hood +type virtualListener struct { + *listener + udpAddr string + version quic.VersionNumber + t *transport + acceptRunnner *acceptLoopRunner + acceptChan chan acceptVal +} + +var _ tpt.Listener = &virtualListener{} + +func (l *virtualListener) Multiaddr() ma.Multiaddr { + return l.listener.localMultiaddrs[l.version] +} + +func (l *virtualListener) Close() error { + l.acceptRunnner.RmAcceptForVersion(l.version) + return l.t.CloseVirtualListener(l) +} + +func (l *virtualListener) Accept() (tpt.CapableConn, error) { + return l.acceptRunnner.Accept(l.listener, l.version, l.acceptChan) +} + +type acceptVal struct { + conn tpt.CapableConn + err error +} + +type acceptLoopRunner struct { + acceptSem chan struct{} + + muxerMu sync.Mutex + muxer map[quic.VersionNumber]chan acceptVal +} + +func (r *acceptLoopRunner) AcceptForVersion(v quic.VersionNumber) chan acceptVal { + r.muxerMu.Lock() + defer r.muxerMu.Unlock() + + ch := make(chan acceptVal, acceptBufferPerVersion) + + if _, ok := r.muxer[v]; ok { + panic("unexpected chan already found in accept muxer") + } + + r.muxer[v] = ch + return ch +} + +func (r *acceptLoopRunner) RmAcceptForVersion(v quic.VersionNumber) { + r.muxerMu.Lock() + defer r.muxerMu.Unlock() + + ch, ok := r.muxer[v] + if !ok { + panic("expected chan in accept muxer") + } + ch <- acceptVal{err: errors.New("listener Accept closed")} + delete(r.muxer, v) +} + +func (r *acceptLoopRunner) sendErrAndClose(err error) { + r.muxerMu.Lock() + defer r.muxerMu.Unlock() + for k, ch := range r.muxer { + select { + case ch <- acceptVal{err: err}: + default: + } + delete(r.muxer, k) + close(ch) + } +} + +// innerAccept is the inner logic of the Accept loop. Assume caller holds the +// acceptSemaphore. May return both a nil conn and nil error if it didn't find a +// conn with the expected version +func (r *acceptLoopRunner) innerAccept(l *listener, expectedVersion quic.VersionNumber, bufferedConnChan chan acceptVal) (tpt.CapableConn, error) { + select { + // Check if we have a buffered connection first from an earlier Accept call + case v, ok := <-bufferedConnChan: + if !ok { + return nil, errors.New("listener closed") + } + return v.conn, v.err + default: + } + + conn, err := l.Accept() + + if err != nil { + r.sendErrAndClose(err) + return nil, err + } + + _, version, err := quicreuse.FromQuicMultiaddr(conn.RemoteMultiaddr()) + if err != nil { + r.sendErrAndClose(err) + return nil, err + } + + if version == expectedVersion { + return conn, nil + } + + // This wasn't the version we were expecting, lets queue it up for a + // future Accept call with a different version + r.muxerMu.Lock() + ch, ok := r.muxer[version] + r.muxerMu.Unlock() + + if !ok { + // Nothing to handle this connection version. Close it + conn.Close() + return nil, nil + } + + // Non blocking + select { + case ch <- acceptVal{conn: conn}: + default: + // accept queue filled up, drop the connection + conn.Close() + log.Warn("Accept queue filled. Dropping connection.") + } + + return nil, nil +} + +func (r *acceptLoopRunner) Accept(l *listener, expectedVersion quic.VersionNumber, bufferedConnChan chan acceptVal) (tpt.CapableConn, error) { + for { + r.acceptSem <- struct{}{} + conn, err := r.innerAccept(l, expectedVersion, bufferedConnChan) + <-r.acceptSem + + if conn == nil && err == nil { + // Didn't find a conn for the expected version and there was no error, lets try again + continue + } + return conn, err + } +} diff --git a/p2p/transport/tcp/tcp_test.go b/p2p/transport/tcp/tcp_test.go index 1dacc71952..05b6cb9352 100644 --- a/p2p/transport/tcp/tcp_test.go +++ b/p2p/transport/tcp/tcp_test.go @@ -85,10 +85,10 @@ func TestResourceManager(t *testing.T) { t.Run("success", func(t *testing.T) { scope := mocknetwork.NewMockConnManagementScope(ctrl) - rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, ln.Multiaddrs()[0]).Return(scope, nil) + rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, ln.Multiaddr()).Return(scope, nil) scope.EXPECT().SetPeer(peerA) scope.EXPECT().PeerScope().Return(&network.NullScope{}).AnyTimes() // called by the upgrader - conn, err := tb.Dial(context.Background(), ln.Multiaddrs()[0], peerA) + conn, err := tb.Dial(context.Background(), ln.Multiaddr(), peerA) require.NoError(t, err) scope.EXPECT().Done() defer conn.Close() @@ -96,18 +96,18 @@ func TestResourceManager(t *testing.T) { t.Run("connection denied", func(t *testing.T) { rerr := errors.New("nope") - rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, ln.Multiaddrs()[0]).Return(nil, rerr) - _, err = tb.Dial(context.Background(), ln.Multiaddrs()[0], peerA) + rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, ln.Multiaddr()).Return(nil, rerr) + _, err = tb.Dial(context.Background(), ln.Multiaddr(), peerA) require.ErrorIs(t, err, rerr) }) t.Run("peer denied", func(t *testing.T) { scope := mocknetwork.NewMockConnManagementScope(ctrl) - rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, ln.Multiaddrs()[0]).Return(scope, nil) + rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, ln.Multiaddr()).Return(scope, nil) rerr := errors.New("nope") scope.EXPECT().SetPeer(peerA).Return(rerr) scope.EXPECT().Done() - _, err = tb.Dial(context.Background(), ln.Multiaddrs()[0], peerA) + _, err = tb.Dial(context.Background(), ln.Multiaddr(), peerA) require.ErrorIs(t, err, rerr) }) } diff --git a/p2p/transport/testsuite/stream_suite.go b/p2p/transport/testsuite/stream_suite.go index 3170b72c40..b139976b91 100644 --- a/p2p/transport/testsuite/stream_suite.go +++ b/p2p/transport/testsuite/stream_suite.go @@ -197,7 +197,7 @@ func SubtestStress(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, serve(t, l) }() - c, err := tb.Dial(context.Background(), l.Multiaddrs()[0], peerA) + c, err := tb.Dial(context.Background(), l.Multiaddr(), peerA) if err != nil { t.Error(err) return @@ -259,7 +259,7 @@ func SubtestStreamOpenStress(t *testing.T, ta, tb transport.Transport, maddr ma. connA, err = l.Accept() accepted <- err }() - connB, err = tb.Dial(context.Background(), l.Multiaddrs()[0], peerA) + connB, err = tb.Dial(context.Background(), l.Multiaddr(), peerA) if err != nil { t.Fatal(err) } @@ -373,7 +373,7 @@ func SubtestStreamReset(t *testing.T, ta, tb transport.Transport, maddr ma.Multi }() - muxb, err := tb.Dial(context.Background(), l.Multiaddrs()[0], peerA) + muxb, err := tb.Dial(context.Background(), l.Multiaddr(), peerA) if err != nil { t.Fatal(err) } diff --git a/p2p/transport/testsuite/transport_suite.go b/p2p/transport/testsuite/transport_suite.go index f8d0ff1113..bd8892e807 100644 --- a/p2p/transport/testsuite/transport_suite.go +++ b/p2p/transport/testsuite/transport_suite.go @@ -111,11 +111,11 @@ func SubtestBasic(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, } }() - if !tb.CanDial(list.Multiaddrs()[0]) { + if !tb.CanDial(list.Multiaddr()) { t.Error("CanDial should have returned true") } - connA, err = tb.Dial(ctx, list.Multiaddrs()[0], peerA) + connA, err = tb.Dial(ctx, list.Multiaddr(), peerA) if err != nil { t.Fatal(err) } @@ -232,11 +232,11 @@ func SubtestPingPong(t *testing.T, ta, tb transport.Transport, maddr ma.Multiadd sWg.Wait() }() - if !tb.CanDial(list.Multiaddrs()[0]) { + if !tb.CanDial(list.Multiaddr()) { t.Error("CanDial should have returned true") } - connB, err = tb.Dial(ctx, list.Multiaddrs()[0], peerA) + connB, err = tb.Dial(ctx, list.Multiaddr(), peerA) if err != nil { t.Fatal(err) } @@ -297,7 +297,7 @@ func SubtestCancel(t *testing.T, ta, tb transport.Transport, maddr ma.Multiaddr, ctx, cancel := context.WithCancel(context.Background()) cancel() - c, err := tb.Dial(ctx, list.Multiaddrs()[0], peerA) + c, err := tb.Dial(ctx, list.Multiaddr(), peerA) if err == nil { c.Close() t.Fatal("dial should have failed") diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index c9991d71b1..5e7ed098d6 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -139,10 +139,6 @@ func (l *listener) Multiaddr() ma.Multiaddr { return l.laddr } -func (l *listener) Multiaddrs() []ma.Multiaddr { - return []ma.Multiaddr{l.laddr} -} - type transportListener struct { transport.Listener } diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 0b05f94298..5b9500f612 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -181,7 +181,7 @@ func testWSSServer(t *testing.T, listenAddr ma.Multiaddr) (ma.Multiaddr, peer.ID close(errChan) }() - return l.Multiaddrs()[0], id, errChan + return l.Multiaddr(), id, errChan } func getTLSConf(t *testing.T, ip net.IP, start, end time.Time) *tls.Config { @@ -321,9 +321,9 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { l, err := tpt.Listen(laddr) require.NoError(t, err) if secure { - require.Contains(t, l.Multiaddrs()[0].String(), "tls") + require.Contains(t, l.Multiaddr().String(), "tls") } else { - require.Equal(t, lastComponent(t, l.Multiaddrs()[0]), wsComponent) + require.Equal(t, lastComponent(t, l.Multiaddr()), wsComponent) } defer l.Close() @@ -337,7 +337,7 @@ func connectAndExchangeData(t *testing.T, laddr ma.Multiaddr, secure bool) { _, u := newUpgrader(t) tpt, err := New(u, &network.NullResourceManager{}, opts...) require.NoError(t, err) - c, err := tpt.Dial(context.Background(), l.Multiaddrs()[0], server) + c, err := tpt.Dial(context.Background(), l.Multiaddr(), server) require.NoError(t, err) str, err := c.OpenStream(context.Background()) require.NoError(t, err) @@ -392,14 +392,14 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { require.NoError(t, err) // dialing the insecure address should succeed - conn, err := client.Dial(context.Background(), lnInsecure.Multiaddrs()[0], serverID) + conn, err := client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()).String(), wsComponent.String()) require.Equal(t, lastComponent(t, conn.LocalMultiaddr()).String(), wsComponent.String()) // dialing the secure address should fail - _, err = client.Dial(context.Background(), lnSecure.Multiaddrs()[0], serverID) + _, err = client.Dial(context.Background(), lnSecure.Multiaddr(), serverID) require.NoError(t, err) }) @@ -409,14 +409,14 @@ func TestWebsocketListenSecureAndInsecure(t *testing.T) { require.NoError(t, err) // dialing the insecure address should succeed - conn, err := client.Dial(context.Background(), lnSecure.Multiaddrs()[0], serverID) + conn, err := client.Dial(context.Background(), lnSecure.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() require.Equal(t, lastComponent(t, conn.RemoteMultiaddr()), wssComponent) require.Equal(t, lastComponent(t, conn.LocalMultiaddr()), wssComponent) // dialing the insecure address should fail - _, err = client.Dial(context.Background(), lnInsecure.Multiaddrs()[0], serverID) + _, err = client.Dial(context.Background(), lnInsecure.Multiaddr(), serverID) require.NoError(t, err) }) } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index f82c0944bf..c604022837 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -198,11 +198,11 @@ func (l *listener) Addr() net.Addr { return l.addr } -func (l *listener) Multiaddrs() []ma.Multiaddr { +func (l *listener) Multiaddr() ma.Multiaddr { if l.transport.certManager == nil { - return []ma.Multiaddr{l.multiaddr} + return l.multiaddr } - return []ma.Multiaddr{l.multiaddr.Encapsulate(l.transport.certManager.AddrComponent())} + return l.multiaddr.Encapsulate(l.transport.certManager.AddrComponent()) } func (l *listener) Close() error { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index 96857564ce..b43ff9feb6 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -126,7 +126,7 @@ func TestTransport(t *testing.T) { require.NoError(t, err) defer tr2.(io.Closer).Close() - conn, err := tr2.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := tr2.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) str, err := conn.OpenStream(context.Background()) require.NoError(t, err) @@ -135,7 +135,7 @@ func TestTransport(t *testing.T) { require.NoError(t, str.Close()) // check RemoteMultiaddr - _, addr, err := manet.DialArgs(ln.Multiaddrs()[0]) + _, addr, err := manet.DialArgs(ln.Multiaddr()) require.NoError(t, err) _, port, err := net.SplitHostPort(addr) require.NoError(t, err) @@ -179,14 +179,14 @@ func TestHashVerification(t *testing.T) { t.Run("fails using only a wrong hash", func(t *testing.T) { // replace the certificate hash in the multiaddr with a fake hash - addr := stripCertHashes(ln.Multiaddrs()[0]).Encapsulate(foobarHash) + addr := stripCertHashes(ln.Multiaddr()).Encapsulate(foobarHash) _, err := tr2.Dial(context.Background(), addr, serverID) require.Error(t, err) require.Contains(t, err.Error(), "CRYPTO_ERROR (0x12a): cert hash not found") }) t.Run("fails when adding a wrong hash", func(t *testing.T) { - _, err := tr2.Dial(context.Background(), ln.Multiaddrs()[0].Encapsulate(foobarHash), serverID) + _, err := tr2.Dial(context.Background(), ln.Multiaddr().Encapsulate(foobarHash), serverID) require.Error(t, err) }) @@ -260,9 +260,9 @@ func TestListenerAddrs(t *testing.T) { require.NoError(t, err) ln2, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) - hashes1 := extractCertHashes(ln1.Multiaddrs()[0]) + hashes1 := extractCertHashes(ln1.Multiaddr()) require.Len(t, hashes1, 2) - hashes2 := extractCertHashes(ln2.Multiaddrs()[0]) + hashes2 := extractCertHashes(ln2.Multiaddr()) require.Equal(t, hashes1, hashes2) } @@ -316,7 +316,7 @@ func TestResourceManagerListening(t *testing.T) { return nil, errors.New("denied") }) - _, err = cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.EqualError(t, err, "received status 503") }) @@ -338,7 +338,7 @@ func TestResourceManagerListening(t *testing.T) { scope.EXPECT().Done().Do(func() { close(serverDone) }) // The handshake will complete, but the server will immediately close the connection. - conn, err := cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() clientDone := make(chan struct{}) @@ -377,13 +377,13 @@ func TestConnectionGaterDialing(t *testing.T) { defer ln.Close() connGater.EXPECT().InterceptSecured(network.DirOutbound, serverID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { - require.Equal(t, stripCertHashes(ln.Multiaddrs()[0]), addrs.RemoteMultiaddr()) + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) }) _, key := newIdentity(t) cl, err := libp2pwebtransport.New(key, newConnManager(t), connGater, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() - _, err = cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.EqualError(t, err, "secured connection gated") } @@ -401,15 +401,15 @@ func TestConnectionGaterInterceptAccept(t *testing.T) { defer ln.Close() connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) { - require.Equal(t, stripCertHashes(ln.Multiaddrs()[0]), addrs.LocalMultiaddr()) - require.NotEqual(t, stripCertHashes(ln.Multiaddrs()[0]), addrs.RemoteMultiaddr()) + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr()) + require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) }) _, key := newIdentity(t) cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() - _, err = cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.EqualError(t, err, "received status 403") } @@ -433,11 +433,11 @@ func TestConnectionGaterInterceptSecured(t *testing.T) { connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true) connGater.EXPECT().InterceptSecured(network.DirInbound, clientID, gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { - require.Equal(t, stripCertHashes(ln.Multiaddrs()[0]), addrs.LocalMultiaddr()) - require.NotEqual(t, stripCertHashes(ln.Multiaddrs()[0]), addrs.RemoteMultiaddr()) + require.Equal(t, stripCertHashes(ln.Multiaddr()), addrs.LocalMultiaddr()) + require.NotEqual(t, stripCertHashes(ln.Multiaddr()), addrs.RemoteMultiaddr()) }) // The handshake will complete, but the server will immediately close the connection. - conn, err := cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() done := make(chan struct{}) @@ -491,7 +491,7 @@ func TestStaticTLSConf(t *testing.T) { ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) defer ln.Close() - require.Empty(t, extractCertHashes(ln.Multiaddrs()[0]), "listener address shouldn't contain any certhash") + require.Empty(t, extractCertHashes(ln.Multiaddr()), "listener address shouldn't contain any certhash") t.Run("fails when the certificate is invalid", func(t *testing.T) { _, key := newIdentity(t) @@ -499,7 +499,7 @@ func TestStaticTLSConf(t *testing.T) { require.NoError(t, err) defer cl.(io.Closer).Close() - _, err = cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + _, err = cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.Error(t, err) if !strings.Contains(err.Error(), "certificate is not trusted") && !strings.Contains(err.Error(), "certificate signed by unknown authority") { @@ -513,7 +513,7 @@ func TestStaticTLSConf(t *testing.T) { require.NoError(t, err) defer cl.(io.Closer).Close() - addr := ln.Multiaddrs()[0].Encapsulate(getCerthashComponent(t, []byte("foo"))) + addr := ln.Multiaddr().Encapsulate(getCerthashComponent(t, []byte("foo"))) _, err = cl.Dial(context.Background(), addr, serverID) require.Error(t, err) require.Contains(t, err.Error(), "cert hash not found") @@ -528,8 +528,8 @@ func TestStaticTLSConf(t *testing.T) { require.NoError(t, err) defer cl.(io.Closer).Close() - require.True(t, cl.CanDial(ln.Multiaddrs()[0])) - conn, err := cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + require.True(t, cl.CanDial(ln.Multiaddr())) + conn, err := cl.Dial(context.Background(), ln.Multiaddr(), serverID) require.NoError(t, err) defer conn.Close() }) @@ -550,7 +550,7 @@ func TestAcceptQueueFilledUp(t *testing.T) { cl, err := libp2pwebtransport.New(key, newConnManager(t), nil, &network.NullResourceManager{}) require.NoError(t, err) defer cl.(io.Closer).Close() - return cl.Dial(context.Background(), ln.Multiaddrs()[0], serverID) + return cl.Dial(context.Background(), ln.Multiaddr(), serverID) } for i := 0; i < 16; i++ { @@ -589,7 +589,7 @@ func TestSNIIsSent(t *testing.T) { require.NoError(t, err) defer tr.(io.Closer).Close() - beforeQuicMa, withQuicMa := ma.SplitFunc(ln1.Multiaddrs()[0], func(c ma.Component) bool { + beforeQuicMa, withQuicMa := ma.SplitFunc(ln1.Multiaddr(), func(c ma.Component) bool { return c.Protocol().Code == ma.P_QUIC_V1 }) @@ -675,7 +675,7 @@ func TestFlowControlWindowIncrease(t *testing.T) { defer tr2.(io.Closer).Close() var addr ma.Multiaddr - for _, comp := range ma.Split(ln.Multiaddrs()[0]) { + for _, comp := range ma.Split(ln.Multiaddr()) { if _, err := comp.ValueForProtocol(ma.P_UDP); err == nil { addr = addr.Encapsulate(ma.StringCast(fmt.Sprintf("/udp/%d", proxy.LocalPort()))) continue @@ -842,7 +842,7 @@ func TestServerRotatesCertCorrectly(t *testing.T) { if err != nil { return false } - certhashes := extractCertHashes(onlyWebtransportmultiaddr(t, l.Multiaddrs())) + certhashes := extractCertHashes(l.Multiaddr()) l.Close() // These two certificates together are valid for at most certValidity - (4*clockSkewAllowance) @@ -859,8 +859,7 @@ func TestServerRotatesCertCorrectly(t *testing.T) { defer l.Close() var found bool - addrs := onlyWebtransportmultiaddr(t, l.Multiaddrs()) - ma.ForEach(addrs, func(c ma.Component) bool { + ma.ForEach(l.Multiaddr(), func(c ma.Component) bool { if c.Protocol().Code == ma.P_CERTHASH { for _, prevCerthash := range certhashes { if c.Value() == prevCerthash { @@ -877,15 +876,6 @@ func TestServerRotatesCertCorrectly(t *testing.T) { }, nil)) } -func onlyWebtransportmultiaddr(t testing.TB, addrs []ma.Multiaddr) ma.Multiaddr { - addrs = ma.FilterAddrs(addrs, func(m ma.Multiaddr) bool { - _, err := m.ValueForProtocol(ma.P_WEBTRANSPORT) - return err == nil - }) - require.NotEmpty(t, addrs) - return addrs[0] -} - func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { cl := clock.NewMock() // Move one year ahead to avoid edge cases around epoch @@ -899,7 +889,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { l, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) require.NoError(t, err) - certhashes := extractCertHashes(onlyWebtransportmultiaddr(t, l.Multiaddrs())) + certhashes := extractCertHashes(l.Multiaddr()) l.Close() // Traverse various time boundaries and make sure we always keep a common certhash. @@ -912,8 +902,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { require.NoError(t, err) var found bool - addrs := onlyWebtransportmultiaddr(t, l.Multiaddrs()) - ma.ForEach(addrs, func(c ma.Component) bool { + ma.ForEach(l.Multiaddr(), func(c ma.Component) bool { if c.Protocol().Code == ma.P_CERTHASH { for _, prevCerthash := range certhashes { if prevCerthash == c.Value() { @@ -924,7 +913,7 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { } return true }) - certhashes = extractCertHashes(onlyWebtransportmultiaddr(t, l.Multiaddrs())) + certhashes = extractCertHashes(l.Multiaddr()) l.Close() require.True(t, found, "Failed after hour: %v", i)