diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index 1834fc812b..b6711d66ed 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -329,8 +329,8 @@ func (l *listener) Multiaddr() ma.Multiaddr { func addOnConnectionStateChangeCallback(pc *webrtc.PeerConnection) <-chan error { errC := make(chan error, 1) var once sync.Once - pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { - switch state { + pc.OnConnectionStateChange(func(_ webrtc.PeerConnectionState) { + switch pc.ConnectionState() { case webrtc.PeerConnectionStateConnected: once.Do(func() { close(errC) }) case webrtc.PeerConnectionStateFailed: diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index b04753ecab..e9fcb2ac2a 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -73,9 +73,9 @@ const ( // timeout values for the peerconnection // https://github.com/pion/webrtc/blob/v3.1.50/settingengine.go#L102-L109 const ( - DefaultDisconnectedTimeout = 20 * time.Second - DefaultFailedTimeout = 30 * time.Second - DefaultKeepaliveTimeout = 15 * time.Second + DefaultDisconnectedTimeout = 100 * time.Second + DefaultFailedTimeout = 300 * time.Second + DefaultKeepaliveTimeout = 50 * time.Second sctpReceiveBufferSize = 100_000 ) diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index a3054a82df..3b56ec4001 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -17,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" "github.com/multiformats/go-multibase" @@ -751,8 +752,8 @@ func TestTransportWebRTC_PeerConnectionDTLSFailed(t *testing.T) { func TestConnectionTimeoutOnListener(t *testing.T) { tr, listeningPeer := getTransport(t) - tr.peerConnectionTimeouts.Disconnect = 100 * time.Millisecond - tr.peerConnectionTimeouts.Failed = 150 * time.Millisecond + tr.peerConnectionTimeouts.Disconnect = 300 * time.Millisecond + tr.peerConnectionTimeouts.Failed = 500 * time.Millisecond tr.peerConnectionTimeouts.Keepalive = 50 * time.Millisecond listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") @@ -770,17 +771,25 @@ func TestConnectionTimeoutOnListener(t *testing.T) { tr1, connectingPeer := getTransport(t) go func() { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() addr, err := manet.FromNetAddr(proxy.LocalAddr()) - require.NoError(t, err) + if !assert.NoError(t, err) { + ln.Close() + return + } _, webrtcComponent := ma.SplitFunc(ln.Multiaddr(), func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBRTC_DIRECT }) addr = addr.Encapsulate(webrtcComponent) conn, err := tr1.Dial(ctx, addr, listeningPeer) - require.NoError(t, err) + if !assert.NoError(t, err) { + ln.Close() + return + } t.Cleanup(func() { conn.Close() }) str, err := conn.OpenStream(ctx) - require.NoError(t, err) + if !assert.NoError(t, err) { + return + } str.Write([]byte("foobar")) }() @@ -860,3 +869,108 @@ func TestMaxInFlightRequests(t *testing.T) { require.Equal(t, count, int(success.Load()), "expected exactly 3 dial successes") require.Equal(t, 1, int(fails.Load()), "expected exactly 1 dial failure") } + +func TestStressConnectionCreation(t *testing.T) { + var listeners []tpt.Listener + var listenerPeerIDs []peer.ID + + const numListeners = 10 + const dialersPerListener = 10 + const connsPerDialer = 10 + errCh := make(chan error, 10*numListeners*dialersPerListener*connsPerDialer) + successCh := make(chan struct{}, 10*numListeners*dialersPerListener*connsPerDialer) + + for i := 0; i < numListeners; i++ { + tr, lp := getTransport(t) + listenerPeerIDs = append(listenerPeerIDs, lp) + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct")) + require.NoError(t, err) + defer ln.Close() + listeners = append(listeners, ln) + } + + runListenConn := func(conn tpt.CapableConn) { + s, err := conn.AcceptStream() + if err != nil { + t.Errorf("accept stream failed for listener: %s", err) + errCh <- err + return + } + var b [4]byte + if _, err := s.Read(b[:]); err != nil { + t.Errorf("read stream failed for listener: %s", err) + errCh <- err + return + } + s.Write(b[:]) + s.Read(b[:]) // peer will close the connection after read + successCh <- struct{}{} + } + + runDialConn := func(conn tpt.CapableConn) { + s, err := conn.OpenStream(context.Background()) + if err != nil { + t.Errorf("accept stream failed for listener: %s", err) + errCh <- err + return + } + var b [4]byte + if _, err := s.Write(b[:]); err != nil { + t.Errorf("write stream failed for dialer: %s", err) + } + if _, err := s.Read(b[:]); err != nil { + t.Errorf("read stream failed for dialer: %s", err) + errCh <- err + return + } + s.Close() + } + + runListener := func(ln tpt.Listener) { + for i := 0; i < dialersPerListener*connsPerDialer; i++ { + conn, err := ln.Accept() + if err != nil { + t.Errorf("listener failed to accept conneciton: %s", err) + return + } + go runListenConn(conn) + } + } + + runDialer := func(ln tpt.Listener, lp peer.ID) { + tp, _ := getTransport(t) + for i := 0; i < connsPerDialer; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) + conn, err := tp.Dial(ctx, ln.Multiaddr(), lp) + if err != nil { + t.Errorf("dial failed: %s", err) + errCh <- err + cancel() + return + } + go runDialConn(conn) + cancel() + } + } + + for i := 0; i < numListeners; i++ { + go runListener(listeners[i]) + } + time.Sleep(10 * time.Second) + for i := 0; i < numListeners; i++ { + for j := 0; j < dialersPerListener; j++ { + go runDialer(listeners[i], listenerPeerIDs[i]) + } + } + + for i := 0; i < numListeners*dialersPerListener*connsPerDialer; i++ { + select { + case <-successCh: + fmt.Println(i) + case err := <-errCh: + t.Fatalf("failed: %s", err) + case <-time.After(300 * time.Second): + t.Fatalf("timed out") + } + } +}