Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

noise: make it possible for the server to send early data #1750

Merged
merged 6 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions p2p/security/noise/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
return fmt.Errorf("error initializing handshake state: %w", err)
}

payload, err := s.generateHandshakePayload(kp)
if err != nil {
return err
}

// set a deadline to complete the handshake, if one has been supplied.
// clear it after we're done.
if deadline, ok := ctx.Deadline(); ok {
Expand All @@ -82,7 +77,7 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
// will be the size of the maximum handshake message for the Noise XX pattern.
// Also, since we prefix every noise handshake message with its length, we need to account for
// it when we fetch the buffer from the pool
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*chacha20poly1305.Overhead
maxMsgSize := 2*noise.DH25519.DHLen() + 2*chacha20poly1305.Overhead + 1024 /* payload */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering if there is a more efficient way of setting length rather than a hard limit here, I don't have a better way readily available, just want to throw some thoughts here to explore the options.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to a fixed value of 2048 in 24459c5.

It's not even a hard limit, it's just the size we use to get a buffer from the buffer pool. If we get a buffer that's too small, the worst thing that happens is that append would allocate a larger buffer.

We can just use a large enough buffer here. What exactly "large enough" means depends on the amount of early data we send, but 2048 should be enough for practically all use cases. And if not, we can live with that one allocation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification, that makes sense. Sorry I was not aware that we can do extra allocation if the limit is exceeded.

hbuf := pool.Get(maxMsgSize + LengthPrefixLength)
defer pool.Put(hbuf)

Expand All @@ -102,12 +97,22 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil {
rcvdEd, err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
if err != nil {
return err
}
if s.earlyDataHandler != nil {
if err := s.earlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early data handler may need to know the handshake state to drive its internal state machine, for example if this is the server or client side, at which stage is the early data received. Do you think we should pass the handshake state info to the earlyDataHandler too? or maybe that is available in other ways already and I missed it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both client and server, the data is sent with their respective first flights.

I think the nicest way to resolve this is to pass separate client and server EarlyDataHandlers to the EarlyData option. I applied that change in 0b36832. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks great, thanks for the changes which make the Noise based muxer selection and other early data utilization smoother.

return err
}
}

// stage 2 //
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted)
payload, err := s.generateHandshakePayload(kp, nil)
if err != nil {
return err
}
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
Expand All @@ -127,6 +132,14 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
// stage 1 //
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) +
// MAC(payload is encrypted)
var ed []byte
if s.earlyDataHandler != nil {
ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID)
}
payload, err := s.generateHandshakePayload(kp, ed)
if err != nil {
return err
}
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil {
return fmt.Errorf("error sending handshake message: %w", err)
}
Expand All @@ -136,7 +149,9 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) {
if err != nil {
return fmt.Errorf("error reading handshake message: %w", err)
}
return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
// we don't expect any early data on this message
_, err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic())
return err
}
}

Expand Down Expand Up @@ -214,8 +229,8 @@ func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte,

// generateHandshakePayload creates a libp2p handshake payload with a
// signature of our static noise key.
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) {
// obtain the public key from the handshake session so we can sign it with
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, data []byte) ([]byte, error) {
// obtain the public key from the handshake session, so we can sign it with
// our libp2p secret key.
localKeyRaw, err := crypto.MarshalPublicKey(s.LocalPublicKey())
if err != nil {
Expand All @@ -230,10 +245,11 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt
}

// create payload
payload := new(pb.NoiseHandshakePayload)
payload.IdentityKey = localKeyRaw
payload.IdentitySig = signedPayload
payloadEnc, err := proto.Marshal(payload)
payloadEnc, err := proto.Marshal(&pb.NoiseHandshakePayload{
IdentityKey: localKeyRaw,
IdentitySig: signedPayload,
Data: data,
})
if err != nil {
return nil, fmt.Errorf("error marshaling handshake payload: %w", err)
}
Expand All @@ -242,44 +258,45 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt

// handleRemoteHandshakePayload unmarshals the handshake payload object sent
// by the remote peer and validates the signature against the peer's static Noise key.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error {
// It returns the data attached to the payload.
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) ([]byte, error) {
// unmarshal payload
nhp := new(pb.NoiseHandshakePayload)
err := proto.Unmarshal(payload, nhp)
if err != nil {
return fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
return nil, fmt.Errorf("error unmarshaling remote handshake payload: %w", err)
}

// unpack remote peer's public libp2p key
remotePubKey, err := crypto.UnmarshalPublicKey(nhp.GetIdentityKey())
if err != nil {
return err
return nil, err
}
id, err := peer.IDFromPublicKey(remotePubKey)
if err != nil {
return err
return nil, err
}

// check the peer ID for:
// * all outbound connection
// * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID)
if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) {
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms.
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty())
}

