Skip to content

Commit

Permalink
use the transport.Upgrader interface
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Jan 2, 2022
1 parent 5155c77 commit d7b7473
Show file tree
Hide file tree
Showing 20 changed files with 155 additions and 259 deletions.
29 changes: 19 additions & 10 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/libp2p/go-libp2p-core/peerstore"
"github.com/libp2p/go-libp2p-core/pnet"
"github.com/libp2p/go-libp2p-core/routing"
"github.com/libp2p/go-libp2p-core/sec"
"github.com/libp2p/go-libp2p-core/transport"
"github.com/libp2p/go-libp2p-peerstore/pstoremem"

Expand Down Expand Up @@ -139,30 +140,38 @@ func (cfg *Config) makeSwarm() (*swarm.Swarm, error) {
return swarm.NewSwarm(pid, cfg.Peerstore, swarm.WithMetrics(cfg.Reporter), swarm.WithConnectionGater(cfg.ConnectionGater))
}

func (cfg *Config) addTransports(h host.Host) (err error) {
func (cfg *Config) addTransports(h host.Host) error {
swrm, ok := h.Network().(transport.TransportNetwork)
if !ok {
// Should probably skip this if no transports.
return fmt.Errorf("swarm does not support transports")
}
upgrader := new(tptu.Upgrader)
upgrader.PSK = cfg.PSK
upgrader.ConnGater = cfg.ConnectionGater
var secure sec.SecureMuxer
if cfg.Insecure {
upgrader.Secure = makeInsecureTransport(h.ID(), cfg.PeerKey)
secure = makeInsecureTransport(h.ID(), cfg.PeerKey)
} else {
upgrader.Secure, err = makeSecurityMuxer(h, cfg.SecurityTransports)
var err error
secure, err = makeSecurityMuxer(h, cfg.SecurityTransports)
if err != nil {
return err
}
}

upgrader.Muxer, err = makeMuxer(h, cfg.Muxers)
muxer, err := makeMuxer(h, cfg.Muxers)
if err != nil {
return err
}

tpts, err := makeTransports(h, upgrader, cfg.ConnectionGater, cfg.Transports)
var opts []tptu.Option
if len(cfg.PSK) > 0 {
opts = append(opts, tptu.WithPSK(cfg.PSK))
}
if cfg.ConnectionGater != nil {
opts = append(opts, tptu.WithConnectionGater(cfg.ConnectionGater))
}
upgrader, err := tptu.New(secure, muxer, opts...)
if err != nil {
return err
}
tpts, err := makeTransports(h, upgrader, cfg.ConnectionGater, cfg.PSK, cfg.Transports)
if err != nil {
return err
}
Expand Down
36 changes: 20 additions & 16 deletions config/constructor_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"github.com/libp2p/go-libp2p-core/pnet"
"github.com/libp2p/go-libp2p-core/sec"
"github.com/libp2p/go-libp2p-core/transport"

tptu "github.com/libp2p/go-libp2p-transport-upgrader"
)

var (
Expand All @@ -29,29 +27,35 @@ var (
pubKeyType = reflect.TypeOf((*crypto.PubKey)(nil)).Elem()
pstoreType = reflect.TypeOf((*peerstore.Peerstore)(nil)).Elem()
connGaterType = reflect.TypeOf((*connmgr.ConnectionGater)(nil)).Elem()
upgraderType = reflect.TypeOf((*transport.Upgrader)(nil)).Elem()

// concrete types
peerIDType = reflect.TypeOf((peer.ID)(""))
upgraderType = reflect.TypeOf((*tptu.Upgrader)(nil))
pskType = reflect.TypeOf((pnet.PSK)(nil))
peerIDType = reflect.TypeOf((peer.ID)(""))
pskType = reflect.TypeOf((pnet.PSK)(nil))
)

var argTypes = map[reflect.Type]constructor{
upgraderType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return u },
hostType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return h },
networkType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return h.Network() },
muxType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return u.Muxer },
securityType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return u.Secure },
pskType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return u.PSK },
connGaterType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return cg },
peerIDType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return h.ID() },
privKeyType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} {
upgraderType: func(_ host.Host, u transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater) interface{} { return u },
hostType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater) interface{} { return h },
networkType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater) interface{} {
return h.Network()
},
pskType: func(_ host.Host, _ transport.Upgrader, psk pnet.PSK, _ connmgr.ConnectionGater) interface{} {
return psk
},
connGaterType: func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, cg connmgr.ConnectionGater) interface{} { return cg },
peerIDType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater) interface{} {
return h.ID()
},
privKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater) interface{} {
return h.Peerstore().PrivKey(h.ID())
},
pubKeyType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} {
pubKeyType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater) interface{} {
return h.Peerstore().PubKey(h.ID())
},
pstoreType: func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{} { return h.Peerstore() },
pstoreType: func(h host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater) interface{} {
return h.Peerstore()
},
}

