From 464aa57ac8d02ee62581ccaf4ae62c35d235dd08 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 4 Jul 2024 18:47:50 +0530 Subject: [PATCH] webtransport: close underlying h3 connection --- p2p/transport/webtransport/conn.go | 9 ++- p2p/transport/webtransport/listener.go | 80 +++++++++++++++++++- p2p/transport/webtransport/transport.go | 43 +++++++---- p2p/transport/webtransport/transport_test.go | 35 +++++++++ 4 files changed, 146 insertions(+), 21 deletions(-) diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index 0e83b1d16f..0525124711 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -7,6 +7,7 @@ import ( tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + "github.com/quic-go/quic-go" "github.com/quic-go/webtransport-go" ) @@ -31,16 +32,18 @@ type conn struct { session *webtransport.Session scope network.ConnManagementScope + qconn quic.Connection } var _ tpt.CapableConn = &conn{} -func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope) *conn { +func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope, qconn quic.Connection) *conn { return &conn{ connSecurityMultiaddrs: sconn, transport: tr, session: sess, scope: scope, + qconn: qconn, } } @@ -70,7 +73,9 @@ func (c *conn) allowWindowIncrease(size uint64) bool { func (c *conn) Close() error { c.scope.Done() c.transport.removeConn(c.session) - return c.session.CloseWithError(0, "") + err := c.session.CloseWithError(0, "") + _ = c.qconn.CloseWithError(1, "") + return err } func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } diff --git a/p2p/transport/webtransport/listener.go b/p2p/transport/webtransport/listener.go index 2a7c3546f2..ff611fe927 100644 --- a/p2p/transport/webtransport/listener.go +++ b/p2p/transport/webtransport/listener.go @@ -15,12 +15,61 @@ import ( "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" ma "github.com/multiformats/go-multiaddr" + "github.com/quic-go/quic-go" + "github.com/quic-go/quic-go/http3" "github.com/quic-go/webtransport-go" ) const queueLen = 16 const handshakeTimeout = 10 * time.Second +type connKey struct{} + +// negotiatingConn is a wrapper around a quic.Connection that lets us wrap it in +// our own context for the duration of the upgrade process. Upgrading a quic +// connection to an h3 connection to a webtransport session. +type negotiatingConn struct { + quic.Connection + ctx context.Context + cancel context.CancelFunc + // stopClose is a function that stops the connection from being closed when + // the context is done. Returns true if the connection close function was + // not called. + stopClose func() bool + err error +} + +func (c *negotiatingConn) Unwrap() (quic.Connection, error) { + defer c.cancel() + if c.stopClose != nil { + // unwrap the first time + if !c.stopClose() { + c.err = errTimeout + } + c.stopClose = nil + } + if c.err != nil { + return nil, c.err + } + return c.Connection, nil +} + +func wrapConn(ctx context.Context, c quic.Connection, handshakeTimeout time.Duration) *negotiatingConn { + ctx, cancel := context.WithTimeout(ctx, handshakeTimeout) + stopClose := context.AfterFunc(ctx, func() { + log.Debugf("failed to handshake on conn: %s", c.RemoteAddr()) + c.CloseWithError(1, "") + }) + return &negotiatingConn{ + Connection: c, + ctx: ctx, + cancel: cancel, + stopClose: stopClose, + } +} + +var errTimeout = errors.New("timeout") + type listener struct { transport *transport isStaticTLSConf bool @@ -56,6 +105,11 @@ func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf addr: reuseListener.Addr(), multiaddr: localMultiaddr, server: webtransport.Server{ + H3: http3.Server{ + ConnContext: func(ctx context.Context, c quic.Connection) context.Context { + return context.WithValue(ctx, connKey{}, c) + }, + }, CheckOrigin: func(r *http.Request) bool { return true }, }, } @@ -71,7 +125,8 @@ func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf log.Debugw("serving failed", "addr", ln.Addr(), "error", err) return } - go ln.server.ServeQUICConn(conn) + wrapped := wrapConn(ln.ctx, conn, t.handshakeTimeout) + go ln.server.ServeQUICConn(wrapped) } }() return ln, nil @@ -137,13 +192,32 @@ func (l *listener) httpHandlerWithConnScope(w http.ResponseWriter, r *http.Reque return err } - conn := newConn(l.transport, sess, sconn, connScope) + connVal := r.Context().Value(connKey{}) + if connVal == nil { + log.Errorf("missing conn from context") + sess.CloseWithError(1, "") + return errors.New("invalid context") + } + nconn, ok := connVal.(*negotiatingConn) + if !ok { + log.Errorf("unexpected connection in context. invalid conn type: %T", nconn) + sess.CloseWithError(1, "") + return errors.New("invalid context") + } + qconn, err := nconn.Unwrap() + if err != nil { + log.Debugf("handshake timed out: %s", r.RemoteAddr) + sess.CloseWithError(1, "") + return err + } + + conn := newConn(l.transport, sess, sconn, connScope, qconn) l.transport.addConn(sess, conn) select { case l.queue <- conn: default: log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err) - sess.CloseWithError(1, "") + conn.Close() return errors.New("accept queue full") } diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index 97172703f7..ef8551d60f 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -60,6 +60,13 @@ func WithTLSClientConfig(c *tls.Config) Option { } } +func WithHandshakeTimeout(d time.Duration) Option { + return func(t *transport) error { + t.handshakeTimeout = d + return nil + } +} + type transport struct { privKey ic.PrivKey pid peer.ID @@ -78,8 +85,9 @@ type transport struct { noise *noise.Transport - connMx sync.Mutex - conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key + connMx sync.Mutex + conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key + handshakeTimeout time.Duration } var _ tpt.Transport = &transport{} @@ -99,13 +107,14 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater return nil, err } t := &transport{ - pid: id, - privKey: key, - rcmgr: rcmgr, - gater: gater, - clock: clock.New(), - connManager: connManager, - conns: map[quic.ConnectionTracingID]*conn{}, + pid: id, + privKey: key, + rcmgr: rcmgr, + gater: gater, + clock: clock.New(), + connManager: connManager, + conns: map[quic.ConnectionTracingID]*conn{}, + handshakeTimeout: handshakeTimeout, } for _, opt := range opts { if err := opt(t); err != nil { @@ -159,7 +168,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee } maddr, _ := ma.SplitFunc(raddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBTRANSPORT }) - sess, err := t.dial(ctx, maddr, url, sni, certHashes) + sess, qconn, err := t.dial(ctx, maddr, url, sni, certHashes) if err != nil { return nil, err } @@ -172,12 +181,12 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee sess.CloseWithError(errorCodeConnectionGating, "") return nil, fmt.Errorf("secured connection gated") } - conn := newConn(t, sess, sconn, scope) + conn := newConn(t, sess, sconn, scope, qconn) t.addConn(sess, conn) return conn, nil } -func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) { +func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, quic.Connection, error) { var tlsConf *tls.Config if t.tlsClientConf != nil { tlsConf = t.tlsClientConf.Clone() @@ -200,7 +209,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string } conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease) if err != nil { - return nil, err + return nil, nil, err } dialer := webtransport.Dialer{ DialAddr: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { @@ -210,12 +219,14 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string } rsp, sess, err := dialer.Dial(ctx, url, nil) if err != nil { - return nil, err + conn.CloseWithError(1, "") + return nil, nil, err } if rsp.StatusCode < 200 || rsp.StatusCode > 299 { - return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) + conn.CloseWithError(1, "") + return nil, nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode) } - return sess, err + return sess, conn, err } func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) { diff --git a/p2p/transport/webtransport/transport_test.go b/p2p/transport/webtransport/transport_test.go index f6c850a2b9..bd41446218 100644 --- a/p2p/transport/webtransport/transport_test.go +++ b/p2p/transport/webtransport/transport_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net" + "net/http" "os" "runtime" "sync/atomic" @@ -827,3 +828,37 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) { require.True(t, found, "Failed after hour: %v", i) } } + +func TestH3ConnClosed(t *testing.T) { + _, serverKey := newIdentity(t) + tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil, libp2pwebtransport.WithHandshakeTimeout(1*time.Second)) + require.NoError(t, err) + defer tr.(io.Closer).Close() + ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) + require.NoError(t, err) + defer ln.Close() + + p, err := net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + conn, err := quic.Dial(context.Background(), p, ln.Addr(), &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{http3.NextProtoH3}, + }, nil) + require.NoError(t, err) + rt := &http3.SingleDestinationRoundTripper{ + Connection: conn, + } + rt.Start() + require.Eventually(t, func() bool { + c := http.Client{ + Transport: rt, + Timeout: 1 * time.Second, + } + resp, err := c.Get(fmt.Sprintf("https://%s", ln.Addr().String())) + if err != nil { + return true + } + resp.Body.Close() + return false + }, 10*time.Second, 1*time.Second) +}