From 2f564f628ee59037ff8535a85fbeac83a0cb7d0d Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 25 Jul 2021 18:05:29 +0200 Subject: [PATCH 1/8] migrate the transport tests away from Ginkgo --- go.mod | 1 + transport_test.go | 126 +++++++++++++++++++++++++++------------------- 2 files changed, 75 insertions(+), 52 deletions(-) diff --git a/go.mod b/go.mod index d2c50b0..9907fba 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/onsi/ginkgo v1.16.4 github.com/onsi/gomega v1.13.0 github.com/prometheus/client_golang v1.9.0 + github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf golang.org/x/sync v0.0.0-20210220032951-036812b2e83c ) diff --git a/transport_test.go b/transport_test.go index 5f52ce0..6ea09b6 100644 --- a/transport_test.go +++ b/transport_test.go @@ -9,69 +9,91 @@ import ( "errors" "io" "net" + "testing" + + "github.com/stretchr/testify/require" ic "github.com/libp2p/go-libp2p-core/crypto" tpt "github.com/libp2p/go-libp2p-core/transport" - "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + "github.com/lucas-clemente/quic-go" ) -var _ = Describe("Transport", func() { - var t tpt.Transport +func getTransport(t *testing.T) tpt.Transport { + t.Helper() + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + require.NoError(t, err) + tr, err := NewTransport(key, nil, nil) + require.NoError(t, err) + return tr +} - BeforeEach(func() { - rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) - Expect(err).ToNot(HaveOccurred()) - key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) - Expect(err).ToNot(HaveOccurred()) - t, err = NewTransport(key, nil, nil) - Expect(err).ToNot(HaveOccurred()) - }) +func TestQUICProtocol(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() - AfterEach(func() { - Expect(t.(io.Closer).Close()).To(Succeed()) - }) + protocols := tr.Protocols() + if len(protocols) != 1 { + t.Fatalf("expected to only support a single protocol, got %v", protocols) + } + if protocols[0] != ma.P_QUIC { + t.Fatalf("expected the supported protocol to be QUIC, got %d", protocols[0]) + } +} - It("says if it can dial an address", func() { - invalidAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234") - Expect(err).ToNot(HaveOccurred()) - validAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/1234/quic") - Expect(err).ToNot(HaveOccurred()) - Expect(t.CanDial(invalidAddr)).To(BeFalse()) - Expect(t.CanDial(validAddr)).To(BeTrue()) - }) +func TestCanDial(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() - It("says that it cannot dial /dns addresses", func() { - addr, err := ma.NewMultiaddr("/dns/google.com/udp/443/quic") - Expect(err).ToNot(HaveOccurred()) - Expect(t.CanDial(addr)).To(BeFalse()) - }) + invalid := []string{ + "/ip4/127.0.0.1/udp/1234", + "/ip4/5.5.5.5/tcp/1234", + "/dns/google.com/udp/443/quic", + } + valid := []string{ + "/ip4/127.0.0.1/udp/1234/quic", + "/ip4/5.5.5.5/udp/0/quic", + } + for _, s := range invalid { + invalidAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if tr.CanDial(invalidAddr) { + t.Errorf("didn't expect to be able to dial a non-quic address (%s)", invalidAddr) + } + } + for _, s := range valid { + validAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if !tr.CanDial(validAddr) { + t.Errorf("expected to be able to dial QUIC address (%s)", validAddr) + } + } +} - It("supports the QUIC protocol", func() { - protocols := t.Protocols() - Expect(protocols).To(HaveLen(1)) - Expect(protocols[0]).To(Equal(ma.P_QUIC)) - }) +// The connection passed to quic-go needs to be type-assertable to a net.UDPConn, +// in order to enable features like batch processing and ECN. +func TestConnectionPassedToQUIC(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() - It("uses a conn that can interface assert to a UDPConn for dialing", func() { - origQuicDialContext := quicDialContext - defer func() { quicDialContext = origQuicDialContext }() + origQuicDialContext := quicDialContext + defer func() { quicDialContext = origQuicDialContext }() - var conn net.PacketConn - quicDialContext = func(_ context.Context, c net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { - conn = c - return nil, errors.New("listen error") - } - remoteAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - _, err = t.Dial(context.Background(), remoteAddr, "remote peer id") - Expect(err).To(MatchError("listen error")) - Expect(conn).ToNot(BeNil()) - defer conn.Close() - _, ok := conn.(udpConn) - Expect(ok).To(BeTrue()) - }) -}) + var conn net.PacketConn + quicDialContext = func(_ context.Context, c net.PacketConn, _ net.Addr, _ string, _ *tls.Config, _ *quic.Config) (quic.Session, error) { + conn = c + return nil, errors.New("listen error") + } + remoteAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + require.NoError(t, err) + _, err = tr.Dial(context.Background(), remoteAddr, "remote peer id") + require.EqualError(t, err, "listen error") + require.NotNil(t, conn) + defer conn.Close() + if _, ok := conn.(udpConn); !ok { + t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") + } +} From f13c30d63e64c17abf9f23feef0b34e9ff810bc0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Jan 2022 17:26:32 +0400 Subject: [PATCH 2/8] migrate the conn tests away from Ginkgo --- conn_test.go | 688 ++++++++++++++++++++------------------- libp2pquic_suite_test.go | 6 - 2 files changed, 358 insertions(+), 336 deletions(-) diff --git a/conn_test.go b/conn_test.go index 82e2d62..8c69b55 100644 --- a/conn_test.go +++ b/conn_test.go @@ -10,6 +10,7 @@ import ( mrand "math/rand" "net" "sync/atomic" + "testing" "time" ic "github.com/libp2p/go-libp2p-core/crypto" @@ -20,406 +21,433 @@ import ( ma "github.com/multiformats/go-multiaddr" "github.com/golang/mock/gomock" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) //go:generate sh -c "mockgen -package libp2pquic -destination mock_connection_gater_test.go github.com/libp2p/go-libp2p-core/connmgr ConnectionGater && goimports -w mock_connection_gater_test.go" -var _ = Describe("Connection", func() { - var ( - serverKey, clientKey ic.PrivKey - serverID, clientID peer.ID - ) - - createPeer := func() (peer.ID, ic.PrivKey) { - var priv ic.PrivKey - var err error - switch mrand.Int() % 4 { - case 0: - fmt.Fprintf(GinkgoWriter, " using an ECDSA key: ") - priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader) - case 1: - fmt.Fprintf(GinkgoWriter, " using an RSA key: ") - priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader) - case 2: - fmt.Fprintf(GinkgoWriter, " using an Ed25519 key: ") - priv, _, err = ic.GenerateEd25519Key(rand.Reader) - case 3: - fmt.Fprintf(GinkgoWriter, " using an secp256k1 key: ") - priv, _, err = ic.GenerateSecp256k1Key(rand.Reader) - } - Expect(err).ToNot(HaveOccurred()) - id, err := peer.IDFromPrivateKey(priv) - Expect(err).ToNot(HaveOccurred()) - fmt.Fprintln(GinkgoWriter, id.Pretty()) - return id, priv - } - runServer := func(tr tpt.Transport, multiaddr string) tpt.Listener { - addr, err := ma.NewMultiaddr(multiaddr) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - ln, err := tr.Listen(addr) - ExpectWithOffset(1, err).ToNot(HaveOccurred()) - return ln +func createPeer(t *testing.T) (peer.ID, ic.PrivKey) { + var priv ic.PrivKey + var err error + switch mrand.Int() % 4 { + case 0: + priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader) + case 1: + priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader) + case 2: + priv, _, err = ic.GenerateEd25519Key(rand.Reader) + case 3: + priv, _, err = ic.GenerateSecp256k1Key(rand.Reader) } - - BeforeEach(func() { - serverID, serverKey = createPeer() - clientID, clientKey = createPeer() - }) - - It("handshakes on IPv4", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer serverTransport.(io.Closer).Close() - ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") - defer ln.Close() - + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + t.Logf("using a %s key: %s", priv.Type(), id.Pretty()) + return id, priv +} + +func runServer(t *testing.T, tr tpt.Transport, multiaddr string) tpt.Listener { + t.Helper() + addr, err := ma.NewMultiaddr(multiaddr) + require.NoError(t, err) + ln, err := tr.Listen(addr) + require.NoError(t, err) + return ln +} + +func TestHandshake(t *testing.T) { + serverID, serverKey := createPeer(t) + clientID, clientKey := createPeer(t) + serverTransport, err := NewTransport(serverKey, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + + handshake := func(t *testing.T, ln tpt.Listener) { clientTransport, err := NewTransport(clientKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) defer clientTransport.(io.Closer).Close() conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) defer conn.Close() serverConn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) defer serverConn.Close() - Expect(conn.LocalPeer()).To(Equal(clientID)) - Expect(conn.LocalPrivateKey()).To(Equal(clientKey)) - Expect(conn.RemotePeer()).To(Equal(serverID)) - Expect(conn.RemotePublicKey().Equals(serverKey.GetPublic())).To(BeTrue()) - Expect(serverConn.LocalPeer()).To(Equal(serverID)) - Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey)) - Expect(serverConn.RemotePeer()).To(Equal(clientID)) - Expect(serverConn.RemotePublicKey().Equals(clientKey.GetPublic())).To(BeTrue()) - }) - It("handshakes on IPv6", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer serverTransport.(io.Closer).Close() - ln := runServer(serverTransport, "/ip6/::1/udp/0/quic") - defer ln.Close() + require.Equal(t, conn.LocalPeer(), clientID) + require.True(t, conn.LocalPrivateKey().Equals(clientKey), "local private key doesn't match") + require.Equal(t, conn.RemotePeer(), serverID) + require.True(t, conn.RemotePublicKey().Equals(serverKey.GetPublic()), "remote public key doesn't match") - clientTransport, err := NewTransport(clientKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer clientTransport.(io.Closer).Close() - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) - Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - serverConn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) - defer serverConn.Close() - Expect(conn.LocalPeer()).To(Equal(clientID)) - Expect(conn.LocalPrivateKey()).To(Equal(clientKey)) - Expect(conn.RemotePeer()).To(Equal(serverID)) - Expect(conn.RemotePublicKey().Equals(serverKey.GetPublic())).To(BeTrue()) - Expect(serverConn.LocalPeer()).To(Equal(serverID)) - Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey)) - Expect(serverConn.RemotePeer()).To(Equal(clientID)) - Expect(serverConn.RemotePublicKey().Equals(clientKey.GetPublic())).To(BeTrue()) - }) + require.Equal(t, serverConn.LocalPeer(), serverID) + require.True(t, serverConn.LocalPrivateKey().Equals(serverKey), "local private key doesn't match") + require.Equal(t, serverConn.RemotePeer(), clientID) + require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "remote public key doesn't match") + } - It("opens and accepts streams", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer serverTransport.(io.Closer).Close() - ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") + t.Run("on IPv4", func(t *testing.T) { + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() - - clientTransport, err := NewTransport(clientKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer clientTransport.(io.Closer).Close() - conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) - Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - serverConn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) - defer serverConn.Close() - - str, err := conn.OpenStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - str.Close() - sstr, err := serverConn.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(sstr) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foobar"))) + handshake(t, ln) }) - It("fails if the peer ID doesn't match", func() { - thirdPartyID, _ := createPeer() + t.Run("on IPv6", func(t *testing.T) { + ln := runServer(t, serverTransport, "/ip6/::1/udp/0/quic") + defer ln.Close() + handshake(t, ln) + }) +} + +func TestStreams(t *testing.T) { + serverID, serverKey := createPeer(t) + _, clientKey := createPeer(t) + + serverTransport, err := NewTransport(serverKey, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic") + defer ln.Close() + + clientTransport, err := NewTransport(clientKey, nil, nil) + require.NoError(t, err) + defer clientTransport.(io.Closer).Close() + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + serverConn, err := ln.Accept() + require.NoError(t, err) + defer serverConn.Close() + + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + _, err = str.Write([]byte("foobar")) + require.NoError(t, err) + str.Close() + sstr, err := serverConn.AcceptStream() + require.NoError(t, err) + data, err := ioutil.ReadAll(sstr) + require.NoError(t, err) + require.Equal(t, data, []byte("foobar")) +} + +func TestHandshakeFailPeerIDMismatch(t *testing.T) { + _, serverKey := createPeer(t) + _, clientKey := createPeer(t) + thirdPartyID, _ := createPeer(t) + + serverTransport, err := NewTransport(serverKey, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic") + + clientTransport, err := NewTransport(clientKey, nil, nil) + require.NoError(t, err) + // dial, but expect the wrong peer ID + _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID) + require.Error(t, err) + require.Contains(t, err.Error(), "CRYPTO_ERROR") + defer clientTransport.(io.Closer).Close() + + acceptErr := make(chan error) + go func() { + _, err := ln.Accept() + acceptErr <- err + }() + + select { + case <-acceptErr: + t.Fatal("didn't expect Accept to return before being closed") + case <-time.After(100 * time.Millisecond): + } - serverTransport, err := NewTransport(serverKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer serverTransport.(io.Closer).Close() - ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") + require.NoError(t, ln.Close()) + require.Error(t, <-acceptErr) +} - clientTransport, err := NewTransport(clientKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - // dial, but expect the wrong peer ID - _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), thirdPartyID) - Expect(err).To(HaveOccurred()) - defer clientTransport.(io.Closer).Close() - Expect(err.Error()).To(ContainSubstring("CRYPTO_ERROR")) +func TestConnectionGating(t *testing.T) { + serverID, serverKey := createPeer(t) + _, clientKey := createPeer(t) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - ln.Accept() - }() - Consistently(done).ShouldNot(BeClosed()) - ln.Close() - Eventually(done).Should(BeClosed()) - }) + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + cg := NewMockConnectionGater(mockCtrl) - It("gates accepted connections", func() { - cg := NewMockConnectionGater(mockCtrl) - cg.EXPECT().InterceptAccept(gomock.Any()) + t.Run("accepted connections", func(t *testing.T) { serverTransport, err := NewTransport(serverKey, nil, cg) - Expect(err).ToNot(HaveOccurred()) defer serverTransport.(io.Closer).Close() - ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") + require.NoError(t, err) + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() + cg.EXPECT().InterceptAccept(gomock.Any()) + accepted := make(chan struct{}) go func() { - defer GinkgoRecover() defer close(accepted) _, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) }() clientTransport, err := NewTransport(clientKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) defer clientTransport.(io.Closer).Close() // make sure that connection attempts fails conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) _, err = conn.AcceptStream() - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("connection gated")) + require.Error(t, err) + require.Contains(t, err.Error(), "connection gated") // now allow the address and make sure the connection goes through cg.EXPECT().InterceptAccept(gomock.Any()).Return(true) cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) clientTransport.(*transport).clientConfig.HandshakeIdleTimeout = 2 * time.Second conn, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) defer conn.Close() - Eventually(accepted).Should(BeClosed()) + require.Eventually(t, func() bool { + select { + case <-accepted: + return true + default: + return false + } + }, time.Second, 10*time.Millisecond) }) - It("gates secured connections", func() { + t.Run("secured connections", func(t *testing.T) { serverTransport, err := NewTransport(serverKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) defer serverTransport.(io.Closer).Close() - ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic") defer ln.Close() cg := NewMockConnectionGater(mockCtrl) cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()) clientTransport, err := NewTransport(clientKey, nil, cg) - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) defer clientTransport.(io.Closer).Close() // make sure that connection attempts fails _, err = clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("connection gated")) + require.Error(t, err) + require.Contains(t, err.Error(), "connection gated") // now allow the peerId and make sure the connection goes through cg.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) clientTransport.(*transport).clientConfig.HandshakeIdleTimeout = 2 * time.Second conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) - Expect(err).ToNot(HaveOccurred()) + require.NoError(t, err) conn.Close() }) - - It("dials to two servers at the same time", func() { - serverID2, serverKey2 := createPeer() - - serverTransport, err := NewTransport(serverKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer serverTransport.(io.Closer).Close() - ln1 := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") - defer ln1.Close() - serverTransport2, err := NewTransport(serverKey2, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer serverTransport2.(io.Closer).Close() - ln2 := runServer(serverTransport2, "/ip4/127.0.0.1/udp/0/quic") - defer ln2.Close() - - data := bytes.Repeat([]byte{'a'}, 5*1<<20) // 5 MB - // wait for both servers to accept a connection - // then send some data - go func() { - serverConn1, err := ln1.Accept() - Expect(err).ToNot(HaveOccurred()) - serverConn2, err := ln2.Accept() - Expect(err).ToNot(HaveOccurred()) - - for _, c := range []tpt.CapableConn{serverConn1, serverConn2} { - go func(conn tpt.CapableConn) { - defer GinkgoRecover() - str, err := conn.OpenStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - defer str.Close() - _, err = str.Write(data) - Expect(err).ToNot(HaveOccurred()) - }(c) - } - }() - - clientTransport, err := NewTransport(clientKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer clientTransport.(io.Closer).Close() - c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID) - Expect(err).ToNot(HaveOccurred()) - defer c1.Close() - c2, err := clientTransport.Dial(context.Background(), ln2.Multiaddr(), serverID2) - Expect(err).ToNot(HaveOccurred()) - defer c2.Close() - - done := make(chan struct{}, 2) - // receive the data on both connections at the same time - for _, c := range []tpt.CapableConn{c1, c2} { +} + +func TestDialTwo(t *testing.T) { + serverID, serverKey := createPeer(t) + _, clientKey := createPeer(t) + serverID2, serverKey2 := createPeer(t) + + serverTransport, err := NewTransport(serverKey, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + ln1 := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic") + defer ln1.Close() + serverTransport2, err := NewTransport(serverKey2, nil, nil) + require.NoError(t, err) + defer serverTransport2.(io.Closer).Close() + ln2 := runServer(t, serverTransport2, "/ip4/127.0.0.1/udp/0/quic") + defer ln2.Close() + + data := bytes.Repeat([]byte{'a'}, 5*1<<20) // 5 MB + // wait for both servers to accept a connection + // then send some data + go func() { + serverConn1, err := ln1.Accept() + require.NoError(t, err) + serverConn2, err := ln2.Accept() + require.NoError(t, err) + + for _, c := range []tpt.CapableConn{serverConn1, serverConn2} { go func(conn tpt.CapableConn) { - defer GinkgoRecover() - str, err := conn.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - str.CloseWrite() - d, err := ioutil.ReadAll(str) - Expect(err).ToNot(HaveOccurred()) - Expect(d).To(Equal(data)) - done <- struct{}{} + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + defer str.Close() + _, err = str.Write(data) + require.NoError(t, err) }(c) } + }() + + clientTransport, err := NewTransport(clientKey, nil, nil) + require.NoError(t, err) + defer clientTransport.(io.Closer).Close() + c1, err := clientTransport.Dial(context.Background(), ln1.Multiaddr(), serverID) + require.NoError(t, err) + defer c1.Close() + c2, err := clientTransport.Dial(context.Background(), ln2.Multiaddr(), serverID2) + require.NoError(t, err) + defer c2.Close() + + done := make(chan struct{}, 2) + // receive the data on both connections at the same time + for _, c := range []tpt.CapableConn{c1, c2} { + go func(conn tpt.CapableConn) { + str, err := conn.AcceptStream() + require.NoError(t, err) + str.CloseWrite() + d, err := ioutil.ReadAll(str) + require.NoError(t, err) + require.Equal(t, d, data) + done <- struct{}{} + }(c) + } - Eventually(done, 15*time.Second).Should(Receive()) - Eventually(done, 15*time.Second).Should(Receive()) - }) + for i := 0; i < 2; i++ { + require.Eventually(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }, 15*time.Second, 50*time.Millisecond) + } +} - It("sends stateless resets", func() { - serverTransport, err := NewTransport(serverKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer serverTransport.(io.Closer).Close() - ln := runServer(serverTransport, "/ip4/127.0.0.1/udp/0/quic") - - var drop uint32 - serverPort := ln.Addr().(*net.UDPAddr).Port - proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ - RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), - DropPacket: func(quicproxy.Direction, []byte) bool { - return atomic.LoadUint32(&drop) > 0 - }, - }) - Expect(err).ToNot(HaveOccurred()) - defer proxy.Close() - - // establish a connection - clientTransport, err := NewTransport(clientKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer clientTransport.(io.Closer).Close() - proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr()) - Expect(err).ToNot(HaveOccurred()) - conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID) - Expect(err).ToNot(HaveOccurred()) - go func() { - defer GinkgoRecover() - conn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) - str, err := conn.OpenStream(context.Background()) - Expect(err).ToNot(HaveOccurred()) - str.Write([]byte("foobar")) - }() +func TestStatelessReset(t *testing.T) { + origGarbageCollectInterval := garbageCollectInterval + origMaxUnusedDuration := maxUnusedDuration - str, err := conn.AcceptStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Read(make([]byte, 6)) - Expect(err).ToNot(HaveOccurred()) - - // Stop forwarding packets and close the server. - // This prevents the CONNECTION_CLOSE from reaching the client. - atomic.StoreUint32(&drop, 1) - Expect(ln.Close()).To(Succeed()) - time.Sleep(100 * time.Millisecond) // give the kernel some time to free the UDP port - ln = runServer(serverTransport, fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", serverPort)) - defer ln.Close() - // Now that the new server is up, re-enable packet forwarding. - atomic.StoreUint32(&drop, 0) - - // Trigger something (not too small) to be sent, so that we receive the stateless reset. - // The new server doesn't have any state for the previously established connection. - // We expect it to send a stateless reset. - _, rerr := str.Write([]byte("Lorem ipsum dolor sit amet.")) - if rerr == nil { - _, rerr = str.Read([]byte{0, 0}) - } - Expect(rerr).To(HaveOccurred()) - Expect(rerr.Error()).To(ContainSubstring("received a stateless reset")) - }) + garbageCollectInterval = 50 * time.Millisecond + maxUnusedDuration = 0 - It("hole punches", func() { - t1, err := NewTransport(serverKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer t1.(io.Closer).Close() - laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - ln1, err := t1.Listen(laddr) - Expect(err).ToNot(HaveOccurred()) - done1 := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done1) - if _, err := ln1.Accept(); err == nil { - Fail("didn't expect to accept any connections") - } - }() + t.Cleanup(func() { + garbageCollectInterval = origGarbageCollectInterval + maxUnusedDuration = origMaxUnusedDuration + }) - t2, err := NewTransport(clientKey, nil, nil) - Expect(err).ToNot(HaveOccurred()) - defer t2.(io.Closer).Close() - ln2, err := t2.Listen(laddr) - Expect(err).ToNot(HaveOccurred()) - done2 := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done2) - if _, err := ln2.Accept(); err == nil { - Fail("didn't expect to accept any connections") - } - }() - connChan := make(chan tpt.CapableConn) - go func() { - defer GinkgoRecover() - conn, err := t2.Dial( - n.WithSimultaneousConnect(context.Background(), false, ""), - ln1.Multiaddr(), - serverID, - ) - Expect(err).ToNot(HaveOccurred()) - connChan <- conn - }() - conn1, err := t1.Dial( - n.WithSimultaneousConnect(context.Background(), true, ""), - ln2.Multiaddr(), - clientID, - ) - Expect(err).ToNot(HaveOccurred()) - defer conn1.Close() - Expect(conn1.RemotePeer()).To(Equal(clientID)) - var conn2 tpt.CapableConn - Eventually(connChan).Should(Receive(&conn2)) - defer conn2.Close() - Expect(conn2.RemotePeer()).To(Equal(serverID)) - ln1.Close() - ln2.Close() - Eventually(done1).Should(BeClosed()) - Eventually(done2).Should(BeClosed()) + serverID, serverKey := createPeer(t) + _, clientKey := createPeer(t) + + serverTransport, err := NewTransport(serverKey, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic") + + var drop uint32 + serverPort := ln.Addr().(*net.UDPAddr).Port + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DropPacket: func(quicproxy.Direction, []byte) bool { + return atomic.LoadUint32(&drop) > 0 + }, }) -}) + require.NoError(t, err) + defer proxy.Close() + + // establish a connection + clientTransport, err := NewTransport(clientKey, nil, nil) + require.NoError(t, err) + defer clientTransport.(io.Closer).Close() + proxyAddr, err := toQuicMultiaddr(proxy.LocalAddr()) + require.NoError(t, err) + conn, err := clientTransport.Dial(context.Background(), proxyAddr, serverID) + require.NoError(t, err) + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + str.Write([]byte("foobar")) + }() + + str, err := conn.AcceptStream() + require.NoError(t, err) + _, err = str.Read(make([]byte, 6)) + require.NoError(t, err) + + // Stop forwarding packets and close the server. + // This prevents the CONNECTION_CLOSE from reaching the client. + atomic.StoreUint32(&drop, 1) + ln.Close() + // require.NoError(t, ln.Close()) + time.Sleep(2000 * time.Millisecond) // give the kernel some time to free the UDP port + ln = runServer(t, serverTransport, fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", serverPort)) + defer ln.Close() + // Now that the new server is up, re-enable packet forwarding. + atomic.StoreUint32(&drop, 0) + + // Trigger something (not too small) to be sent, so that we receive the stateless reset. + // The new server doesn't have any state for the previously established connection. + // We expect it to send a stateless reset. + _, rerr := str.Write([]byte("Lorem ipsum dolor sit amet.")) + if rerr == nil { + _, rerr = str.Read([]byte{0, 0}) + } + require.Error(t, rerr) + require.Contains(t, rerr.Error(), "received a stateless reset") +} + +func TestHolePunching(t *testing.T) { + serverID, serverKey := createPeer(t) + clientID, clientKey := createPeer(t) + + t1, err := NewTransport(serverKey, nil, nil) + require.NoError(t, err) + defer t1.(io.Closer).Close() + laddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + require.NoError(t, err) + ln1, err := t1.Listen(laddr) + require.NoError(t, err) + done1 := make(chan struct{}) + go func() { + defer close(done1) + _, err := ln1.Accept() + require.Error(t, err, "didn't expect to accept any connections") + }() + + t2, err := NewTransport(clientKey, nil, nil) + require.NoError(t, err) + defer t2.(io.Closer).Close() + ln2, err := t2.Listen(laddr) + require.NoError(t, err) + done2 := make(chan struct{}) + go func() { + defer close(done2) + _, err := ln2.Accept() + require.Error(t, err, "didn't expect to accept any connections") + }() + connChan := make(chan tpt.CapableConn) + go func() { + conn, err := t2.Dial( + n.WithSimultaneousConnect(context.Background(), false, ""), + ln1.Multiaddr(), + serverID, + ) + require.NoError(t, err) + connChan <- conn + }() + conn1, err := t1.Dial( + n.WithSimultaneousConnect(context.Background(), true, ""), + ln2.Multiaddr(), + clientID, + ) + require.NoError(t, err) + defer conn1.Close() + require.Equal(t, conn1.RemotePeer(), clientID) + var conn2 tpt.CapableConn + require.Eventually(t, func() bool { + select { + case conn2 = <-connChan: + return true + default: + return false + } + }, 100*time.Millisecond, 10*time.Millisecond) + defer conn2.Close() + require.Equal(t, conn2.RemotePeer(), serverID) + ln1.Close() + ln2.Close() + <-done1 + <-done2 +} diff --git a/libp2pquic_suite_test.go b/libp2pquic_suite_test.go index 8bef230..5144ae3 100644 --- a/libp2pquic_suite_test.go +++ b/libp2pquic_suite_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" "github.com/lucas-clemente/quic-go" . "github.com/onsi/ginkgo" @@ -25,12 +24,9 @@ var ( garbageCollectIntervalOrig time.Duration maxUnusedDurationOrig time.Duration origQuicConfig *quic.Config - mockCtrl *gomock.Controller ) var _ = BeforeEach(func() { - mockCtrl = gomock.NewController(GinkgoT()) - garbageCollectIntervalOrig = garbageCollectInterval maxUnusedDurationOrig = maxUnusedDuration garbageCollectInterval = 50 * time.Millisecond @@ -40,8 +36,6 @@ var _ = BeforeEach(func() { }) var _ = AfterEach(func() { - mockCtrl.Finish() - garbageCollectInterval = garbageCollectIntervalOrig maxUnusedDuration = maxUnusedDurationOrig quicConfig = origQuicConfig From d7838088ca0672f6d34c2b4a4b6a63e33fb90378 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Jan 2022 17:57:23 +0400 Subject: [PATCH 3/8] use the quic.OOBCapablePacketConn in tests --- listener_test.go | 10 +--------- transport_test.go | 2 +- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/listener_test.go b/listener_test.go index f8d2041..3958778 100644 --- a/listener_test.go +++ b/listener_test.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "net" - "syscall" ic "github.com/libp2p/go-libp2p-core/crypto" tpt "github.com/libp2p/go-libp2p-core/transport" @@ -20,13 +19,6 @@ import ( . "github.com/onsi/gomega" ) -// interface containing some methods defined on the net.UDPConn, but not the net.PacketConn -type udpConn interface { - ReadFromUDP(b []byte) (int, *net.UDPAddr, error) - SetReadBuffer(bytes int) error - SyscallConn() (syscall.RawConn, error) -} - var _ = Describe("Listener", func() { var t tpt.Transport @@ -58,7 +50,7 @@ var _ = Describe("Listener", func() { Expect(err).To(MatchError("listen error")) Expect(conn).ToNot(BeNil()) defer conn.Close() - _, ok := conn.(udpConn) + _, ok := conn.(quic.OOBCapablePacketConn) Expect(ok).To(BeTrue()) }) diff --git a/transport_test.go b/transport_test.go index 6ea09b6..72eb0c4 100644 --- a/transport_test.go +++ b/transport_test.go @@ -93,7 +93,7 @@ func TestConnectionPassedToQUIC(t *testing.T) { require.EqualError(t, err, "listen error") require.NotNil(t, conn) defer conn.Close() - if _, ok := conn.(udpConn); !ok { + if _, ok := conn.(quic.OOBCapablePacketConn); !ok { t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn") } } From 1303b904568de427a7408b2d96692e20308808d0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Jan 2022 18:12:40 +0400 Subject: [PATCH 4/8] migrate the listener tests away from Ginkgo --- listener_test.go | 191 +++++++++++++++++++++-------------------------- 1 file changed, 87 insertions(+), 104 deletions(-) diff --git a/listener_test.go b/listener_test.go index 3958778..4de9da3 100644 --- a/listener_test.go +++ b/listener_test.go @@ -9,123 +9,106 @@ import ( "fmt" "io" "net" + "testing" + "time" ic "github.com/libp2p/go-libp2p-core/crypto" tpt "github.com/libp2p/go-libp2p-core/transport" - "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go" ma "github.com/multiformats/go-multiaddr" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("Listener", func() { - var t tpt.Transport - - BeforeEach(func() { - rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) - Expect(err).ToNot(HaveOccurred()) - key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) - Expect(err).ToNot(HaveOccurred()) - t, err = NewTransport(key, nil, nil) - Expect(err).ToNot(HaveOccurred()) - }) - - AfterEach(func() { - Expect(t.(io.Closer).Close()).To(Succeed()) - }) +func newTransport(t *testing.T) tpt.Transport { + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + require.NoError(t, err) + tr, err := NewTransport(key, nil, nil) + require.NoError(t, err) + return tr +} - It("uses a conn that can interface assert to a UDPConn for listening", func() { - origQuicListen := quicListen - defer func() { quicListen = origQuicListen }() +// The conn passed to quic-go should be a conn that quic-go can be +// type-asserted to a UDPConn. That way, it can use all kinds of optimizations. +func TestConnUsedForListening(t *testing.T) { + origQuicListen := quicListen + t.Cleanup(func() { quicListen = origQuicListen }) - var conn net.PacketConn - quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) { - conn = c - return nil, errors.New("listen error") - } - localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - _, err = t.Listen(localAddr) - Expect(err).To(MatchError("listen error")) - Expect(conn).ToNot(BeNil()) - defer conn.Close() - _, ok := conn.(quic.OOBCapablePacketConn) - Expect(ok).To(BeTrue()) - }) + var conn net.PacketConn + quicListen = func(c net.PacketConn, _ *tls.Config, _ *quic.Config) (quic.Listener, error) { + conn = c + return nil, errors.New("listen error") + } + localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") + require.NoError(t, err) - Context("listening on the right address", func() { - It("returns the address it is listening on", func() { - localAddr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - ln, err := t.Listen(localAddr) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - netAddr := ln.Addr() - Expect(netAddr).To(BeAssignableToTypeOf(&net.UDPAddr{})) - port := netAddr.(*net.UDPAddr).Port - Expect(port).ToNot(BeZero()) - Expect(ln.Multiaddr().String()).To(Equal(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", port))) - }) + tr := newTransport(t) + defer tr.(io.Closer).Close() + _, err = tr.Listen(localAddr) + require.EqualError(t, err, "listen error") + require.NotNil(t, conn) + defer conn.Close() + _, ok := conn.(quic.OOBCapablePacketConn) + require.True(t, ok) +} - It("returns the address it is listening on, for listening on IPv4", func() { - localAddr, err := ma.NewMultiaddr("/ip4/0.0.0.0/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - ln, err := t.Listen(localAddr) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - netAddr := ln.Addr() - Expect(netAddr).To(BeAssignableToTypeOf(&net.UDPAddr{})) - port := netAddr.(*net.UDPAddr).Port - Expect(port).ToNot(BeZero()) - Expect(ln.Multiaddr().String()).To(Equal(fmt.Sprintf("/ip4/0.0.0.0/udp/%d/quic", port))) - }) +func TestListenAddr(t *testing.T) { + tr := newTransport(t) + defer tr.(io.Closer).Close() - It("returns the address it is listening on, for listening on IPv6", func() { - localAddr, err := ma.NewMultiaddr("/ip6/::/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - ln, err := t.Listen(localAddr) - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - netAddr := ln.Addr() - Expect(netAddr).To(BeAssignableToTypeOf(&net.UDPAddr{})) - port := netAddr.(*net.UDPAddr).Port - Expect(port).ToNot(BeZero()) - Expect(ln.Multiaddr().String()).To(Equal(fmt.Sprintf("/ip6/::/udp/%d/quic", port))) - }) + t.Run("for IPv4", func(t *testing.T) { + localAddr := ma.StringCast("/ip4/127.0.0.1/udp/0/quic") + ln, err := tr.Listen(localAddr) + require.NoError(t, err) + defer ln.Close() + port := ln.Addr().(*net.UDPAddr).Port + require.NotZero(t, port) + require.Equal(t, ln.Multiaddr().String(), fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic", port)) }) - Context("accepting connections", func() { - var localAddr ma.Multiaddr - - BeforeEach(func() { - var err error - localAddr, err = ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - }) + t.Run("for IPv6", func(t *testing.T) { + localAddr := ma.StringCast("/ip6/::/udp/0/quic") + ln, err := tr.Listen(localAddr) + require.NoError(t, err) + defer ln.Close() + port := ln.Addr().(*net.UDPAddr).Port + require.NotZero(t, port) + require.Equal(t, ln.Multiaddr().String(), fmt.Sprintf("/ip6/::/udp/%d/quic", port)) + }) +} - It("returns Accept when it is closed", func() { - addr, err := ma.NewMultiaddr("/ip4/127.0.0.1/udp/0/quic") - Expect(err).ToNot(HaveOccurred()) - ln, err := t.Listen(addr) - Expect(err).ToNot(HaveOccurred()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - ln.Accept() - close(done) - }() - Consistently(done).ShouldNot(BeClosed()) - Expect(ln.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) +func TestAccepting(t *testing.T) { + tr := newTransport(t) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")) + require.NoError(t, err) + done := make(chan struct{}) + go func() { + ln.Accept() + close(done) + }() + time.Sleep(100 * time.Millisecond) + select { + case <-done: + t.Fatal("Accept didn't block") + default: + } + require.NoError(t, ln.Close()) + select { + case <-done: + case <-time.After(100 * time.Millisecond): + t.Fatal("Accept didn't return after the listener was closed") + } +} - It("doesn't accept Accept calls after it is closed", func() { - ln, err := t.Listen(localAddr) - Expect(err).ToNot(HaveOccurred()) - Expect(ln.Close()).To(Succeed()) - _, err = ln.Accept() - Expect(err).To(HaveOccurred()) - }) - }) -}) +func TestAcceptAfterClose(t *testing.T) { + tr := newTransport(t) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic")) + require.NoError(t, err) + require.NoError(t, ln.Close()) + _, err = ln.Accept() + require.Error(t, err) +} From 53419a4a1472864f50f812d66dc8ee7e3fd02cbf Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Jan 2022 18:30:29 +0400 Subject: [PATCH 5/8] migrate the multiaddr tests away from Ginkgo --- quic_multiaddr_test.go | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/quic_multiaddr_test.go b/quic_multiaddr_test.go index 48949d3..db7cdb3 100644 --- a/quic_multiaddr_test.go +++ b/quic_multiaddr_test.go @@ -2,29 +2,26 @@ package libp2pquic import ( "net" + "testing" ma "github.com/multiformats/go-multiaddr" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) -var _ = Describe("QUIC Multiaddr", func() { - It("converts a net.Addr to a QUIC Multiaddr", func() { - addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 42), Port: 1337} - maddr, err := toQuicMultiaddr(addr) - Expect(err).ToNot(HaveOccurred()) - Expect(maddr.String()).To(Equal("/ip4/192.168.0.42/udp/1337/quic")) - }) +func TestConvertToQuicMultiaddr(t *testing.T) { + addr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 42), Port: 1337} + maddr, err := toQuicMultiaddr(addr) + require.NoError(t, err) + require.Equal(t, maddr.String(), "/ip4/192.168.0.42/udp/1337/quic") +} - It("converts a QUIC Multiaddr to a net.Addr", func() { - maddr, err := ma.NewMultiaddr("/ip4/192.168.0.42/udp/1337/quic") - Expect(err).ToNot(HaveOccurred()) - addr, err := fromQuicMultiaddr(maddr) - Expect(err).ToNot(HaveOccurred()) - Expect(addr).To(BeAssignableToTypeOf(&net.UDPAddr{})) - udpAddr := addr.(*net.UDPAddr) - Expect(udpAddr.IP).To(Equal(net.IPv4(192, 168, 0, 42))) - Expect(udpAddr.Port).To(Equal(1337)) - }) -}) +func TestConvertFromQuicMultiaddr(t *testing.T) { + maddr, err := ma.NewMultiaddr("/ip4/192.168.0.42/udp/1337/quic") + require.NoError(t, err) + addr, err := fromQuicMultiaddr(maddr) + require.NoError(t, err) + udpAddr, ok := addr.(*net.UDPAddr) + require.True(t, ok) + require.Equal(t, udpAddr.IP, net.IPv4(192, 168, 0, 42)) + require.Equal(t, udpAddr.Port, 1337) +} From 9362cf2582155add6118a456ff7971f4cce3f96b Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Jan 2022 18:57:22 +0400 Subject: [PATCH 6/8] migrate the reuse tests away from Ginkgo --- reuse_test.go | 233 +++++++++++++++++++++++++++----------------------- 1 file changed, 124 insertions(+), 109 deletions(-) diff --git a/reuse_test.go b/reuse_test.go index b43d23d..46fb821 100644 --- a/reuse_test.go +++ b/reuse_test.go @@ -5,11 +5,11 @@ import ( "net" "runtime/pprof" "strings" + "testing" "time" "github.com/libp2p/go-netroute" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + "github.com/stretchr/testify/require" ) func (c *reuseConn) GetCount() int { @@ -35,124 +35,139 @@ func closeAllConns(reuse *reuse) { reuse.mutex.Unlock() } -func OnPlatformsWithRoutingTablesIt(description string, f interface{}) { - if _, err := netroute.New(); err == nil { - It(description, f) - } else { - PIt(description, f) - } +func platformHasRoutingTables() bool { + _, err := netroute.New() + return err == nil } -var _ = Describe("Reuse", func() { - var reuse *reuse +func isGarbageCollectorRunning() bool { + var b bytes.Buffer + pprof.Lookup("goroutine").WriteTo(&b, 1) + return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).gc") +} - BeforeEach(func() { - reuse = newReuse() +func cleanup(t *testing.T, reuse *reuse) { + t.Cleanup(func() { + closeAllConns(reuse) + reuse.Close() + require.False(t, isGarbageCollectorRunning(), "reuse gc still running") }) +} - AfterEach(func() { - isGarbageCollectorRunning := func() bool { - var b bytes.Buffer - pprof.Lookup("goroutine").WriteTo(&b, 1) - return strings.Contains(b.String(), "go-libp2p-quic-transport.(*reuse).gc") - } +func TestReuseListenOnAllIPv4(t *testing.T) { + reuse := newReuse() + cleanup(t, reuse) - Expect(reuse.Close()).To(Succeed()) - Expect(isGarbageCollectorRunning()).To(BeFalse()) - }) + addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") + require.NoError(t, err) + conn, err := reuse.Listen("udp4", addr) + require.NoError(t, err) + require.Equal(t, conn.GetCount(), 1) +} + +func TestReuseListenOnAllIPv6(t *testing.T) { + reuse := newReuse() + cleanup(t, reuse) + + addr, err := net.ResolveUDPAddr("udp6", "[::]:1234") + require.NoError(t, err) + conn, err := reuse.Listen("udp6", addr) + require.NoError(t, err) + require.Equal(t, conn.GetCount(), 1) +} + +func TestReuseCreateNewGlobalConnOnDial(t *testing.T) { + reuse := newReuse() + cleanup(t, reuse) + + addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") + require.NoError(t, err) + conn, err := reuse.Dial("udp4", addr) + require.NoError(t, err) + require.Equal(t, conn.GetCount(), 1) + laddr := conn.LocalAddr().(*net.UDPAddr) + require.Equal(t, laddr.IP.String(), "0.0.0.0") + require.NotEqual(t, laddr.Port, 0) +} - Context("creating and reusing connections", func() { - AfterEach(func() { closeAllConns(reuse) }) - - It("creates a new global connection when listening on 0.0.0.0", func() { - addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") - Expect(err).ToNot(HaveOccurred()) - conn, err := reuse.Listen("udp4", addr) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.GetCount()).To(Equal(1)) - }) - - It("creates a new global connection when listening on [::]", func() { - addr, err := net.ResolveUDPAddr("udp6", "[::]:1234") - Expect(err).ToNot(HaveOccurred()) - conn, err := reuse.Listen("udp6", addr) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.GetCount()).To(Equal(1)) - }) - - It("creates a new global connection when dialing", func() { - addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") - Expect(err).ToNot(HaveOccurred()) - conn, err := reuse.Dial("udp4", addr) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.GetCount()).To(Equal(1)) - laddr := conn.LocalAddr().(*net.UDPAddr) - Expect(laddr.IP.String()).To(Equal("0.0.0.0")) - Expect(laddr.Port).ToNot(BeZero()) - }) - - It("reuses a connection it created for listening when dialing", func() { - // listen - addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") - Expect(err).ToNot(HaveOccurred()) - lconn, err := reuse.Listen("udp4", addr) - Expect(err).ToNot(HaveOccurred()) - Expect(lconn.GetCount()).To(Equal(1)) - // dial - raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") - Expect(err).ToNot(HaveOccurred()) - conn, err := reuse.Dial("udp4", raddr) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.GetCount()).To(Equal(2)) - }) - - OnPlatformsWithRoutingTablesIt("reuses a connection it created for listening on a specific interface", func() { - router, err := netroute.New() - Expect(err).ToNot(HaveOccurred()) - - raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") - Expect(err).ToNot(HaveOccurred()) - _, _, ip, err := router.Route(raddr.IP) - Expect(err).ToNot(HaveOccurred()) - // listen - addr, err := net.ResolveUDPAddr("udp4", ip.String()+":0") - Expect(err).ToNot(HaveOccurred()) - lconn, err := reuse.Listen("udp4", addr) - Expect(err).ToNot(HaveOccurred()) - Expect(lconn.GetCount()).To(Equal(1)) - // dial - conn, err := reuse.Dial("udp4", raddr) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.GetCount()).To(Equal(2)) - }) +func TestReuseConnectionWhenDialing(t *testing.T) { + reuse := newReuse() + cleanup(t, reuse) + + addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") + require.NoError(t, err) + lconn, err := reuse.Listen("udp4", addr) + require.NoError(t, err) + require.Equal(t, lconn.GetCount(), 1) + // dial + raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") + require.NoError(t, err) + conn, err := reuse.Dial("udp4", raddr) + require.NoError(t, err) + require.Equal(t, conn.GetCount(), 2) +} + +func TestReuseListenOnSpecificInterface(t *testing.T) { + if platformHasRoutingTables() { + t.Skip("this test only works on platforms that support routing tables") + } + reuse := newReuse() + cleanup(t, reuse) + + router, err := netroute.New() + require.NoError(t, err) + + raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") + require.NoError(t, err) + _, _, ip, err := router.Route(raddr.IP) + require.NoError(t, err) + // listen + addr, err := net.ResolveUDPAddr("udp4", ip.String()+":0") + require.NoError(t, err) + lconn, err := reuse.Listen("udp4", addr) + require.NoError(t, err) + require.Equal(t, lconn.GetCount(), 1) + // dial + conn, err := reuse.Dial("udp4", raddr) + require.NoError(t, err) + require.Equal(t, conn.GetCount(), 1) +} + +func TestReuseGarbageCollect(t *testing.T) { + maxUnusedDurationOrig := maxUnusedDuration + garbageCollectIntervalOrig := garbageCollectInterval + t.Cleanup(func() { + maxUnusedDuration = maxUnusedDurationOrig + garbageCollectInterval = garbageCollectIntervalOrig }) + garbageCollectInterval = 50 * time.Millisecond + maxUnusedDuration = 100 * time.Millisecond - It("garbage collects connections once they're not used any more for a certain time", func() { - numGlobals := func() int { - reuse.mutex.Lock() - defer reuse.mutex.Unlock() - return len(reuse.global) - } + reuse := newReuse() + cleanup(t, reuse) - maxUnusedDuration = 100 * time.Millisecond + numGlobals := func() int { + reuse.mutex.Lock() + defer reuse.mutex.Unlock() + return len(reuse.global) + } - addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") - Expect(err).ToNot(HaveOccurred()) - lconn, err := reuse.Listen("udp4", addr) - Expect(err).ToNot(HaveOccurred()) - Expect(lconn.GetCount()).To(Equal(1)) + addr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:0") + require.NoError(t, err) + lconn, err := reuse.Listen("udp4", addr) + require.NoError(t, err) + require.Equal(t, lconn.GetCount(), 1) - closeTime := time.Now() - lconn.DecreaseCount() + closeTime := time.Now() + lconn.DecreaseCount() - for { - num := numGlobals() - if closeTime.Add(maxUnusedDuration).Before(time.Now()) { - break - } - Expect(num).To(Equal(1)) - time.Sleep(2 * time.Millisecond) + for { + num := numGlobals() + if closeTime.Add(maxUnusedDuration).Before(time.Now()) { + break } - Eventually(numGlobals).Should(BeZero()) - }) -}) + require.Equal(t, num, 1) + time.Sleep(2 * time.Millisecond) + } + require.Eventually(t, func() bool { return numGlobals() == 0 }, 100*time.Millisecond, 5*time.Millisecond) +} From 0b9d47ec3d51e1e19dcd2cf3787700945611f098 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Jan 2022 19:07:13 +0400 Subject: [PATCH 7/8] migrate the tracer tests away from Ginkgo --- tracer_test.go | 124 +++++++++++++++++++++++-------------------------- 1 file changed, 58 insertions(+), 66 deletions(-) diff --git a/tracer_test.go b/tracer_test.go index 145a326..2f6d582 100644 --- a/tracer_test.go +++ b/tracer_test.go @@ -2,81 +2,73 @@ package libp2pquic import ( "bytes" - "fmt" "io/ioutil" "os" + "strings" + "testing" + + "github.com/stretchr/testify/require" "github.com/klauspost/compress/zstd" "github.com/lucas-clemente/quic-go/logging" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" ) -var _ = Describe("qlogger", func() { - var qlogDir string - - BeforeEach(func() { - var err error - qlogDir, err = ioutil.TempDir("", "libp2p-quic-transport-test") - Expect(err).ToNot(HaveOccurred()) - fmt.Fprintf(GinkgoWriter, "Creating temporary directory: %s\n", qlogDir) - initQlogger(qlogDir) - }) - - AfterEach(func() { - Expect(os.RemoveAll(qlogDir)).To(Succeed()) - }) +func createLogDir(t *testing.T) string { + dir, err := ioutil.TempDir("", "libp2p-quic-transport-test") + require.NoError(t, err) + t.Cleanup(func() { os.RemoveAll(dir) }) + return dir +} - getFile := func() os.FileInfo { - files, err := ioutil.ReadDir(qlogDir) - Expect(err).ToNot(HaveOccurred()) - Expect(files).To(HaveLen(1)) - return files[0] - } +func getFile(t *testing.T, dir string) os.FileInfo { + files, err := ioutil.ReadDir(dir) + require.NoError(t, err) + require.Len(t, files, 1) + return files[0] +} - It("saves a qlog", func() { - logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte{0xde, 0xad, 0xbe, 0xef}) - file := getFile() - Expect(string(file.Name()[0])).To(Equal(".")) - Expect(file.Name()).To(HaveSuffix(".qlog.swp")) - // close the logger. This should move the file. - Expect(logger.Close()).To(Succeed()) - file = getFile() - Expect(string(file.Name()[0])).ToNot(Equal(".")) - Expect(file.Name()).To(HaveSuffix(".qlog.zst")) - Expect(file.Name()).To(And( - ContainSubstring("server"), - ContainSubstring("deadbeef"), - )) - }) +func TestSaveQlog(t *testing.T) { + qlogDir := createLogDir(t) + logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte{0xde, 0xad, 0xbe, 0xef}) + file := getFile(t, qlogDir) + require.Equal(t, string(file.Name()[0]), ".") + require.Truef(t, strings.HasSuffix(file.Name(), ".qlog.swp"), "expected %s to have the .qlog.swp file ending", file.Name()) + // close the logger. This should move the file. + require.NoError(t, logger.Close()) + file = getFile(t, qlogDir) + require.NotEqual(t, string(file.Name()[0]), ".") + require.Truef(t, strings.HasSuffix(file.Name(), ".qlog.zst"), "expected %s to have the .qlog.zst file ending", file.Name()) + require.Contains(t, file.Name(), "server") + require.Contains(t, file.Name(), "deadbeef") +} - It("buffers", func() { - logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte("connid")) - initialSize := getFile().Size() - // Do a small write. - // Since the writter is buffered, this should not be written to disk yet. - logger.Write([]byte("foobar")) - Expect(getFile().Size()).To(Equal(initialSize)) - // Close the logger. This should flush the buffer to disk. - Expect(logger.Close()).To(Succeed()) - finalSize := getFile().Size() - fmt.Fprintf(GinkgoWriter, "initial log file size: %d, final log file size: %d\n", initialSize, finalSize) - Expect(finalSize).To(BeNumerically(">", initialSize)) - }) +func TestQlogBuffering(t *testing.T) { + qlogDir := createLogDir(t) + logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte("connid")) + initialSize := getFile(t, qlogDir).Size() + // Do a small write. + // Since the writter is buffered, this should not be written to disk yet. + logger.Write([]byte("foobar")) + require.Equal(t, getFile(t, qlogDir).Size(), initialSize) + // Close the logger. This should flush the buffer to disk. + require.NoError(t, logger.Close()) + finalSize := getFile(t, qlogDir).Size() + t.Logf("initial log file size: %d, final log file size: %d\n", initialSize, finalSize) + require.Greater(t, finalSize, initialSize) +} - It("compresses", func() { - logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte("connid")) - logger.Write([]byte("foobar")) - Expect(logger.Close()).To(Succeed()) - compressed, err := ioutil.ReadFile(qlogDir + "/" + getFile().Name()) - Expect(err).ToNot(HaveOccurred()) - Expect(compressed).ToNot(Equal("foobar")) - c, err := zstd.NewReader(bytes.NewReader(compressed)) - Expect(err).ToNot(HaveOccurred()) - data, err := ioutil.ReadAll(c) - Expect(err).ToNot(HaveOccurred()) - Expect(data).To(Equal([]byte("foobar"))) - }) -}) +func TestQlogCompression(t *testing.T) { + qlogDir := createLogDir(t) + logger := newQlogger(qlogDir, logging.PerspectiveServer, []byte("connid")) + logger.Write([]byte("foobar")) + require.NoError(t, logger.Close()) + compressed, err := ioutil.ReadFile(qlogDir + "/" + getFile(t, qlogDir).Name()) + require.NoError(t, err) + require.NotEqual(t, compressed, "foobar") + c, err := zstd.NewReader(bytes.NewReader(compressed)) + require.NoError(t, err) + data, err := ioutil.ReadAll(c) + require.NoError(t, err) + require.Equal(t, data, []byte("foobar")) +} From 46e8a938fd3549a8cdd84aaf1ecedcfb6c535d46 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 3 Jan 2022 19:07:53 +0400 Subject: [PATCH 8/8] remove Ginkgo suite --- go.mod | 2 -- libp2pquic_suite_test.go | 42 ---------------------------------------- 2 files changed, 44 deletions(-) delete mode 100644 libp2pquic_suite_test.go diff --git a/go.mod b/go.mod index 9907fba..3641415 100644 --- a/go.mod +++ b/go.mod @@ -14,8 +14,6 @@ require ( github.com/minio/sha256-simd v0.1.1 github.com/multiformats/go-multiaddr v0.3.1 github.com/multiformats/go-multiaddr-fmt v0.1.0 - github.com/onsi/ginkgo v1.16.4 - github.com/onsi/gomega v1.13.0 github.com/prometheus/client_golang v1.9.0 github.com/stretchr/testify v1.7.0 golang.org/x/crypto v0.0.0-20210506145944-38f3c27a63bf diff --git a/libp2pquic_suite_test.go b/libp2pquic_suite_test.go deleted file mode 100644 index 5144ae3..0000000 --- a/libp2pquic_suite_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package libp2pquic - -import ( - mrand "math/rand" - "testing" - "time" - - "github.com/lucas-clemente/quic-go" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestLibp2pQuicTransport(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "libp2p QUIC Transport Suite") -} - -var _ = BeforeSuite(func() { - mrand.Seed(GinkgoRandomSeed()) -}) - -var ( - garbageCollectIntervalOrig time.Duration - maxUnusedDurationOrig time.Duration - origQuicConfig *quic.Config -) - -var _ = BeforeEach(func() { - garbageCollectIntervalOrig = garbageCollectInterval - maxUnusedDurationOrig = maxUnusedDuration - garbageCollectInterval = 50 * time.Millisecond - maxUnusedDuration = 0 - origQuicConfig = quicConfig - quicConfig = quicConfig.Clone() -}) - -var _ = AfterEach(func() { - garbageCollectInterval = garbageCollectIntervalOrig - maxUnusedDuration = maxUnusedDurationOrig - quicConfig = origQuicConfig -})