diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 86c276cb5b..0ddb904174 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -23,6 +23,10 @@ var log = logging.Logger("basichost") var NegotiateTimeout = time.Second * 60 +// AddrsFactory functions can be passed to New in order to override +// addresses returned by Addrs. +type AddrsFactory func([]ma.Multiaddr) []ma.Multiaddr + // Option is a type used to pass in options to the host. type Option int @@ -45,6 +49,7 @@ type BasicHost struct { mux *msmux.MultistreamMuxer ids *identify.IDService natmgr *natManager + addrs AddrsFactory NegotiateTimeout time.Duration @@ -72,6 +77,9 @@ func New(net inet.Network, opts ...interface{}) *BasicHost { // setup host services h.ids = identify.NewIDService(h) + // default addresses factory, can be overridden via opts argument + h.addrs = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs } + for _, o := range opts { switch o := o.(type) { case Option: @@ -81,6 +89,8 @@ func New(net inet.Network, opts ...interface{}) *BasicHost { } case metrics.Reporter: h.bwc = o + case AddrsFactory: + h.addrs = AddrsFactory(o) } } @@ -336,9 +346,15 @@ func (h *BasicHost) dialPeer(ctx context.Context, p peer.ID) error { return nil } -// Addrs returns all the addresses of BasicHost at this moment in time. -// It's ok to not include addresses if they're not available to be used now. +// Addrs returns listening addresses that are safe to announce to the network. +// The output is the same as AllAddrs, but processed by AddrsFactory. func (h *BasicHost) Addrs() []ma.Multiaddr { + return h.addrs(h.AllAddrs()) +} + +// AllAddrs returns all the addresses of BasicHost at this moment in time. +// It's ok to not include addresses if they're not available to be used now. +func (h *BasicHost) AllAddrs() []ma.Multiaddr { addrs, err := h.Network().InterfaceListenAddresses() if err != nil { log.Debug("error retrieving network interface addrs") diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 57d7c2edd7..89b975a0f6 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -11,6 +11,7 @@ import ( inet "github.com/libp2p/go-libp2p-net" testutil "github.com/libp2p/go-libp2p-netutil" protocol "github.com/libp2p/go-libp2p-protocol" + ma "github.com/multiformats/go-multiaddr" ) func TestHostSimple(t *testing.T) { @@ -63,6 +64,25 @@ func TestHostSimple(t *testing.T) { } } +func TestHostAddrsFactory(t *testing.T) { + maddr := ma.StringCast("/ip4/1.2.3.4/tcp/1234") + addrsFactory := func(addrs []ma.Multiaddr) []ma.Multiaddr { + return []ma.Multiaddr{maddr} + } + + ctx := context.Background() + h := New(testutil.GenSwarmNetwork(t, ctx), AddrsFactory(addrsFactory)) + defer h.Close() + + addrs := h.Addrs() + if len(addrs) != 1 { + t.Fatalf("expected 1 addr, got %d", len(addrs)) + } + if addrs[0] != maddr { + t.Fatalf("expected %s, got %s", maddr.String(), addrs[0].String()) + } +} + func getHostPair(ctx context.Context, t *testing.T) (host.Host, host.Host) { h1 := New(testutil.GenSwarmNetwork(t, ctx)) h2 := New(testutil.GenSwarmNetwork(t, ctx))