func newArgTypeSet(types ...reflect.Type) map[reflect.Type]constructor {
Expand Down
2 changes: 1 addition & 1 deletion config/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func MuxerConstructor(m interface{}) (MuxC, error) {
return nil, err
}
return func(h host.Host) (mux.Multiplexer, error) {
t, err := ctor(h, nil, nil)
t, err := ctor(h, nil, nil, nil)
if err != nil {
return nil, err
}
Expand Down
13 changes: 7 additions & 6 deletions config/reflection_magic.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ import (
"reflect"
"runtime"

"github.com/libp2p/go-libp2p-core/pnet"

"github.com/libp2p/go-libp2p-core/connmgr"
"github.com/libp2p/go-libp2p-core/host"

tptu "github.com/libp2p/go-libp2p-transport-upgrader"
"github.com/libp2p/go-libp2p-core/transport"
)

var errorType = reflect.TypeOf((*error)(nil)).Elem()
Expand Down Expand Up @@ -79,7 +80,7 @@ func callConstructor(c reflect.Value, args []reflect.Value) (interface{}, error)
return val, err
}

type constructor func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) interface{}
type constructor func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater) interface{}

func makeArgumentConstructors(fnType reflect.Type, argTypes map[reflect.Type]constructor) ([]constructor, error) {
params := fnType.NumIn()
Expand Down Expand Up @@ -130,7 +131,7 @@ func makeConstructor(
tptType reflect.Type,
argTypes map[reflect.Type]constructor,
opts ...interface{},
) (func(host.Host, *tptu.Upgrader, connmgr.ConnectionGater) (interface{}, error), error) {
) (func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater) (interface{}, error), error) {
v := reflect.ValueOf(tpt)
// avoid panicing on nil/zero value.
if v == (reflect.Value{}) {
Expand All @@ -154,10 +155,10 @@ func makeConstructor(
return nil, err
}

return func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) (interface{}, error) {
return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater) (interface{}, error) {
arguments := make([]reflect.Value, 0, len(argConstructors)+len(opts))
for i, makeArg := range argConstructors {
if arg := makeArg(h, u, cg); arg != nil {
if arg := makeArg(h, u, psk, cg); arg != nil {
arguments = append(arguments, reflect.ValueOf(arg))
} else {
// ValueOf an un-typed nil yields a zero reflect
Expand Down
2 changes: 1 addition & 1 deletion config/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func SecurityConstructor(security interface{}) (SecC, error) {
return nil, err
}
return func(h host.Host) (sec.SecureTransport, error) {
t, err := ctor(h, nil, nil)
t, err := ctor(h, nil, nil, nil)
if err != nil {
return nil, err
}
Expand Down
15 changes: 7 additions & 8 deletions config/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ package config
import (
"github.com/libp2p/go-libp2p-core/connmgr"
"github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/pnet"
"github.com/libp2p/go-libp2p-core/transport"

tptu "github.com/libp2p/go-libp2p-transport-upgrader"
)

// TptC is the type for libp2p transport constructors. You probably won't ever
// implement this function interface directly. Instead, pass your transport
// constructor to TransportConstructor.
type TptC func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) (transport.Transport, error)
type TptC func(host.Host, transport.Upgrader, pnet.PSK, connmgr.ConnectionGater) (transport.Transport, error)

var transportArgTypes = argTypes

Expand Down Expand Up @@ -39,27 +38,27 @@ var transportArgTypes = argTypes
func TransportConstructor(tpt interface{}, opts ...interface{}) (TptC, error) {
// Already constructed?
if t, ok := tpt.(transport.Transport); ok {
return func(_ host.Host, _ *tptu.Upgrader, _ connmgr.ConnectionGater) (transport.Transport, error) {
return func(_ host.Host, _ transport.Upgrader, _ pnet.PSK, _ connmgr.ConnectionGater) (transport.Transport, error) {
return t, nil
}, nil
}
ctor, err := makeConstructor(tpt, transportType, transportArgTypes, opts...)
if err != nil {
return nil, err
}
return func(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater) (transport.Transport, error) {
t, err := ctor(h, u, cg)
return func(h host.Host, u transport.Upgrader, psk pnet.PSK, cg connmgr.ConnectionGater) (transport.Transport, error) {
t, err := ctor(h, u, psk, cg)
if err != nil {
return nil, err
}
return t.(transport.Transport), nil
}, nil
}

func makeTransports(h host.Host, u *tptu.Upgrader, cg connmgr.ConnectionGater, tpts []TptC) ([]transport.Transport, error) {
func makeTransports(h host.Host, u transport.Upgrader, cg connmgr.ConnectionGater, psk pnet.PSK, tpts []TptC) ([]transport.Transport, error) {
transports := make([]transport.Transport, len(tpts))
for i, tC := range tpts {
t, err := tC(h, u, cg)
t, err := tC(h, u, psk, cg)
if err != nil {
return nil, err
}
Expand Down
9 changes: 4 additions & 5 deletions config/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/transport"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
"github.com/libp2p/go-tcp-transport"

"github.com/stretchr/testify/require"
Expand All @@ -17,27 +16,27 @@ func TestTransportVariadicOptions(t *testing.T) {
}

func TestConstructorWithoutOptsCalledWithOpts(t *testing.T) {
_, err := TransportConstructor(func(_ *tptu.Upgrader) transport.Transport {
_, err := TransportConstructor(func(_ transport.Upgrader) transport.Transport {
return nil
}, 42)
require.EqualError(t, err, "constructor doesn't accept any options")
}

func TestConstructorWithOptsTypeMismatch(t *testing.T) {
_, err := TransportConstructor(func(_ *tptu.Upgrader, opts ...int) transport.Transport {
_, err := TransportConstructor(func(_ transport.Upgrader, opts ...int) transport.Transport {
return nil
}, 42, "foo")
require.EqualError(t, err, "expected option of type int, got string")
}

func TestConstructorWithOpts(t *testing.T) {
var options []int
c, err := TransportConstructor(func(_ *tptu.Upgrader, opts ...int) (transport.Transport, error) {
c, err := TransportConstructor(func(_ transport.Upgrader, opts ...int) (transport.Transport, error) {
options = opts
return tcp.NewTCPTransport(nil)
}, 42, 1337)
require.NoError(t, err)
_, err = c(nil, nil, nil)
_, err = c(nil, nil, nil, nil)
require.NoError(t, err)
require.Equal(t, []int{42, 1337}, options)
}
6 changes: 3 additions & 3 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ require (
github.com/ipfs/go-log/v2 v2.4.0
github.com/libp2p/go-libp2p v0.14.4
github.com/libp2p/go-libp2p-connmgr v0.2.4
github.com/libp2p/go-libp2p-core v0.13.0
github.com/libp2p/go-libp2p-core v0.13.1-0.20220102114513-ec2f775d92a4
github.com/libp2p/go-libp2p-discovery v0.6.0
github.com/libp2p/go-libp2p-kad-dht v0.15.0
github.com/libp2p/go-libp2p-noise v0.3.0
github.com/libp2p/go-libp2p-swarm v0.9.0
github.com/libp2p/go-libp2p-swarm v0.9.1-0.20220102125925-52f8e01834d2
github.com/libp2p/go-libp2p-tls v0.3.1
github.com/multiformats/go-multiaddr v0.4.0
github.com/multiformats/go-multiaddr v0.4.1
)

// Ensure that examples always use the go-libp2p version in the same git checkout.
Expand Down
Loading

0 comments on commit d7b7473

Please sign in to comment.