From ebd000db1e1d1593ad5cc95025c087cd83f99e18 Mon Sep 17 00:00:00 2001 From: julian88110 <111450570+julian88110@users.noreply.github.com> Date: Fri, 7 Oct 2022 10:34:58 -0700 Subject: [PATCH] tls: use ALPN to negotiate the stream multiplexer (#1772) * Muxer selection in TLS handshake first cut * Clean up some part of the code * Change earlydata to ConnectionState for security connection. * resolve merging conflicts * Add stubs for noise * clean up code * Switch over to passing muxers to security transport constructors * Address feedback points * Update p2p/net/upgrader/upgrader.go Co-authored-by: Marten Seemann * clean up accidental checked file. * Review points round 2 * Address some go nit points * Update tls transport test to address review points Co-authored-by: Marten Seemann --- config/config.go | 2 +- config/constructor_types.go | 27 +++--- config/muxer.go | 2 +- config/reflection_magic.go | 9 +- config/security.go | 18 ++-- config/transport.go | 2 +- core/network/conn.go | 11 +++ core/sec/insecure/insecure.go | 6 ++ p2p/muxer/muxer-multistream/multistream.go | 5 ++ p2p/net/connmgr/connmgr_test.go | 1 + p2p/net/mock/mock_conn.go | 5 ++ p2p/net/swarm/swarm_conn.go | 6 ++ p2p/net/upgrader/upgrader.go | 17 +++- p2p/security/noise/session.go | 5 ++ p2p/security/tls/cmd/tlsdiag/client.go | 2 +- p2p/security/tls/cmd/tlsdiag/server.go | 2 +- p2p/security/tls/conn.go | 10 ++- p2p/security/tls/transport.go | 41 +++++++-- p2p/security/tls/transport_test.go | 96 +++++++++++++++++++--- p2p/transport/quic/conn.go | 6 ++ 20 files changed, 227 insertions(+), 46 deletions(-) diff --git a/config/config.go b/config/config.go index 0634ef4a36..adbcf1d122 100644 --- a/config/config.go +++ b/config/config.go @@ -173,7 +173,7 @@ func (cfg *Config) addTransports(h host.Host) error { secure = makeInsecureTransport(h.ID(), cfg.PeerKey) } else { var err error - secure, err = makeSecurityMuxer(h, cfg.SecurityTransports) + secure, err = makeSecurityMuxer(h, cfg.SecurityTransports, cfg.Muxers) if err != nil { return err } diff --git a/config/constructor_types.go b/config/constructor_types.go index 53a105ba39..9c14df2e2c 100644 --- a/config/constructor_types.go +++ b/config/constructor_types.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" @@ -35,42 +36,46 @@ var ( peerIDType = reflect.TypeOf((peer.ID)("")) pskType = reflect.TypeOf((pnet.PSK)(nil)) resolverType = reflect.TypeOf((*madns.Resolver)(nil)) + muxersType = reflect.TypeOf(([]protocol.ID)(nil)) ) var argTypes = map[reflect.Type]constructor{ - upgraderType: func(_ host.Host, u transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + upgraderType: func(_ host.Host, u transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return u }, - hostType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + hostType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h }, - networkType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + networkType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.Network() }, - pskType: func(_ host.Host, _ transport.Upgrader, psk pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + pskType: func(_ host.Host, _ transport.Upgrader, psk pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return psk }, - connGaterType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, cg connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + connGaterType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, cg connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return cg }, - peerIDType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + peerIDType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.ID() }, - privKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + privKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.Peerstore().PrivKey(h.ID()) }, - pubKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + pubKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.Peerstore().PubKey(h.ID()) }, - pstoreType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver) interface{} { + pstoreType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return h.Peerstore() }, - rcmgrType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, rcmgr network.ResourceManager, _ *madns.Resolver) interface{} { + rcmgrType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, rcmgr network.ResourceManager, _ *madns.Resolver, _ []protocol.ID) interface{} { return rcmgr }, - resolverType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, r *madns.Resolver) interface{} { + resolverType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, r *madns.Resolver, _ []protocol.ID) interface{} { return r }, + muxersType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater, _ network.ResourceManager, _ *madns.Resolver, muxers []protocol.ID) interface{} { + return muxers + }, } func newArgTypeSet(types ...reflect.Type) map[reflect.Type]constructor { diff --git a/config/muxer.go b/config/muxer.go index 30a3fa2f7e..e7b64c1345 100644 --- a/config/muxer.go +++ b/config/muxer.go @@ -35,7 +35,7 @@ func MuxerConstructor(m interface{}) (MuxC, error) { return nil, err } return func(h host.Host) (network.Multiplexer, error) { - t, err := ctor(h, nil, nil, nil, nil, nil) + t, err := ctor(h, nil, nil, nil, nil, nil, nil) if err != nil { return nil, err } diff --git a/config/reflection_magic.go b/config/reflection_magic.go index bb2f52c6b0..0189872abb 100644 --- a/config/reflection_magic.go +++ b/config/reflection_magic.go @@ -10,6 +10,7 @@ import ( "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/pnet" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/transport" madns "github.com/multiformats/go-multiaddr-dns" @@ -82,7 +83,7 @@ func callConstructor(c reflect.Value, args []reflect.Value) (interface{}, error) return val, err } -type constructor func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver) interface{} +type constructor func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []protocol.ID) interface{} func makeArgumentConstructors(fnType reflect.Type, argTypes map[reflect.Type]constructor) ([]constructor, error) { params := fnType.NumIn() @@ -133,7 +134,7 @@ func makeConstructor( tptType reflect.Type, argTypes map[reflect.Type]constructor, opts ...interface{}, -) (func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver) (interface{}, error), error) { +) (func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater, network.ResourceManager, *madns.Resolver, []protocol.ID) (interface{}, error), error) { v := reflect.ValueOf(tpt) // avoid panicing on nil/zero value. if v == (reflect.Value{}) { @@ -157,10 +158,10 @@ func makeConstructor( return nil, err } - return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver) (interface{}, error) { + return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver, muxers []protocol.ID) (interface{}, error) { arguments := make([]reflect.Value, 0, len(argConstructors)+len(opts)) for i, makeArg := range argConstructors { - if arg := makeArg(h, u, psk, cg, rcmgr, resolver); arg != nil { + if arg := makeArg(h, u, psk, cg, rcmgr, resolver, muxers); arg != nil { arguments = append(arguments, reflect.ValueOf(arg)) } else { // ValueOf an un-typed nil yields a zero reflect diff --git a/config/security.go b/config/security.go index 330cfaf334..6f15b63462 100644 --- a/config/security.go +++ b/config/security.go @@ -6,13 +6,14 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/sec/insecure" csms "github.com/libp2p/go-libp2p/p2p/net/conn-security-multistream" ) // SecC is a security transport constructor. -type SecC func(h host.Host) (sec.SecureTransport, error) +type SecC func(h host.Host, muxers []protocol.ID) (sec.SecureTransport, error) // MsSecC is a tuple containing a security transport constructor and a protocol // ID. @@ -24,6 +25,7 @@ type MsSecC struct { var securityArgTypes = newArgTypeSet( hostType, networkType, peerIDType, privKeyType, pubKeyType, pstoreType, + muxersType, ) // SecurityConstructor creates a security constructor from the passed parameter @@ -31,7 +33,7 @@ var securityArgTypes = newArgTypeSet( func SecurityConstructor(security interface{}) (SecC, error) { // Already constructed? if t, ok := security.(sec.SecureTransport); ok { - return func(_ host.Host) (sec.SecureTransport, error) { + return func(_ host.Host, _ []protocol.ID) (sec.SecureTransport, error) { return t, nil }, nil } @@ -40,8 +42,8 @@ func SecurityConstructor(security interface{}) (SecC, error) { if err != nil { return nil, err } - return func(h host.Host) (sec.SecureTransport, error) { - t, err := ctor(h, nil, nil, nil, nil, nil) + return func(h host.Host, muxers []protocol.ID) (sec.SecureTransport, error) { + t, err := ctor(h, nil, nil, nil, nil, nil, muxers) if err != nil { return nil, err } @@ -55,7 +57,7 @@ func makeInsecureTransport(id peer.ID, privKey crypto.PrivKey) sec.SecureMuxer { return secMuxer } -func makeSecurityMuxer(h host.Host, tpts []MsSecC) (sec.SecureMuxer, error) { +func makeSecurityMuxer(h host.Host, tpts []MsSecC, muxers []MsMuxC) (sec.SecureMuxer, error) { secMuxer := new(csms.SSMuxer) transportSet := make(map[string]struct{}, len(tpts)) for _, tptC := range tpts { @@ -64,8 +66,12 @@ func makeSecurityMuxer(h host.Host, tpts []MsSecC) (sec.SecureMuxer, error) { } transportSet[tptC.ID] = struct{}{} } + muxIds := make([]protocol.ID, 0, len(muxers)) + for _, muxc := range muxers { + muxIds = append(muxIds, protocol.ID(muxc.ID)) + } for _, tptC := range tpts { - tpt, err := tptC.SecC(h) + tpt, err := tptC.SecC(h, muxIds) if err != nil { return nil, err } diff --git a/config/transport.go b/config/transport.go index 850357f5a4..6105e77f13 100644 --- a/config/transport.go +++ b/config/transport.go @@ -50,7 +50,7 @@ func TransportConstructor(tpt interface{}, opts ...interface{}) (TptC, error) { return nil, err } return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater, rcmgr network.ResourceManager, resolver *madns.Resolver) (transport.Transport, error) { - t, err := ctor(h, u, psk, cg, rcmgr, resolver) + t, err := ctor(h, u, psk, cg, rcmgr, resolver, nil) if err != nil { return nil, err } diff --git a/core/network/conn.go b/core/network/conn.go index 8554493e25..18414b062c 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -34,6 +34,14 @@ type Conn interface { GetStreams() []Stream } +// ConnectionState holds extra information releated to the ConnSecurity entity. +type ConnectionState struct { + // The next protocol used for stream muxer selection. This is derived from + // security protocol handshake, for example, Noise handshake payload or + // TLS/ALPN negotiation. + NextProto string +} + // ConnSecurity is the interface that one can mix into a connection interface to // give it the security methods. type ConnSecurity interface { @@ -48,6 +56,9 @@ type ConnSecurity interface { // RemotePublicKey returns the public key of the remote peer. RemotePublicKey() ic.PubKey + + // Connection state info of the secured connection. + ConnState() ConnectionState } // ConnMultiaddrs is an interface mixin for connection types that provide multiaddr diff --git a/core/sec/insecure/insecure.go b/core/sec/insecure/insecure.go index 12bd1842b4..2d94f43804 100644 --- a/core/sec/insecure/insecure.go +++ b/core/sec/insecure/insecure.go @@ -10,6 +10,7 @@ import ( "net" ci "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/sec" pb "github.com/libp2p/go-libp2p/core/sec/insecure/pb" @@ -230,5 +231,10 @@ func (ic *Conn) LocalPrivateKey() ci.PrivKey { return ic.localPrivKey } +// ConnState returns the security connection's state information. +func (ic *Conn) ConnState() network.ConnectionState { + return network.ConnectionState{} +} + var _ sec.SecureTransport = (*Transport)(nil) var _ sec.SecureConn = (*Conn)(nil) diff --git a/p2p/muxer/muxer-multistream/multistream.go b/p2p/muxer/muxer-multistream/multistream.go index bf9d41630c..e81ae0ded3 100644 --- a/p2p/muxer/muxer-multistream/multistream.go +++ b/p2p/muxer/muxer-multistream/multistream.go @@ -73,3 +73,8 @@ func (t *Transport) NewConn(nc net.Conn, isServer bool, scope network.PeerScope) return tpt.NewConn(nc, isServer, scope) } + +func (t *Transport) GetTransportByKey(key string) (network.Multiplexer, bool) { + val, ok := t.tpts[key] + return val, ok +} diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index 312bdc1f3b..2053e3e6f7 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -806,6 +806,7 @@ func (m mockConn) ID() string { panic func (m mockConn) NewStream(ctx context.Context) (network.Stream, error) { panic("implement me") } func (m mockConn) GetStreams() []network.Stream { panic("implement me") } func (m mockConn) Scope() network.ConnScope { panic("implement me") } +func (m mockConn) ConnState() network.ConnectionState { return network.ConnectionState{} } func TestPeerInfoSorting(t *testing.T) { t.Run("starts with temporary connections", func(t *testing.T) { diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 5664fabb61..48015a4c61 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -178,6 +178,11 @@ func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } +// ConnState of security connection. Empty if not supported. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{} +} + // Stat returns metadata about the connection func (c *conn) Stat() network.ConnStats { return c.stat diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 779ee37374..4de2727f80 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -178,6 +178,12 @@ func (c *Conn) RemotePublicKey() ic.PubKey { return c.conn.RemotePublicKey() } +// ConnState is the security connection state. including early data result. +// Empty if not supported. +func (c *Conn) ConnState() network.ConnectionState { + return c.conn.ConnState() +} + // Stat returns metadata pertaining to this connection func (c *Conn) Stat() network.ConnStats { c.streams.Lock() diff --git a/p2p/net/upgrader/upgrader.go b/p2p/net/upgrader/upgrader.go index 58347865a5..2ef60b82ac 100644 --- a/p2p/net/upgrader/upgrader.go +++ b/p2p/net/upgrader/upgrader.go @@ -13,6 +13,7 @@ import ( ipnet "github.com/libp2p/go-libp2p/core/pnet" "github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/transport" + msmux "github.com/libp2p/go-libp2p/p2p/muxer/muxer-multistream" "github.com/libp2p/go-libp2p/p2p/net/pnet" manet "github.com/multiformats/go-multiaddr/net" @@ -194,12 +195,24 @@ func (u *upgrader) setupSecurity(ctx context.Context, conn net.Conn, p peer.ID, return u.secure.SecureOutbound(ctx, conn, p) } -func (u *upgrader) setupMuxer(ctx context.Context, conn net.Conn, server bool, scope network.PeerScope) (network.MuxedConn, error) { - // TODO: The muxer should take a context. +func (u *upgrader) setupMuxer(ctx context.Context, conn sec.SecureConn, server bool, scope network.PeerScope) (network.MuxedConn, error) { + msmuxer, ok := u.muxer.(*msmux.Transport) + muxerSelected := conn.ConnState().NextProto + // Use muxer selected from security handshake if available. Otherwise fall back to multistream-selection. + if ok && len(muxerSelected) > 0 { + tpt, ok := msmuxer.GetTransportByKey(muxerSelected) + if !ok { + return nil, fmt.Errorf("selected a muxer we don't know: %s", muxerSelected) + } + + return tpt.NewConn(conn, server, scope) + } + done := make(chan struct{}) var smconn network.MuxedConn var err error + // TODO: The muxer should take a context. go func() { defer close(done) smconn, err = u.muxer.NewConn(conn, server, scope) diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index 692c9174ad..f1286b9ffb 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -10,6 +10,7 @@ import ( "github.com/flynn/noise" "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" ) @@ -108,6 +109,10 @@ func (s *secureSession) RemotePublicKey() crypto.PubKey { return s.remoteKey } +func (s *secureSession) ConnState() network.ConnectionState { + return network.ConnectionState{} +} + func (s *secureSession) SetDeadline(t time.Time) error { return s.insecureConn.SetDeadline(t) } diff --git a/p2p/security/tls/cmd/tlsdiag/client.go b/p2p/security/tls/cmd/tlsdiag/client.go index 7f1a7efecd..2292bfe0e9 100644 --- a/p2p/security/tls/cmd/tlsdiag/client.go +++ b/p2p/security/tls/cmd/tlsdiag/client.go @@ -34,7 +34,7 @@ func StartClient() error { return err } fmt.Printf(" Peer ID: %s\n", id.Pretty()) - tp, err := libp2ptls.New(priv) + tp, err := libp2ptls.New(priv, nil) if err != nil { return err } diff --git a/p2p/security/tls/cmd/tlsdiag/server.go b/p2p/security/tls/cmd/tlsdiag/server.go index 16f51f4a1b..05e4be3f16 100644 --- a/p2p/security/tls/cmd/tlsdiag/server.go +++ b/p2p/security/tls/cmd/tlsdiag/server.go @@ -27,7 +27,7 @@ func StartServer() error { return err } fmt.Printf(" Peer ID: %s\n", id.Pretty()) - tp, err := libp2ptls.New(priv) + tp, err := libp2ptls.New(priv, nil) if err != nil { return err } diff --git a/p2p/security/tls/conn.go b/p2p/security/tls/conn.go index 6353eac80b..3ebc7aefc5 100644 --- a/p2p/security/tls/conn.go +++ b/p2p/security/tls/conn.go @@ -4,6 +4,7 @@ import ( "crypto/tls" ci "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/sec" ) @@ -14,8 +15,9 @@ type conn struct { localPeer peer.ID privKey ci.PrivKey - remotePeer peer.ID - remotePubKey ci.PubKey + remotePeer peer.ID + remotePubKey ci.PubKey + connectionState network.ConnectionState } var _ sec.SecureConn = &conn{} @@ -35,3 +37,7 @@ func (c *conn) RemotePeer() peer.ID { func (c *conn) RemotePublicKey() ci.PubKey { return c.remotePubKey } + +func (c *conn) ConnState() network.ConnectionState { + return c.connectionState +} diff --git a/p2p/security/tls/transport.go b/p2p/security/tls/transport.go index f6aa64f6ab..695f648465 100644 --- a/p2p/security/tls/transport.go +++ b/p2p/security/tls/transport.go @@ -11,7 +11,9 @@ import ( "github.com/libp2p/go-libp2p/core/canonicallog" ci "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" manet "github.com/multiformats/go-multiaddr/net" @@ -26,10 +28,11 @@ type Transport struct { localPeer peer.ID privKey ci.PrivKey + muxers []protocol.ID } // New creates a TLS encrypted transport -func New(key ci.PrivKey) (*Transport, error) { +func New(key ci.PrivKey, muxers []protocol.ID) (*Transport, error) { id, err := peer.IDFromPrivateKey(key) if err != nil { return nil, err @@ -37,6 +40,7 @@ func New(key ci.PrivKey) (*Transport, error) { t := &Transport{ localPeer: id, privKey: key, + muxers: muxers, } identity, err := NewIdentity(key) @@ -53,6 +57,12 @@ var _ sec.SecureTransport = &Transport{} // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) + muxers := make([]string, 0, len(t.muxers)) + for _, muxer := range t.muxers { + muxers = append(muxers, string(muxer)) + } + // Prepend the prefered muxers list to TLS config. + config.NextProtos = append(muxers, config.NextProtos...) cs, err := t.handshake(ctx, tls.Server(insecure, config), keyCh) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) @@ -73,6 +83,12 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // notice this after 1 RTT when calling Read. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { config, keyCh := t.identity.ConfigForPeer(p) + muxers := make([]string, 0, len(t.muxers)) + for _, muxer := range t.muxers { + muxers = append(muxers, (string)(muxer)) + } + // Prepend the prefered muxers list to TLS config. + config.NextProtos = append(muxers, config.NextProtos...) cs, err := t.handshake(ctx, tls.Client(insecure, config), keyCh) if err != nil { insecure.Close() @@ -89,6 +105,7 @@ func (t *Transport) handshake(ctx context.Context, tlsConn *tls.Conn, keyCh <-ch } }() + // handshaking... if err := tlsConn.HandshakeContext(ctx); err != nil { return nil, err } @@ -111,11 +128,23 @@ func (t *Transport) setupConn(tlsConn *tls.Conn, remotePubKey ci.PubKey) (sec.Se if err != nil { return nil, err } + + nextProto := tlsConn.ConnectionState().NegotiatedProtocol + // The special ALPN extension value "libp2p" is used by libp2p versions + // that don't support early muxer negotiation. If we see this sepcial + // value selected, that means we are handshaking with a version that does + // not support early muxer negotiation. In this case return empty nextProto + // to indicate no muxer is selected. + if nextProto == "libp2p" { + nextProto = "" + } + return &conn{ - Conn: tlsConn, - localPeer: t.localPeer, - privKey: t.privKey, - remotePeer: remotePeerID, - remotePubKey: remotePubKey, + Conn: tlsConn, + localPeer: t.localPeer, + privKey: t.privKey, + remotePeer: remotePeerID, + remotePubKey: remotePubKey, + connectionState: network.ConnectionState{NextProto: nextProto}, }, nil } diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index c20fa494cb..59fa8bdae8 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -22,6 +22,7 @@ import ( ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/sec" "github.com/stretchr/testify/assert" @@ -125,9 +126,9 @@ func TestHandshakeSucceeds(t *testing.T) { } // Use standard transports with default TLS configuration - clientTransport, err := New(clientKey) + clientTransport, err := New(clientKey, nil) require.NoError(t, err) - serverTransport, err := New(serverKey) + serverTransport, err := New(serverKey, nil) require.NoError(t, err) t.Run("standard TLS with extension not critical", func(t *testing.T) { @@ -175,6 +176,81 @@ func TestHandshakeSucceeds(t *testing.T) { }) } +type testcase struct { + clientProtos []protocol.ID + serverProtos []protocol.ID + expectedResult string +} + +func TestHandshakeWithNextProtoSucceeds(t *testing.T) { + + tests := []testcase{ + {clientProtos: nil, serverProtos: nil, expectedResult: ""}, + {[]protocol.ID{"muxer1/1.0.0", "muxer2/1.0.1"}, []protocol.ID{"muxer2/1.0.1", "muxer1/1.0.0"}, "muxer2/1.0.1"}, + {[]protocol.ID{"muxer1/1.0.0", "muxer2/1.0.1", "libp2p"}, []protocol.ID{"muxer2/1.0.1", "muxer1/1.0.0", "libp2p"}, "muxer2/1.0.1"}, + {[]protocol.ID{"muxer1/1.0.0", "libp2p"}, []protocol.ID{"libp2p"}, ""}, + {[]protocol.ID{"libp2p"}, []protocol.ID{"libp2p"}, ""}, + {[]protocol.ID{"muxer1"}, []protocol.ID{}, ""}, + {[]protocol.ID{}, []protocol.ID{"muxer1"}, ""}, + {[]protocol.ID{"muxer2"}, []protocol.ID{"muxer1"}, ""}, + } + + clientID, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + + handshake := func(t *testing.T, clientTransport *Transport, serverTransport *Transport, expectedMuxer string) { + clientInsecureConn, serverInsecureConn := connect(t) + + serverConnChan := make(chan sec.SecureConn) + go func() { + serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + require.NoError(t, err) + serverConnChan <- serverConn + }() + + clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.NoError(t, err) + defer clientConn.Close() + + var serverConn sec.SecureConn + select { + case serverConn = <-serverConnChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server to accept a connection") + } + defer serverConn.Close() + + require.Equal(t, clientConn.LocalPeer(), clientID) + require.Equal(t, serverConn.LocalPeer(), serverID) + require.True(t, clientConn.LocalPrivateKey().Equals(clientKey), "client private key mismatch") + require.True(t, serverConn.LocalPrivateKey().Equals(serverKey), "server private key mismatch") + require.Equal(t, clientConn.RemotePeer(), serverID) + require.Equal(t, serverConn.RemotePeer(), clientID) + require.True(t, clientConn.RemotePublicKey().Equals(serverKey.GetPublic()), "server public key mismatch") + require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "client public key mismatch") + require.Equal(t, clientConn.ConnState().NextProto, expectedMuxer) + // exchange some data + _, err = serverConn.Write([]byte("foobar")) + require.NoError(t, err) + b := make([]byte, 6) + _, err = clientConn.Read(b) + require.NoError(t, err) + require.Equal(t, string(b), "foobar") + } + + // Iterate through the NextProto combinations. + for _, test := range tests { + clientTransport, err := New(clientKey, test.clientProtos) + require.NoError(t, err) + serverTransport, err := New(serverKey, test.serverProtos) + require.NoError(t, err) + + t.Run("TLS handshake with ALPN extension", func(t *testing.T) { + handshake(t, clientTransport, serverTransport, test.expectedResult) + }) + } +} + // crypto/tls' cancellation logic works by spinning up a separate Go routine that watches the ctx. // If the ctx is canceled, it kills the handshake. // We need to make sure that the handshake doesn't complete before that Go routine picks up the cancellation. @@ -192,9 +268,9 @@ func TestHandshakeConnectionCancellations(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - clientTransport, err := New(clientKey) + clientTransport, err := New(clientKey, nil) require.NoError(t, err) - serverTransport, err := New(serverKey) + serverTransport, err := New(serverKey, nil) require.NoError(t, err) t.Run("cancel outgoing connection", func(t *testing.T) { @@ -244,9 +320,9 @@ func TestPeerIDMismatch(t *testing.T) { _, clientKey := createPeer(t) serverID, serverKey := createPeer(t) - serverTransport, err := New(serverKey) + serverTransport, err := New(serverKey, nil) require.NoError(t, err) - clientTransport, err := New(clientKey) + clientTransport, err := New(clientKey, nil) require.NoError(t, err) t.Run("for outgoing connections", func(t *testing.T) { @@ -521,9 +597,9 @@ func TestInvalidCerts(t *testing.T) { tr := transforms[i] t.Run(fmt.Sprintf("client offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey) + serverTransport, err := New(serverKey, nil) require.NoError(t, err) - clientTransport, err := New(clientKey) + clientTransport, err := New(clientKey, nil) require.NoError(t, err) tr.apply(clientTransport.identity) @@ -564,10 +640,10 @@ func TestInvalidCerts(t *testing.T) { }) t.Run(fmt.Sprintf("server offending: %s", tr.name), func(t *testing.T) { - serverTransport, err := New(serverKey) + serverTransport, err := New(serverKey, nil) require.NoError(t, err) tr.apply(serverTransport.identity) - clientTransport, err := New(clientKey) + clientTransport, err := New(clientKey, nil) require.NoError(t, err) clientInsecureConn, serverInsecureConn := connect(t) diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index a012dea3d7..1af53f2abc 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -90,6 +90,12 @@ func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } +// ConnState is the state of security connection. +// It is empty if not supported. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{} +} + // LocalMultiaddr returns the local Multiaddr associated func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr