-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
noise: make it possible for the server to send early data #1750
Changes from 1 commit
65270f1
d27a209
4f8a466
0fe5519
e6c29ed
d2ffc90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,11 +64,6 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { | |
return fmt.Errorf("error initializing handshake state: %w", err) | ||
} | ||
|
||
payload, err := s.generateHandshakePayload(kp) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// set a deadline to complete the handshake, if one has been supplied. | ||
// clear it after we're done. | ||
if deadline, ok := ctx.Deadline(); ok { | ||
|
@@ -82,7 +77,7 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { | |
// will be the size of the maximum handshake message for the Noise XX pattern. | ||
// Also, since we prefix every noise handshake message with its length, we need to account for | ||
// it when we fetch the buffer from the pool | ||
maxMsgSize := 2*noise.DH25519.DHLen() + len(payload) + 2*chacha20poly1305.Overhead | ||
maxMsgSize := 2*noise.DH25519.DHLen() + 2*chacha20poly1305.Overhead + 1024 /* payload */ | ||
hbuf := pool.Get(maxMsgSize + LengthPrefixLength) | ||
defer pool.Put(hbuf) | ||
|
||
|
@@ -102,12 +97,22 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { | |
if err != nil { | ||
return fmt.Errorf("error reading handshake message: %w", err) | ||
} | ||
if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil { | ||
rcvdEd, err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) | ||
if err != nil { | ||
return err | ||
} | ||
if s.earlyDataHandler != nil { | ||
if err := s.earlyDataHandler.Received(ctx, s.insecureConn, rcvdEd); err != nil { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The early data handler may need to know the handshake state to drive its internal state machine, for example if this is the server or client side, at which stage is the early data received. Do you think we should pass the handshake state info to the earlyDataHandler too? or maybe that is available in other ways already and I missed it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For both client and server, the data is sent with their respective first flights. I think the nicest way to resolve this is to pass separate client and server There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That looks great, thanks for the changes which make the Noise based muxer selection and other early data utilization smoother. |
||
return err | ||
} | ||
} | ||
|
||
// stage 2 // | ||
// Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted) | ||
payload, err := s.generateHandshakePayload(kp, nil) | ||
if err != nil { | ||
return err | ||
} | ||
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil { | ||
return fmt.Errorf("error sending handshake message: %w", err) | ||
} | ||
|
@@ -127,6 +132,14 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { | |
// stage 1 // | ||
// Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) + | ||
// MAC(payload is encrypted) | ||
var ed []byte | ||
if s.earlyDataHandler != nil { | ||
ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) | ||
} | ||
payload, err := s.generateHandshakePayload(kp, ed) | ||
if err != nil { | ||
return err | ||
} | ||
if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil { | ||
return fmt.Errorf("error sending handshake message: %w", err) | ||
} | ||
|
@@ -136,7 +149,9 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { | |
if err != nil { | ||
return fmt.Errorf("error reading handshake message: %w", err) | ||
} | ||
return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) | ||
// we don't expect any early data on this message | ||
_, err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) | ||
return err | ||
} | ||
} | ||
|
||
|
@@ -214,8 +229,8 @@ func (s *secureSession) readHandshakeMessage(hs *noise.HandshakeState) ([]byte, | |
|
||
// generateHandshakePayload creates a libp2p handshake payload with a | ||
// signature of our static noise key. | ||
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byte, error) { | ||
// obtain the public key from the handshake session so we can sign it with | ||
func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey, data []byte) ([]byte, error) { | ||
// obtain the public key from the handshake session, so we can sign it with | ||
// our libp2p secret key. | ||
localKeyRaw, err := crypto.MarshalPublicKey(s.LocalPublicKey()) | ||
if err != nil { | ||
|
@@ -230,10 +245,11 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt | |
} | ||
|
||
// create payload | ||
payload := new(pb.NoiseHandshakePayload) | ||
payload.IdentityKey = localKeyRaw | ||
payload.IdentitySig = signedPayload | ||
payloadEnc, err := proto.Marshal(payload) | ||
payloadEnc, err := proto.Marshal(&pb.NoiseHandshakePayload{ | ||
IdentityKey: localKeyRaw, | ||
IdentitySig: signedPayload, | ||
Data: data, | ||
}) | ||
if err != nil { | ||
return nil, fmt.Errorf("error marshaling handshake payload: %w", err) | ||
} | ||
|
@@ -242,44 +258,45 @@ func (s *secureSession) generateHandshakePayload(localStatic noise.DHKey) ([]byt | |
|
||
// handleRemoteHandshakePayload unmarshals the handshake payload object sent | ||
// by the remote peer and validates the signature against the peer's static Noise key. | ||
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) error { | ||
// It returns the data attached to the payload. | ||
func (s *secureSession) handleRemoteHandshakePayload(payload []byte, remoteStatic []byte) ([]byte, error) { | ||
// unmarshal payload | ||
nhp := new(pb.NoiseHandshakePayload) | ||
err := proto.Unmarshal(payload, nhp) | ||
if err != nil { | ||
return fmt.Errorf("error unmarshaling remote handshake payload: %w", err) | ||
return nil, fmt.Errorf("error unmarshaling remote handshake payload: %w", err) | ||
} | ||
|
||
// unpack remote peer's public libp2p key | ||
remotePubKey, err := crypto.UnmarshalPublicKey(nhp.GetIdentityKey()) | ||
if err != nil { | ||
return err | ||
return nil, err | ||
} | ||
id, err := peer.IDFromPublicKey(remotePubKey) | ||
if err != nil { | ||
return err | ||
return nil, err | ||
} | ||
|
||
// check the peer ID for: | ||
// * all outbound connection | ||
// * inbound connections, if we know which peer we want to connect to (SecureInbound called with a peer ID) | ||
if (s.initiator && s.remoteID != id) || (!s.initiator && s.remoteID != "" && s.remoteID != id) { | ||
// use Pretty() as it produces the full b58-encoded string, rather than abbreviated forms. | ||
return fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty()) | ||
return nil, fmt.Errorf("peer id mismatch: expected %s, but remote key matches %s", s.remoteID.Pretty(), id.Pretty()) | ||
} | ||
|
||
// verify payload is signed by asserted remote libp2p key. | ||
sig := nhp.GetIdentitySig() | ||
msg := append([]byte(payloadSigPrefix), remoteStatic...) | ||
ok, err := remotePubKey.Verify(msg, sig) | ||
if err != nil { | ||
return fmt.Errorf("error verifying signature: %w", err) | ||
return nil, fmt.Errorf("error verifying signature: %w", err) | ||
} else if !ok { | ||
return fmt.Errorf("handshake signature invalid") | ||
return nil, fmt.Errorf("handshake signature invalid") | ||
} | ||
|
||
// set remote peer key and id | ||
s.remoteID = id | ||
s.remoteKey = remotePubKey | ||
return nil | ||
return nhp.Data, nil | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering if there is a more efficient way of setting length rather than a hard limit here, I don't have a better way readily available, just want to throw some thoughts here to explore the options.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to a fixed value of 2048 in 24459c5.
It's not even a hard limit, it's just the size we use to get a buffer from the buffer pool. If we get a buffer that's too small, the worst thing that happens is that
append
would allocate a larger buffer.We can just use a large enough buffer here. What exactly "large enough" means depends on the amount of early data we send, but 2048 should be enough for practically all use cases. And if not, we can live with that one allocation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification, that makes sense. Sorry I was not aware that we can do extra allocation if the limit is exceeded.