// verify payload is signed by asserted remote libp2p key.
sig := nhp.GetIdentitySig()
msg := append([]byte(payloadSigPrefix), remoteStatic...)
ok, err := remotePubKey.Verify(msg, sig)
if err != nil {
return fmt.Errorf("error verifying signature: %w", err)
return nil, fmt.Errorf("error verifying signature: %w", err)
} else if !ok {
return fmt.Errorf("handshake signature invalid")
return nil, fmt.Errorf("handshake signature invalid")
}

// set remote peer key and id
s.remoteID = id
s.remoteKey = remotePubKey
return nil
return nhp.Data, nil
}
113 changes: 72 additions & 41 deletions p2p/security/noise/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,67 +447,98 @@ func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []b
}

func TestEarlyDataAccepted(t *testing.T) {
handshake := func(t *testing.T, client, server EarlyDataHandler) {
t.Helper()
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(server))
require.NoError(t, err)

initConn, respConn := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()

conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)
require.NoError(t, err)
defer conn.Close()
}

var receivedEarlyData []byte
serverEDH := &earlyDataHandler{
receivingEDH := &earlyDataHandler{
received: func(_ context.Context, _ net.Conn, data []byte) error {
receivedEarlyData = data
return nil
},
}
clientEDH := &earlyDataHandler{
sendingEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH))
require.NoError(t, err)

initConn, respConn := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()

conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)
require.NoError(t, err)
defer conn.Close()
t.Run("client sending", func(t *testing.T) {
handshake(t, sendingEDH, receivingEDH)
require.Equal(t, []byte("foobar"), receivedEarlyData)
receivedEarlyData = nil
})

require.Equal(t, []byte("foobar"), receivedEarlyData)
t.Run("server sending", func(t *testing.T) {
handshake(t, receivingEDH, sendingEDH)
require.Equal(t, []byte("foobar"), receivedEarlyData)
receivedEarlyData = nil
})
}

func TestEarlyDataRejected(t *testing.T) {
serverEDH := &earlyDataHandler{
handshake := func(t *testing.T, client, server EarlyDataHandler) (clientErr, serverErr error) {
t.Helper()
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(client))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(server))
require.NoError(t, err)

initConn, respConn := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()

_, clientErr = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)

select {
case <-time.After(500 * time.Millisecond):
t.Fatal("timeout")
case err := <-errChan:
serverErr = err
}
return
}

receivingEDH := &earlyDataHandler{
marten-seemann marked this conversation as resolved.
Show resolved Hide resolved
received: func(_ context.Context, _ net.Conn, data []byte) error { return errors.New("nope") },
}
clientEDH := &earlyDataHandler{
sendingEDH := &earlyDataHandler{
send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") },
}
initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH))
require.NoError(t, err)
tpt := newTestTransport(t, crypto.Ed25519, 2048)
respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH))
require.NoError(t, err)

initConn, respConn := newConnPair(t)

errChan := make(chan error)
go func() {
_, err := respTransport.SecureInbound(context.Background(), initConn, "")
errChan <- err
}()
t.Run("client sending", func(t *testing.T) {
clientErr, serverErr := handshake(t, sendingEDH, receivingEDH)
require.Error(t, clientErr)
require.EqualError(t, serverErr, "nope")

_, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID)
require.Error(t, err)
})

select {
case <-time.After(500 * time.Millisecond):
t.Fatal("timeout")
case err := <-errChan:
require.EqualError(t, err, "nope")
}
t.Run("server sending", func(t *testing.T) {
clientErr, serverErr := handshake(t, receivingEDH, sendingEDH)
require.Error(t, serverErr)
require.EqualError(t, clientErr, "nope")
})
}

func TestEarlyDataAcceptedWithNoHandler(t *testing.T) {
Expand Down