diff --git a/p2p/test/reconnects/reconnect_test.go b/p2p/test/reconnects/reconnect_test.go index 5e5bbcca49..12fdb36780 100644 --- a/p2p/test/reconnects/reconnect_test.go +++ b/p2p/test/reconnects/reconnect_test.go @@ -12,204 +12,95 @@ import ( "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/protocol" swarmt "github.com/libp2p/go-libp2p-swarm/testing" - u "github.com/ipfs/go-ipfs-util" - logging "github.com/ipfs/go-log/v2" "github.com/stretchr/testify/require" ) -var log = logging.Logger("reconnect") - func EchoStreamHandler(stream network.Stream) { - c := stream.Conn() - log.Debugf("%s echoing %s", c.LocalPeer(), c.RemotePeer()) - go func() { - _, err := io.Copy(stream, stream) - if err == nil { - stream.Close() - } else { - stream.Reset() - } - }() -} - -type sendChans struct { - send chan struct{} - sent chan struct{} - read chan struct{} - close_ chan struct{} - closed chan struct{} -} - -func newSendChans() sendChans { - return sendChans{ - send: make(chan struct{}), - sent: make(chan struct{}), - read: make(chan struct{}), - close_: make(chan struct{}), - closed: make(chan struct{}), + _, err := io.CopyBuffer(stream, stream, make([]byte, 64)) // use a small buffer here to avoid problems with flow control + if err == nil { + stream.Close() + } else { + stream.Reset() } } -func newSender() (chan sendChans, func(s network.Stream)) { - scc := make(chan sendChans) - return scc, func(s network.Stream) { - sc := newSendChans() - scc <- sc - - defer func() { - s.Close() - sc.closed <- struct{}{} - }() - - buf := make([]byte, 65536) - buf2 := make([]byte, 65536) - u.NewTimeSeededRand().Read(buf) - - for { - select { - case <-sc.close_: - return - case <-sc.send: - } - - // send a randomly sized subchunk - from := rand.Intn(len(buf) / 2) - to := rand.Intn(len(buf) / 2) - sendbuf := buf[from : from+to] - - log.Debugf("sender sending %d bytes", len(sendbuf)) - n, err := s.Write(sendbuf) - if err != nil { - log.Debug("sender error. exiting:", err) - return - } - - log.Debugf("sender wrote %d bytes", n) - sc.sent <- struct{}{} - - if n, err = io.ReadFull(s, buf2[:len(sendbuf)]); err != nil { - log.Debug("sender error. failed to read:", err) - return - } - - log.Debugf("sender read %d bytes", n) - sc.read <- struct{}{} +func TestReconnect5(t *testing.T) { + runTest := func(t *testing.T, swarmOpt swarmt.Option) { + t.Helper() + const num = 5 + hosts := make([]host.Host, 0, num) + + for i := 0; i < num; i++ { + h, err := bhost.NewHost(swarmt.GenSwarm(t, swarmOpt), nil) + require.NoError(t, err) + defer h.Close() + hosts = append(hosts, h) + h.SetStreamHandler(protocol.TestingID, EchoStreamHandler) } - } -} -// TestReconnect tests whether hosts are able to disconnect and reconnect. -func TestReconnect2(t *testing.T) { - // TCP RST handling is flaky in OSX, see https://github.com/golang/go/issues/50254. - // We can avoid this by using QUIC in this test. - h1, err := bhost.NewHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP), nil) - require.NoError(t, err) - h2, err := bhost.NewHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP), nil) - require.NoError(t, err) - hosts := []host.Host{h1, h2} - - h1.SetStreamHandler(protocol.TestingID, EchoStreamHandler) - h2.SetStreamHandler(protocol.TestingID, EchoStreamHandler) - - rounds := 8 - if testing.Short() { - rounds = 4 - } - for i := 0; i < rounds; i++ { - log.Debugf("TestReconnect: %d/%d\n", i, rounds) - subtestConnSendDisc(t, hosts) + for i := 0; i < 4; i++ { + runRound(t, hosts) + } } -} -// TestReconnect tests whether hosts are able to disconnect and reconnect. -func TestReconnect5(t *testing.T) { - const num = 5 - hosts := make([]host.Host, 0, num) - for i := 0; i < num; i++ { - // TCP RST handling is flaky in OSX, see https://github.com/golang/go/issues/50254. - // We can avoid this by using QUIC in this test. - h, err := bhost.NewHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP), nil) - require.NoError(t, err) - h.SetStreamHandler(protocol.TestingID, EchoStreamHandler) - hosts = append(hosts, h) - } + t.Run("using TCP", func(t *testing.T) { + runTest(t, swarmt.OptDisableQUIC) + }) - rounds := 4 - if testing.Short() { - rounds = 2 - } - for i := 0; i < rounds; i++ { - log.Debugf("TestReconnect: %d/%d\n", i, rounds) - subtestConnSendDisc(t, hosts) - } + t.Run("using QUIC", func(t *testing.T) { + runTest(t, swarmt.OptDisableTCP) + }) } -func subtestConnSendDisc(t *testing.T, hosts []host.Host) { - ctx := context.Background() - numStreams := 3 * len(hosts) - numMsgs := 10 - - if testing.Short() { - numStreams = 5 * len(hosts) - numMsgs = 4 +func runRound(t *testing.T, hosts []host.Host) { + for _, h := range hosts { + h.SetStreamHandler(protocol.TestingID, EchoStreamHandler) } - ss, sF := newSender() - + // connect all hosts for _, h1 := range hosts { for _, h2 := range hosts { if h1.ID() >= h2.ID() { continue } - - h2pi := h2.Peerstore().PeerInfo(h2.ID()) - log.Debugf("dialing %s", h2pi.Addrs) - if err := h1.Connect(ctx, h2pi); err != nil { - t.Fatal("Failed to connect:", err) - } + require.NoError(t, h1.Connect(context.Background(), peer.AddrInfo{ID: h2.ID(), Addrs: h2.Peerstore().Addrs(h2.ID())})) } } - var wg sync.WaitGroup - for i := 0; i < numStreams; i++ { - h1 := hosts[i%len(hosts)] - h2 := hosts[(i+1)%len(hosts)] - s, err := h1.NewStream(context.Background(), h2.ID(), protocol.TestingID) - if err != nil { - t.Error(err) - } - - wg.Add(1) - go func(j int) { - defer wg.Done() - - go sF(s) - log.Debugf("getting handle %d", j) - sc := <-ss // wait to get handle. - log.Debugf("spawning worker %d", j) - - for k := 0; k < numMsgs; k++ { - sc.send <- struct{}{} - <-sc.sent - log.Debugf("%d sent %d", j, k) - <-sc.read - log.Debugf("%d read %d", j, k) + const ( + numStreams = 5 + maxDataLen = 64 << 10 + ) + // exchange some data + for _, h1 := range hosts { + for _, h2 := range hosts { + if h1 == h2 { + continue } - sc.close_ <- struct{}{} - <-sc.closed - log.Debugf("closed %d", j) - }(i) - } - wg.Wait() - - for i, h1 := range hosts { - log.Debugf("host %d has %d conns", i, len(h1.Network().Conns())) + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < numStreams; i++ { + go func() { + defer wg.Done() + data := make([]byte, rand.Intn(maxDataLen)+1) + rand.Read(data) + str, err := h1.NewStream(context.Background(), h2.ID(), protocol.TestingID) + require.NoError(t, err) + _, err = str.Write(data) + require.NoError(t, err) + require.NoError(t, str.Close()) + }() + } + wg.Wait() + } } + // disconnect all hosts for _, h1 := range hosts { // close connection cs := h1.Network().Conns() @@ -217,16 +108,18 @@ func subtestConnSendDisc(t *testing.T, hosts []host.Host) { if c.LocalPeer() > c.RemotePeer() { continue } - log.Debugf("closing: %s", c) c.Close() } } - <-time.After(20 * time.Millisecond) - - for i, h := range hosts { - if len(h.Network().Conns()) > 0 { - t.Fatalf("host %d %s has %d conns! not zero.", i, h.ID(), len(h.Network().Conns())) + require.Eventually(t, func() bool { + for _, h1 := range hosts { + for _, h2 := range hosts { + if len(h1.Network().ConnsToPeer(h2.ID())) > 0 { + return false + } + } } - } + return true + }, 500*time.Millisecond, 10*time.Millisecond) }