diff --git a/p2p/security/noise/crypto_test.go b/p2p/security/noise/crypto_test.go index ca5125cd70..87efb8487f 100644 --- a/p2p/security/noise/crypto_test.go +++ b/p2p/security/noise/crypto_test.go @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) { init, resp := net.Pipe() _ = resp.Close() - session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", true) + session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, true) _, err := session.encrypt(nil, []byte("hi")) if err == nil { t.Error("expected encryption error when handshake incomplete") diff --git a/p2p/security/noise/handshake.go b/p2p/security/noise/handshake.go index 65fe294bda..504a2b155c 100644 --- a/p2p/security/noise/handshake.go +++ b/p2p/security/noise/handshake.go @@ -57,6 +57,7 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { Pattern: noise.HandshakeXX, Initiator: s.initiator, StaticKeypair: kp, + Prologue: s.prologue, } hs, err := noise.NewHandshakeState(cfg) diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index 9675cc03bb..9770178a5c 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -34,11 +34,14 @@ type secureSession struct { enc *noise.CipherState dec *noise.CipherState + + // noise prologue + prologue []byte } // newSecureSession creates a Noise session over the given insecureConn Conn, using // the libp2p identity keypair from the given Transport. -func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, initiator bool) (*secureSession, error) { +func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiator bool) (*secureSession, error) { s := &secureSession{ insecureConn: insecure, insecureReader: bufio.NewReader(insecure), @@ -46,6 +49,7 @@ func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, re localID: tpt.localID, localKey: tpt.privateKey, remoteID: remote, + prologue: prologue, } // the go-routine we create to run the handshake will diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go new file mode 100644 index 0000000000..b996d95b16 --- /dev/null +++ b/p2p/security/noise/session_transport.go @@ -0,0 +1,58 @@ +package noise + +import ( + "context" + "net" + + "github.com/libp2p/go-libp2p-core/canonicallog" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/sec" + manet "github.com/multiformats/go-multiaddr/net" +) + +type SessionOption = func(*SessionTransport) error + +var _ sec.SecureTransport = &SessionTransport{} + +// SessionTransport can be used +// to provide per-connection options +type SessionTransport struct { + t *Transport + // options + prologue []byte +} + +// SecureInbound runs the Noise handshake as the responder. +// If p is empty, connections from any peer are accepted. +func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { + c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, false) + if err != nil { + addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) + if maErr == nil { + canonicallog.LogPeerStatus(100, p, addr, "handshake_failure", "noise", "err", err.Error()) + } + } + return c, err +} + +// SecureOutbound runs the Noise handshake as the initiator. +func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { + return newSecureSession(i.t, ctx, insecure, p, i.prologue, true) +} + +func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) { + st := &SessionTransport{t: t} + for _, opt := range opts { + if err := opt(st); err != nil { + return nil, err + } + } + return st, nil +} + +func Prologue(prologue []byte) SessionOption { + return func(s *SessionTransport) error { + s.prologue = prologue + return nil + } +} diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index c39da8697b..689a5e123c 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -40,7 +40,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(t, ctx, insecure, p, false) + c, err := newSecureSession(t, ctx, insecure, p, nil, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { @@ -52,5 +52,5 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // SecureOutbound runs the Noise handshake as the initiator. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, true) + return newSecureSession(t, ctx, insecure, p, nil, true) } diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index e88fbe5943..4ce476139c 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -372,3 +372,57 @@ func TestReadUnencryptedFails(t *testing.T) { require.Error(t, err) require.Equal(t, 0, afterLen) } + +func TestPrologueMatches(t *testing.T) { + commonPrologue := []byte("test") + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + respTransport := newTestTransport(t, crypto.Ed25519, 2048) + + initConn, respConn := newConnPair(t) + + done := make(chan struct{}) + + go func() { + defer close(done) + tpt, err := initTransport. + WithSessionOptions(Prologue(commonPrologue)) + require.NoError(t, err) + conn, err := tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID) + require.NoError(t, err) + defer conn.Close() + }() + + tpt, err := respTransport. + WithSessionOptions(Prologue(commonPrologue)) + require.NoError(t, err) + conn, err := tpt.SecureInbound(context.TODO(), respConn, "") + require.NoError(t, err) + defer conn.Close() + <-done +} + +func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) { + initPrologue, respPrologue := []byte("initPrologue"), []byte("respPrologue") + initTransport := newTestTransport(t, crypto.Ed25519, 2048) + respTransport := newTestTransport(t, crypto.Ed25519, 2048) + + initConn, respConn := newConnPair(t) + + done := make(chan struct{}) + + go func() { + defer close(done) + tpt, err := initTransport. + WithSessionOptions(Prologue(initPrologue)) + require.NoError(t, err) + _, err = tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID) + require.Error(t, err) + }() + + tpt, err := respTransport.WithSessionOptions(Prologue(respPrologue)) + require.NoError(t, err) + + _, err = tpt.SecureInbound(context.TODO(), respConn, "") + require.Error(t, err) + <-done +}