Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webtransport: close underlying h3 connection #2862

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions p2p/transport/webtransport/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
tpt "github.com/libp2p/go-libp2p/core/transport"

ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/quic-go"
"github.com/quic-go/webtransport-go"
)

Expand All @@ -31,16 +32,18 @@ type conn struct {
session *webtransport.Session

scope network.ConnManagementScope
qconn quic.Connection
}

var _ tpt.CapableConn = &conn{}

func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope) *conn {
func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope, qconn quic.Connection) *conn {
return &conn{
connSecurityMultiaddrs: sconn,
transport: tr,
session: sess,
scope: scope,
qconn: qconn,
}
}

Expand Down Expand Up @@ -70,7 +73,9 @@ func (c *conn) allowWindowIncrease(size uint64) bool {
func (c *conn) Close() error {
c.scope.Done()
c.transport.removeConn(c.session)
return c.session.CloseWithError(0, "")
err := c.session.CloseWithError(0, "")
_ = c.qconn.CloseWithError(1, "")
return err
}

func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }
Expand Down
80 changes: 77 additions & 3 deletions p2p/transport/webtransport/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,61 @@ import (
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"

ma "github.com/multiformats/go-multiaddr"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/webtransport-go"
)

const queueLen = 16
const handshakeTimeout = 10 * time.Second

type connKey struct{}

// negotiatingConn is a wrapper around a quic.Connection that lets us wrap it in
// our own context for the duration of the upgrade process. Upgrading a quic
// connection to an h3 connection to a webtransport session.
type negotiatingConn struct {
quic.Connection
ctx context.Context
cancel context.CancelFunc
// stopClose is a function that stops the connection from being closed when
// the context is done. Returns true if the connection close function was
// not called.
stopClose func() bool
err error
}

func (c *negotiatingConn) Unwrap() (quic.Connection, error) {
defer c.cancel()
if c.stopClose != nil {
// unwrap the first time
if !c.stopClose() {
c.err = errTimeout
}
c.stopClose = nil
}
if c.err != nil {
return nil, c.err
}
return c.Connection, nil
}

func wrapConn(ctx context.Context, c quic.Connection, handshakeTimeout time.Duration) *negotiatingConn {
ctx, cancel := context.WithTimeout(ctx, handshakeTimeout)
stopClose := context.AfterFunc(ctx, func() {
log.Debugf("failed to handshake on conn: %s", c.RemoteAddr())
c.CloseWithError(1, "")
})
return &negotiatingConn{
Connection: c,
ctx: ctx,
cancel: cancel,
stopClose: stopClose,
}
}

var errTimeout = errors.New("timeout")

type listener struct {
transport *transport
isStaticTLSConf bool
Expand Down Expand Up @@ -56,6 +105,11 @@ func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf
addr: reuseListener.Addr(),
multiaddr: localMultiaddr,
server: webtransport.Server{
H3: http3.Server{
ConnContext: func(ctx context.Context, c quic.Connection) context.Context {
return context.WithValue(ctx, connKey{}, c)
},
},
CheckOrigin: func(r *http.Request) bool { return true },
},
}
Expand All @@ -71,7 +125,8 @@ func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf
log.Debugw("serving failed", "addr", ln.Addr(), "error", err)
return
}
go ln.server.ServeQUICConn(conn)
wrapped := wrapConn(ln.ctx, conn, t.handshakeTimeout)
go ln.server.ServeQUICConn(wrapped)
}
}()
return ln, nil
Expand Down Expand Up @@ -137,13 +192,32 @@ func (l *listener) httpHandlerWithConnScope(w http.ResponseWriter, r *http.Reque
return err
}

conn := newConn(l.transport, sess, sconn, connScope)
connVal := r.Context().Value(connKey{})
if connVal == nil {
log.Errorf("missing conn from context")
sess.CloseWithError(1, "")
return errors.New("invalid context")
}
nconn, ok := connVal.(*negotiatingConn)
if !ok {
log.Errorf("unexpected connection in context. invalid conn type: %T", nconn)
sess.CloseWithError(1, "")
return errors.New("invalid context")
}
qconn, err := nconn.Unwrap()
if err != nil {
log.Debugf("handshake timed out: %s", r.RemoteAddr)
sess.CloseWithError(1, "")
return err
}

