Skip to content

Commit

Permalink
rcmgr: Add conn_limiter to limit number of conns per ip cidr (#2788)
Browse files Browse the repository at this point in the history
* Add conn_limiter to limit number of conns per ip cidr

* Handle the case where we want to call OpenConnection without an IP address

* Delete key when count==0
  • Loading branch information
MarcoPolo authored May 16, 2024
1 parent 6861cec commit 5d547cf
Show file tree
Hide file tree
Showing 4 changed files with 364 additions and 9 deletions.
141 changes: 141 additions & 0 deletions p2p/host/resource-manager/conn_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package rcmgr

import (
"net/netip"
"sync"
)

type ConnLimitPerCIDR struct {
// How many leading 1 bits in the mask
BitMask int
ConnCount int
}

// 8 for now so that it matches the number of concurrent dials we may do
// in swarm_dial.go. With future smart dialing work we should bring this
// down
var defaultMaxConcurrentConns = 8

var defaultIP4Limit = ConnLimitPerCIDR{
ConnCount: defaultMaxConcurrentConns,
BitMask: 32,
}
var defaultIP6Limits = []ConnLimitPerCIDR{
{
ConnCount: defaultMaxConcurrentConns,
BitMask: 56,
},
{
ConnCount: 8 * defaultMaxConcurrentConns,
BitMask: 48,
},
}

func WithLimitPeersPerCIDR(ipv4 []ConnLimitPerCIDR, ipv6 []ConnLimitPerCIDR) Option {
return func(rm *resourceManager) error {
if ipv4 != nil {
rm.connLimiter.connLimitPerCIDRIP4 = ipv4
}
if ipv6 != nil {
rm.connLimiter.connLimitPerCIDRIP6 = ipv6
}
return nil
}
}

type connLimiter struct {
mu sync.Mutex
connLimitPerCIDRIP4 []ConnLimitPerCIDR
connLimitPerCIDRIP6 []ConnLimitPerCIDR
ip4connsPerLimit []map[string]int
ip6connsPerLimit []map[string]int
}

func newConnLimiter() *connLimiter {
return &connLimiter{
connLimitPerCIDRIP4: []ConnLimitPerCIDR{defaultIP4Limit},
connLimitPerCIDRIP6: defaultIP6Limits,
}
}

// addConn adds a connection for the given IP address. It returns true if the connection is allowed.
func (cl *connLimiter) addConn(ip netip.Addr) bool {
cl.mu.Lock()
defer cl.mu.Unlock()
limits := cl.connLimitPerCIDRIP4
countsPerLimit := cl.ip4connsPerLimit
isIP6 := ip.Is6()
if isIP6 {
limits = cl.connLimitPerCIDRIP6
countsPerLimit = cl.ip6connsPerLimit
}

if len(countsPerLimit) == 0 && len(limits) > 0 {
countsPerLimit = make([]map[string]int, len(limits))
if isIP6 {
cl.ip6connsPerLimit = countsPerLimit
} else {
cl.ip4connsPerLimit = countsPerLimit
}
}

for i, limit := range limits {
prefix, err := ip.Prefix(limit.BitMask)
if err != nil {
return false
}
masked := prefix.String()

counts, ok := countsPerLimit[i][masked]
if !ok {
if countsPerLimit[i] == nil {
countsPerLimit[i] = make(map[string]int)
}
countsPerLimit[i][masked] = 0
}
if counts+1 > limit.ConnCount {
return false
}
}

// All limit checks passed, now we update the counts
for i, limit := range limits {
prefix, _ := ip.Prefix(limit.BitMask)
masked := prefix.String()
countsPerLimit[i][masked]++
}

return true
}

func (cl *connLimiter) rmConn(ip netip.Addr) {
cl.mu.Lock()
defer cl.mu.Unlock()
limits := cl.connLimitPerCIDRIP4
countsPerLimit := cl.ip4connsPerLimit
isIP6 := ip.Is6()
if isIP6 {
limits = cl.connLimitPerCIDRIP6
countsPerLimit = cl.ip6connsPerLimit
}

for i, limit := range limits {
prefix, err := ip.Prefix(limit.BitMask)
if err != nil {
// Unexpected since we should have seen this IP before in addConn
log.Errorf("unexpected error getting prefix: %v", err)
continue
}
masked := prefix.String()
counts, ok := countsPerLimit[i][masked]
if !ok || counts == 0 {
// Unexpected, but don't panic
log.Errorf("unexpected conn count for %s ok=%v count=%v", masked, ok, counts)
continue
}
countsPerLimit[i][masked]--
if countsPerLimit[i][masked] == 0 {
delete(countsPerLimit[i], masked)
}
}
}
158 changes: 158 additions & 0 deletions p2p/host/resource-manager/conn_limiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package rcmgr

import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"testing"

"github.com/stretchr/testify/require"
)

