Skip to content

Commit

Permalink
restrict output message size, add top level option to disable address
Browse files Browse the repository at this point in the history
discovery
  • Loading branch information
sukunrt committed May 15, 2024
1 parent 45b6ccd commit 815089d
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 71 deletions.
29 changes: 16 additions & 13 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ type Config struct {
DialRanker network.DialRanker

SwarmOpts []swarm.Option

DisableIdentifyAddressDiscovery bool
}

func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) {
Expand Down Expand Up @@ -290,19 +292,20 @@ func (cfg *Config) addTransports() ([]fx.Option, error) {

func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) {
h, err := bhost.NewHost(swrm, &bhost.HostOpts{
EventBus: eventBus,
ConnManager: cfg.ConnManager,
AddrsFactory: cfg.AddrsFactory,
NATManager: cfg.NATManager,
EnablePing: !cfg.DisablePing,
UserAgent: cfg.UserAgent,
ProtocolVersion: cfg.ProtocolVersion,
EnableHolePunching: cfg.EnableHolePunching,
HolePunchingOptions: cfg.HolePunchingOptions,
EnableRelayService: cfg.EnableRelayService,
RelayServiceOpts: cfg.RelayServiceOpts,
EnableMetrics: !cfg.DisableMetrics,
PrometheusRegisterer: cfg.PrometheusRegisterer,
EventBus: eventBus,
ConnManager: cfg.ConnManager,
AddrsFactory: cfg.AddrsFactory,
NATManager: cfg.NATManager,
EnablePing: !cfg.DisablePing,
UserAgent: cfg.UserAgent,
ProtocolVersion: cfg.ProtocolVersion,
EnableHolePunching: cfg.EnableHolePunching,
HolePunchingOptions: cfg.HolePunchingOptions,
EnableRelayService: cfg.EnableRelayService,
RelayServiceOpts: cfg.RelayServiceOpts,
EnableMetrics: !cfg.DisableMetrics,
PrometheusRegisterer: cfg.PrometheusRegisterer,
DisableIdentifyAddressDiscovery: cfg.DisableIdentifyAddressDiscovery,
})
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion core/peerstore/peerstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var (

// RecentlyConnectedAddrTTL is used when we recently connected to a peer.
// It means that we are reasonably certain of the peer's address.
RecentlyConnectedAddrTTL = time.Minute * 30
RecentlyConnectedAddrTTL = time.Minute * 15

// OwnObservedAddrTTL is used for our own external addresses observed by peers.
// Deprecated: observed addresses are maintained till we disconnect from the peer which provided it
Expand Down
6 changes: 6 additions & 0 deletions libp2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,12 @@ func TestAutoNATService(t *testing.T) {
h.Close()
}

func TestDisableIdentifyAddressDiscovery(t *testing.T) {
h, err := New(DisableIdentifyAddressDiscovery())
require.NoError(t, err)
h.Close()
}

func TestMain(m *testing.M) {
goleak.VerifyTestMain(
m,
Expand Down
11 changes: 11 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,14 @@ func SwarmOpts(opts ...swarm.Option) Option {
return nil
}
}

// DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses
// in identify. If you know your public addresses upfront, the recommended way is to use
// AddressFactory to provide the external adddress to the host and using this option to disable
// discovery from identify as it is more error prone
func DisableIdentifyAddressDiscovery() Option {
return func(cfg *Config) error {
cfg.DisableIdentifyAddressDiscovery = true
return nil
}
}
79 changes: 72 additions & 7 deletions p2p/host/basic/basic_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"slices"
"sync"
"time"

Expand Down Expand Up @@ -53,6 +54,8 @@ var (
DefaultAddrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs }
)

const maxPeerRecordSize = 4096 // 4k to be compatible with rust-libp2p identify implementation

// AddrsFactory functions can be passed to New in order to override
// addresses returned by Addrs.
type AddrsFactory func([]ma.Multiaddr) []ma.Multiaddr
Expand Down Expand Up @@ -161,6 +164,9 @@ type HostOpts struct {
EnableMetrics bool
// PrometheusRegisterer is the PrometheusRegisterer used for metrics
PrometheusRegisterer prometheus.Registerer

// DisableIdentifyAddressDiscovery disables address discovery using peer provided observed addresses in identify
DisableIdentifyAddressDiscovery bool
}

// NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network.
Expand Down Expand Up @@ -244,6 +250,9 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) {
identify.WithMetricsTracer(
identify.NewMetricsTracer(identify.WithRegisterer(opts.PrometheusRegisterer))))
}
if opts.DisableIdentifyAddressDiscovery {
idOpts = append(idOpts, identify.DisableObservedAddrManager())
}

h.ids, err = identify.NewIDService(h, idOpts...)
if err != nil {
Expand Down Expand Up @@ -482,15 +491,18 @@ func makeUpdatedAddrEvent(prev, current []ma.Multiaddr) *event.EvtLocalAddresses
return &evt
}

func (h *BasicHost) makeSignedPeerRecord(evt *event.EvtLocalAddressesUpdated) (*record.Envelope, error) {
current := make([]ma.Multiaddr, 0, len(evt.Current))
for _, a := range evt.Current {
current = append(current, a.Address)
func (h *BasicHost) makeSignedPeerRecord(addrs []ma.Multiaddr) (*record.Envelope, error) {
// Limit the length of currentAddrs to ensure that our signed peer records aren't rejected
peerRecordSize := 64 // HostID
k, err := h.signKey.Raw()
if err != nil {
peerRecordSize += 2 * len(k) // 1 for signature, 1 for public key
}

// we want the final address list to be small for keeping the signed peer record in size
addrs = trimHostAddrList(addrs, maxPeerRecordSize-peerRecordSize-256) // 256 B of buffer
rec := peer.PeerRecordFromAddrInfo(peer.AddrInfo{
ID: h.ID(),
Addrs: current,
Addrs: addrs,
})
return record.Seal(rec, h.signKey)
}
Expand All @@ -513,7 +525,7 @@ func (h *BasicHost) background() {

if !h.disableSignedPeerRecord {
// add signed peer record to the event
sr, err := h.makeSignedPeerRecord(changeEvt)
sr, err := h.makeSignedPeerRecord(currentAddrs)
if err != nil {
log.Errorf("error creating a signed peer record from the set of current addresses, err=%s", err)
return
Expand Down Expand Up @@ -805,6 +817,7 @@ func (h *BasicHost) Addrs() []ma.Multiaddr {
addrs[i] = addrWithCerthash
}
}

return addrs
}

Expand Down Expand Up @@ -997,6 +1010,58 @@ func inferWebtransportAddrsFromQuic(in []ma.Multiaddr) []ma.Multiaddr {
return out
}

func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr {
totalSize := 0
for _, a := range addrs {
totalSize += len(a.Bytes())
}
if totalSize <= maxSize {
return addrs
}

score := func(addr ma.Multiaddr) int {
var res int
if manet.IsPublicAddr(addr) {
res |= 1 << 12
} else if !manet.IsIPLoopback(addr) {
res |= 1 << 11
}
var protocolWeight int
ma.ForEach(addr, func(c ma.Component) bool {
switch c.Protocol().Code {
case ma.P_QUIC_V1:
protocolWeight = 5
case ma.P_TCP:
protocolWeight = 4
case ma.P_WSS:
protocolWeight = 3
case ma.P_WEBTRANSPORT:
protocolWeight = 2
case ma.P_WEBRTC_DIRECT:
protocolWeight = 1
case ma.P_P2P:
return false
}
return true
})
res |= 1 << protocolWeight
return res
}

slices.SortStableFunc(addrs, func(a, b ma.Multiaddr) int {
return score(b) - score(a) // b-a for reverse order
})
totalSize = 0
for i, a := range addrs {
totalSize += len(a.Bytes())
if totalSize > maxSize {
addrs = addrs[:i]
break
}
}
return addrs
}

// SetAutoNat sets the autonat service for the host.
func (h *BasicHost) SetAutoNat(a autonat.AutoNAT) {
h.addrMu.Lock()
Expand Down
52 changes: 52 additions & 0 deletions p2p/host/basic/basic_host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -896,3 +896,55 @@ func TestInferWebtransportAddrsFromQuic(t *testing.T) {
}

}

func TestTrimHostAddrList(t *testing.T) {
type testCase struct {
name string
in []ma.Multiaddr
threshold int
out []ma.Multiaddr
}

tcpPublic := ma.StringCast("/ip4/1.1.1.1/tcp/1")
quicPublic := ma.StringCast("/ip4/1.1.1.1/udp/1/quic-v1")

tcpPrivate := ma.StringCast("/ip4/192.168.1.1/tcp/1")
quicPrivate := ma.StringCast("/ip4/192.168.1.1/udp/1/quic-v1")

tcpLocal := ma.StringCast("/ip4/127.0.0.1/tcp/1")
quicLocal := ma.StringCast("/ip4/127.0.0.1/udp/1/quic-v1")

testCases := []testCase{
{
name: "Public preferred over private",
in: []ma.Multiaddr{tcpPublic, quicPrivate},
threshold: len(tcpLocal.Bytes()),
out: []ma.Multiaddr{tcpPublic},
},
{
name: "Public and private preffered over local",
in: []ma.Multiaddr{tcpPublic, tcpPrivate, quicLocal},
threshold: len(tcpPublic.Bytes()) + len(tcpPrivate.Bytes()),
out: []ma.Multiaddr{tcpPublic, tcpPrivate},
},
{
name: "quic preferred over tcp",
in: []ma.Multiaddr{tcpPublic, quicPublic},
threshold: len(quicPublic.Bytes()),
out: []ma.Multiaddr{quicPublic},
},
{
name: "no filtering on large threshold",
in: []ma.Multiaddr{tcpPublic, quicPublic, quicLocal, tcpLocal, tcpPrivate},
threshold: 10000,
out: []ma.Multiaddr{tcpPublic, quicPublic, quicLocal, tcpLocal, tcpPrivate},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := trimHostAddrList(tc.in, tc.threshold)
require.ElementsMatch(t, got, tc.out)
})
}
}
Loading

0 comments on commit 815089d

Please sign in to comment.