diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index d73da15c2d..e05f71d782 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -250,7 +250,7 @@ func getHostPair(t *testing.T) (host.Host, host.Host) { require.NoError(t, err) h2.Start() - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() h2pi := h2.Peerstore().PeerInfo(h2.ID()) require.NoError(t, h1.Connect(ctx, h2pi)) @@ -344,11 +344,11 @@ func TestHostProtoMismatch(t *testing.T) { } func TestHostProtoPreknowledge(t *testing.T) { - h1, err := NewHost(swarmt.GenSwarm(t), nil) + h1, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil) require.NoError(t, err) defer h1.Close() - h2, err := NewHost(swarmt.GenSwarm(t), nil) + h2, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP), nil) require.NoError(t, err) defer h2.Close() @@ -359,15 +359,23 @@ func TestHostProtoPreknowledge(t *testing.T) { } h2.SetStreamHandler("/super", handler) - // Prevent pushing identify information so this test actually _uses_ the super protocol. - h1.RemoveStreamHandler(identify.IDPush) h1.Start() h2.Start() + // Prevent pushing identify information so this test actually _uses_ the super protocol. + h1.RemoveStreamHandler(identify.IDPush) + h2pi := h2.Peerstore().PeerInfo(h2.ID()) + // Filter to only 1 address so that we don't have to think about parallel + // connections in this test + h2pi.Addrs = h2pi.Addrs[:1] require.NoError(t, h1.Connect(context.Background(), h2pi)) + // This test implicitly relies on 1 connection. If a background identify + // completes after we set the stream handler below things break + require.Equal(t, 1, len(h1.Network().ConnsToPeer(h2.ID()))) + // wait for identify handshake to finish completely select { case <-h1.ids.IdentifyWait(h1.Network().ConnsToPeer(h2.ID())[0]): @@ -383,6 +391,18 @@ func TestHostProtoPreknowledge(t *testing.T) { h2.SetStreamHandler("/foo", handler) + require.Never(t, func() bool { + protos, err := h1.Peerstore().GetProtocols(h2.ID()) + require.NoError(t, err) + for _, p := range protos { + fmt.Println("proto: ", p) + if p == "/foo" { + return true + } + } + return false + }, time.Second, 100*time.Millisecond) + s, err := h1.NewStream(context.Background(), h2.ID(), "/foo", "/bar", "/super") require.NoError(t, err)