func TestItLimits(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ip, err := netip.ParseAddr("1.2.3.4")
require.NoError(t, err)
cl := newConnLimiter()
cl.connLimitPerCIDRIP4[0].ConnCount = 1
require.True(t, cl.addConn(ip))

// should fail the second time
require.False(t, cl.addConn(ip))

otherIP, err := netip.ParseAddr("1.2.3.5")
require.NoError(t, err)
require.True(t, cl.addConn(otherIP))
})
t.Run("IPv6", func(t *testing.T) {
ip, err := netip.ParseAddr("1:2:3:4::1")
require.NoError(t, err)
cl := newConnLimiter()
original := cl.connLimitPerCIDRIP6[0].ConnCount
cl.connLimitPerCIDRIP6[0].ConnCount = 1
defer func() {
cl.connLimitPerCIDRIP6[0].ConnCount = original
}()
require.True(t, cl.addConn(ip))

// should fail the second time
require.False(t, cl.addConn(ip))
otherIPSameSubnet := netip.MustParseAddr("1:2:3:4::2")
require.False(t, cl.addConn(otherIPSameSubnet))

otherIP := netip.MustParseAddr("2:2:3:4::2")
require.True(t, cl.addConn(otherIP))
})

t.Run("IPv6 with multiple limits", func(t *testing.T) {
cl := newConnLimiter()
for i := 0; i < defaultMaxConcurrentConns; i++ {
ip := net.ParseIP("ff:2:3:4::1")
binary.BigEndian.PutUint16(ip[14:], uint16(i))
ipAddr := netip.MustParseAddr(ip.String())
require.True(t, cl.addConn(ipAddr))
}

// Next one should fail
ip := net.ParseIP("ff:2:3:4::1")
binary.BigEndian.PutUint16(ip[14:], uint16(defaultMaxConcurrentConns+1))
require.False(t, cl.addConn(netip.MustParseAddr(ip.String())))

// But on a different root subnet should work
otherIP := netip.MustParseAddr("ffef:2:3::1")
require.True(t, cl.addConn(otherIP))

// But too many on the next subnet limit will fail too
for i := 0; i < defaultMaxConcurrentConns*8; i++ {
ip := net.ParseIP("ffef:2:3:4::1")
binary.BigEndian.PutUint16(ip[5:7], uint16(i))
fmt.Println(ip.String())
ipAddr := netip.MustParseAddr(ip.String())
require.True(t, cl.addConn(ipAddr))
}

ip = net.ParseIP("ffef:2:3:4::1")
binary.BigEndian.PutUint16(ip[5:7], uint16(defaultMaxConcurrentConns*8+1))
ipAddr := netip.MustParseAddr(ip.String())
require.False(t, cl.addConn(ipAddr))
})
}

func genIP(data *[]byte) (netip.Addr, bool) {
if len(*data) < 1 {
return netip.Addr{}, false
}

genIP6 := (*data)[0]&0x01 == 1
bytesRequired := 4
if genIP6 {
bytesRequired = 16
}

if len((*data)[1:]) < bytesRequired {
return netip.Addr{}, false
}

*data = (*data)[1:]
ip, ok := netip.AddrFromSlice((*data)[:bytesRequired])
*data = (*data)[bytesRequired:]
return ip, ok
}

func FuzzConnLimiter(f *testing.F) {
// The goal is to try to enter a state where the count is incorrectly 0
f.Fuzz(func(t *testing.T, data []byte) {
ips := make([]netip.Addr, 0, len(data)/5)
for {
ip, ok := genIP(&data)
if !ok {
break
}
ips = append(ips, ip)
}

cl := newConnLimiter()
addedConns := make([]netip.Addr, 0, len(ips))
for _, ip := range ips {
if cl.addConn(ip) {
addedConns = append(addedConns, ip)
}
}

addedCount := 0
for _, ip := range cl.ip4connsPerLimit {
for _, count := range ip {
addedCount += count
}
}
for _, ip := range cl.ip6connsPerLimit {
for _, count := range ip {
addedCount += count
}
}
if addedCount == 0 && len(addedConns) > 0 {
t.Fatalf("added count: %d", addedCount)
}

for _, ip := range addedConns {
cl.rmConn(ip)
}

leftoverCount := 0
for _, ip := range cl.ip4connsPerLimit {
for _, count := range ip {
leftoverCount += count
}
}
for _, ip := range cl.ip6connsPerLimit {
for _, count := range ip {
leftoverCount += count
}
}
if leftoverCount != 0 {
t.Fatalf("leftover count: %d", leftoverCount)
}
})

}
Loading

0 comments on commit 5d547cf

Please sign in to comment.