conn := newConn(l.transport, sess, sconn, connScope, qconn)
l.transport.addConn(sess, conn)
select {
case l.queue <- conn:
default:
log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
sess.CloseWithError(1, "")
conn.Close()
return errors.New("accept queue full")
}

Expand Down
43 changes: 27 additions & 16 deletions p2p/transport/webtransport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ func WithTLSClientConfig(c *tls.Config) Option {
}
}

func WithHandshakeTimeout(d time.Duration) Option {
return func(t *transport) error {
t.handshakeTimeout = d
return nil
}
}

type transport struct {
privKey ic.PrivKey
pid peer.ID
Expand All @@ -78,8 +85,9 @@ type transport struct {

noise *noise.Transport

connMx sync.Mutex
conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key
connMx sync.Mutex
conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key
handshakeTimeout time.Duration
}

var _ tpt.Transport = &transport{}
Expand All @@ -99,13 +107,14 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater
return nil, err
}
t := &transport{
pid: id,
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
connManager: connManager,
conns: map[quic.ConnectionTracingID]*conn{},
pid: id,
privKey: key,
rcmgr: rcmgr,
gater: gater,
clock: clock.New(),
connManager: connManager,
conns: map[quic.ConnectionTracingID]*conn{},
handshakeTimeout: handshakeTimeout,
}
for _, opt := range opts {
if err := opt(t); err != nil {
Expand Down Expand Up @@ -159,7 +168,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
}

maddr, _ := ma.SplitFunc(raddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBTRANSPORT })
sess, err := t.dial(ctx, maddr, url, sni, certHashes)
sess, qconn, err := t.dial(ctx, maddr, url, sni, certHashes)
if err != nil {
return nil, err
}
Expand All @@ -172,12 +181,12 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
sess.CloseWithError(errorCodeConnectionGating, "")
return nil, fmt.Errorf("secured connection gated")
}
conn := newConn(t, sess, sconn, scope)
conn := newConn(t, sess, sconn, scope, qconn)
t.addConn(sess, conn)
return conn, nil
}

func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) {
func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, quic.Connection, error) {
var tlsConf *tls.Config
if t.tlsClientConf != nil {
tlsConf = t.tlsClientConf.Clone()
Expand All @@ -200,7 +209,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
}
conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease)
if err != nil {
return nil, err
return nil, nil, err
}
dialer := webtransport.Dialer{
DialAddr: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
Expand All @@ -210,12 +219,14 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
}
rsp, sess, err := dialer.Dial(ctx, url, nil)
if err != nil {
return nil, err
conn.CloseWithError(1, "")
return nil, nil, err
}
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode)
conn.CloseWithError(1, "")
return nil, nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode)
}
return sess, err
return sess, conn, err
}

func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) {
Expand Down
35 changes: 35 additions & 0 deletions p2p/transport/webtransport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"fmt"
"io"
"net"
"net/http"
"os"
"runtime"
"sync/atomic"
Expand Down Expand Up @@ -827,3 +828,37 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) {
require.True(t, found, "Failed after hour: %v", i)
}
}

func TestH3ConnClosed(t *testing.T) {
_, serverKey := newIdentity(t)
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil, libp2pwebtransport.WithHandshakeTimeout(1*time.Second))
require.NoError(t, err)
defer tr.(io.Closer).Close()
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
require.NoError(t, err)
defer ln.Close()

p, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
conn, err := quic.Dial(context.Background(), p, ln.Addr(), &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{http3.NextProtoH3},
}, nil)
require.NoError(t, err)
rt := &http3.SingleDestinationRoundTripper{
Connection: conn,
}
rt.Start()
require.Eventually(t, func() bool {
c := http.Client{
Transport: rt,
Timeout: 1 * time.Second,
}
resp, err := c.Get(fmt.Sprintf("https://%s", ln.Addr().String()))
if err != nil {
return true
}
resp.Body.Close()
return false
}, 10*time.Second, 1*time.Second)
}
Loading