Skip to content
This repository has been archived by the owner on Feb 24, 2021. It is now read-only.

Commit

Permalink
Merge pull request #25 from libp2p/fix/secio-handshake
Browse files Browse the repository at this point in the history
refactor secio handshake
  • Loading branch information
Stebalien authored Dec 15, 2017
2 parents 255efc8 + 6d7c50d commit 326d97c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 69 deletions.
79 changes: 57 additions & 22 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"time"

proto "github.com/gogo/protobuf/proto"
logging "github.com/ipfs/go-log"
ci "github.com/libp2p/go-libp2p-crypto"
peer "github.com/libp2p/go-libp2p-peer"
Expand All @@ -25,6 +26,9 @@ var ErrUnsupportedKeyType = errors.New("unsupported key type")
// ErrClosed signals the closing of a connection.
var ErrClosed = errors.New("connection closed")

// ErrBadSig signals that the peer sent us a handshake packet with a bad signature.
var ErrBadSig = errors.New("bad signature")

// ErrEcho is returned when we're attempting to handshake with the same keys and nonces.
var ErrEcho = errors.New("same keys and nonces. one side talking to self")

Expand Down Expand Up @@ -105,6 +109,26 @@ func hashSha256(data []byte) mh.Multihash {
// keys, IDs, and initiate communication, assigning all necessary params.
// requires the duplex channel to be a msgio.ReadWriter (for framed messaging)
func (s *secureSession) runHandshake(ctx context.Context) error {
defer log.EventBegin(ctx, "secureHandshake", s).Done()

result := make(chan error, 1)
go func() {
// do *not* close the channel (will look like a success).
result <- s.runHandshakeSync()
}()

var err error
select {
case <-ctx.Done():
// State unknown. We *have* to close this.
s.insecure.Close()
err = ctx.Err()
case err = <-result:
}
return err
}

func (s *secureSession) runHandshakeSync() error {
// =============================================================================
// step 1. Propose -- propose cipher suite + send pubkeys + nonce

Expand All @@ -116,8 +140,6 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return err
}

defer log.EventBegin(ctx, "secureHandshake", s).Done()

s.local.permanentPubKey = s.localKey.GetPublic()
myPubKeyBytes, err := s.local.permanentPubKey.Bytes()
if err != nil {
Expand All @@ -134,18 +156,24 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// log.Debugf("1.0 Propose: nonce:%s exchanges:%s ciphers:%s hashes:%s",
// nonceOut, SupportedExchanges, SupportedCiphers, SupportedHashes)

// Send Propose packet (respects ctx)
proposeOutBytes, err := writeMsgCtx(ctx, s.insecureM, proposeOut)
// Marshal our propose packet
proposeOutBytes, err := proto.Marshal(proposeOut)
if err != nil {
return err
}

// Receive + Parse their Propose packet and generate an Exchange packet.
proposeIn := new(pb.Propose)
proposeInBytes, err := readMsgCtx(ctx, s.insecureM, proposeIn)
// Send Propose packet and Receive their Propose packet
proposeInBytes, err := readWriteMsg(s.insecureM, proposeOutBytes)
if err != nil {
return err
}
defer s.insecureM.ReleaseMsg(proposeInBytes)

// Parse their propose packet
proposeIn := new(pb.Propose)
if err = proto.Unmarshal(proposeInBytes, proposeIn); err != nil {
return err
}

// log.Debugf("1.0.1 Propose recv: nonce:%s exchanges:%s ciphers:%s hashes:%s",
// proposeIn.GetRand(), proposeIn.GetExchanges(), proposeIn.GetCiphers(), proposeIn.GetHashes())
Expand Down Expand Up @@ -208,6 +236,9 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// Generate EphemeralPubKey
var genSharedKey ci.GenSharedKey
s.local.ephemeralPubKey, genSharedKey, err = ci.GenerateEKeyPair(s.local.curveT)
if err != nil {
return err
}

// Gather corpus to sign.
selectionOut := new(bytes.Buffer)
Expand All @@ -224,14 +255,22 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return err
}

// Send Propose packet (respects ctx)
if _, err := writeMsgCtx(ctx, s.insecureM, exchangeOut); err != nil {
// Marshal our exchange packet
exchangeOutBytes, err := proto.Marshal(exchangeOut)
if err != nil {
return err
}

// Receive + Parse their Exchange packet.
// Send Exchange packet and receive their Exchange packet
exchangeInBytes, err := readWriteMsg(s.insecureM, exchangeOutBytes)
if err != nil {
return err
}
defer s.insecureM.ReleaseMsg(exchangeInBytes)

// Parse their Exchange packet.
exchangeIn := new(pb.Exchange)
if _, err := readMsgCtx(ctx, s.insecureM, exchangeIn); err != nil {
if err = proto.Unmarshal(exchangeInBytes, exchangeIn); err != nil {
return err
}

Expand All @@ -256,9 +295,8 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
}

if !sigOK {
err := errors.New("Bad signature!")
// log.Error("2.1 Verify: failed: %s", err)
return err
// log.Error("2.1 Verify: failed: %s", ErrBadSig)
return ErrBadSig
}
// log.Debugf("2.1 Verify: signature verified.")

Expand Down Expand Up @@ -312,16 +350,13 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
s.secure = msgio.Combine(w, r).(msgio.ReadWriteCloser)

// log.Debug("3.0 finish. sending: %v", proposeIn.GetRand())
// send their Nonce.
if _, err := s.secure.Write(proposeIn.GetRand()); err != nil {
return fmt.Errorf("Failed to write Finish nonce: %s", err)
}

// read our Nonce
nonceOut2 := make([]byte, len(nonceOut))
if _, err := io.ReadFull(s.secure, nonceOut2); err != nil {
return fmt.Errorf("Failed to read Finish nonce: %s", err)
// send their Nonce and receive ours
nonceOut2, err := readWriteMsg(s.secure, proposeIn.GetRand())
if err != nil {
return err
}
defer s.secure.ReleaseMsg(nonceOut2)

// log.Debug("3.0 finish.\n\texpect: %v\n\tactual: %v", nonceOut, nonceOut2)
if !bytes.Equal(nonceOut, nonceOut2) {
Expand Down
64 changes: 17 additions & 47 deletions rw.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package secio

import (
"context"
"crypto/cipher"
"crypto/hmac"
"encoding/binary"
Expand All @@ -10,7 +9,6 @@ import (
"io"
"sync"

proto "github.com/gogo/protobuf/proto"
msgio "github.com/libp2p/go-msgio"
mpool "github.com/libp2p/go-msgio/mpool"
)
Expand Down Expand Up @@ -201,6 +199,7 @@ func (r *etmReader) ReadMsg() ([]byte, error) {

n, err := r.macCheckThenDecrypt(msg)
if err != nil {
r.msg.ReleaseMsg(msg)
return nil, err
}
return msg[:n], nil
Expand Down Expand Up @@ -243,53 +242,24 @@ func (r *etmReader) ReleaseMsg(b []byte) {
r.msg.ReleaseMsg(b)
}

// writeMsgCtx is used by the
func writeMsgCtx(ctx context.Context, w msgio.Writer, msg proto.Message) ([]byte, error) {
enc, err := proto.Marshal(msg)
if err != nil {
return nil, err
}

// write in a goroutine so we can exit when our context is cancelled.
done := make(chan error)
go func(m []byte) {
err := w.WriteMsg(m)
select {
case done <- err:
case <-ctx.Done():
}
}(enc)

select {
case <-ctx.Done():
return nil, ctx.Err()
case e := <-done:
return enc, e
}
}

func readMsgCtx(ctx context.Context, r msgio.Reader, p proto.Message) ([]byte, error) {
var msg []byte

// read in a goroutine so we can exit when our context is cancelled.
done := make(chan error)
// read and write a message at the same time.
func readWriteMsg(c msgio.ReadWriter, out []byte) ([]byte, error) {
wresult := make(chan error)
go func() {
var err error
msg, err = r.ReadMsg()
select {
case done <- err:
case <-ctx.Done():
}
wresult <- c.WriteMsg(out)
}()

select {
case <-ctx.Done():
return nil, ctx.Err()
case e := <-done:
if e != nil {
return nil, e
}
}
msg, err1 := c.ReadMsg()

// Always wait for the read to finish.
err2 := <-wresult

return msg, proto.Unmarshal(msg, p)
if err1 != nil {
return nil, err1
}
if err2 != nil {
c.ReleaseMsg(msg)
return nil, err2
}
return msg, nil
}

0 comments on commit 326d97c

Please sign in to comment.