-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rcmgr: Add conn_limiter to limit number of conns per ip cidr (#2788)
* Add conn_limiter to limit number of conns per ip cidr * Handle the case where we want to call OpenConnection without an IP address * Delete key when count==0
- Loading branch information
Showing
4 changed files
with
364 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
package rcmgr | ||
|
||
import ( | ||
"net/netip" | ||
"sync" | ||
) | ||
|
||
type ConnLimitPerCIDR struct { | ||
// How many leading 1 bits in the mask | ||
BitMask int | ||
ConnCount int | ||
} | ||
|
||
// 8 for now so that it matches the number of concurrent dials we may do | ||
// in swarm_dial.go. With future smart dialing work we should bring this | ||
// down | ||
var defaultMaxConcurrentConns = 8 | ||
|
||
var defaultIP4Limit = ConnLimitPerCIDR{ | ||
ConnCount: defaultMaxConcurrentConns, | ||
BitMask: 32, | ||
} | ||
var defaultIP6Limits = []ConnLimitPerCIDR{ | ||
{ | ||
ConnCount: defaultMaxConcurrentConns, | ||
BitMask: 56, | ||
}, | ||
{ | ||
ConnCount: 8 * defaultMaxConcurrentConns, | ||
BitMask: 48, | ||
}, | ||
} | ||
|
||
func WithLimitPeersPerCIDR(ipv4 []ConnLimitPerCIDR, ipv6 []ConnLimitPerCIDR) Option { | ||
return func(rm *resourceManager) error { | ||
if ipv4 != nil { | ||
rm.connLimiter.connLimitPerCIDRIP4 = ipv4 | ||
} | ||
if ipv6 != nil { | ||
rm.connLimiter.connLimitPerCIDRIP6 = ipv6 | ||
} | ||
return nil | ||
} | ||
} | ||
|
||
type connLimiter struct { | ||
mu sync.Mutex | ||
connLimitPerCIDRIP4 []ConnLimitPerCIDR | ||
connLimitPerCIDRIP6 []ConnLimitPerCIDR | ||
ip4connsPerLimit []map[string]int | ||
ip6connsPerLimit []map[string]int | ||
} | ||
|
||
func newConnLimiter() *connLimiter { | ||
return &connLimiter{ | ||
connLimitPerCIDRIP4: []ConnLimitPerCIDR{defaultIP4Limit}, | ||
connLimitPerCIDRIP6: defaultIP6Limits, | ||
} | ||
} | ||
|
||
// addConn adds a connection for the given IP address. It returns true if the connection is allowed. | ||
func (cl *connLimiter) addConn(ip netip.Addr) bool { | ||
cl.mu.Lock() | ||
defer cl.mu.Unlock() | ||
limits := cl.connLimitPerCIDRIP4 | ||
countsPerLimit := cl.ip4connsPerLimit | ||
isIP6 := ip.Is6() | ||
if isIP6 { | ||
limits = cl.connLimitPerCIDRIP6 | ||
countsPerLimit = cl.ip6connsPerLimit | ||
} | ||
|
||
if len(countsPerLimit) == 0 && len(limits) > 0 { | ||
countsPerLimit = make([]map[string]int, len(limits)) | ||
if isIP6 { | ||
cl.ip6connsPerLimit = countsPerLimit | ||
} else { | ||
cl.ip4connsPerLimit = countsPerLimit | ||
} | ||
} | ||
|
||
for i, limit := range limits { | ||
prefix, err := ip.Prefix(limit.BitMask) | ||
if err != nil { | ||
return false | ||
} | ||
masked := prefix.String() | ||
|
||
counts, ok := countsPerLimit[i][masked] | ||
if !ok { | ||
if countsPerLimit[i] == nil { | ||
countsPerLimit[i] = make(map[string]int) | ||
} | ||
countsPerLimit[i][masked] = 0 | ||
} | ||
if counts+1 > limit.ConnCount { | ||
return false | ||
} | ||
} | ||
|
||
// All limit checks passed, now we update the counts | ||
for i, limit := range limits { | ||
prefix, _ := ip.Prefix(limit.BitMask) | ||
masked := prefix.String() | ||
countsPerLimit[i][masked]++ | ||
} | ||
|
||
return true | ||
} | ||
|
||
func (cl *connLimiter) rmConn(ip netip.Addr) { | ||
cl.mu.Lock() | ||
defer cl.mu.Unlock() | ||
limits := cl.connLimitPerCIDRIP4 | ||
countsPerLimit := cl.ip4connsPerLimit | ||
isIP6 := ip.Is6() | ||
if isIP6 { | ||
limits = cl.connLimitPerCIDRIP6 | ||
countsPerLimit = cl.ip6connsPerLimit | ||
} | ||
|
||
for i, limit := range limits { | ||
prefix, err := ip.Prefix(limit.BitMask) | ||
if err != nil { | ||
// Unexpected since we should have seen this IP before in addConn | ||
log.Errorf("unexpected error getting prefix: %v", err) | ||
continue | ||
} | ||
masked := prefix.String() | ||
counts, ok := countsPerLimit[i][masked] | ||
if !ok || counts == 0 { | ||
// Unexpected, but don't panic | ||
log.Errorf("unexpected conn count for %s ok=%v count=%v", masked, ok, counts) | ||
continue | ||
} | ||
countsPerLimit[i][masked]-- | ||
if countsPerLimit[i][masked] == 0 { | ||
delete(countsPerLimit[i], masked) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
package rcmgr | ||
|
||
import ( | ||
"encoding/binary" | ||
"fmt" | ||
"net" | ||
"net/netip" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestItLimits(t *testing.T) { | ||
t.Run("IPv4", func(t *testing.T) { | ||
ip, err := netip.ParseAddr("1.2.3.4") | ||
require.NoError(t, err) | ||
cl := newConnLimiter() | ||
cl.connLimitPerCIDRIP4[0].ConnCount = 1 | ||
require.True(t, cl.addConn(ip)) | ||
|
||
// should fail the second time | ||
require.False(t, cl.addConn(ip)) | ||
|
||
otherIP, err := netip.ParseAddr("1.2.3.5") | ||
require.NoError(t, err) | ||
require.True(t, cl.addConn(otherIP)) | ||
}) | ||
t.Run("IPv6", func(t *testing.T) { | ||
ip, err := netip.ParseAddr("1:2:3:4::1") | ||
require.NoError(t, err) | ||
cl := newConnLimiter() | ||
original := cl.connLimitPerCIDRIP6[0].ConnCount | ||
cl.connLimitPerCIDRIP6[0].ConnCount = 1 | ||
defer func() { | ||
cl.connLimitPerCIDRIP6[0].ConnCount = original | ||
}() | ||
require.True(t, cl.addConn(ip)) | ||
|
||
// should fail the second time | ||
require.False(t, cl.addConn(ip)) | ||
otherIPSameSubnet := netip.MustParseAddr("1:2:3:4::2") | ||
require.False(t, cl.addConn(otherIPSameSubnet)) | ||
|
||
otherIP := netip.MustParseAddr("2:2:3:4::2") | ||
require.True(t, cl.addConn(otherIP)) | ||
}) | ||
|
||
t.Run("IPv6 with multiple limits", func(t *testing.T) { | ||
cl := newConnLimiter() | ||
for i := 0; i < defaultMaxConcurrentConns; i++ { | ||
ip := net.ParseIP("ff:2:3:4::1") | ||
binary.BigEndian.PutUint16(ip[14:], uint16(i)) | ||
ipAddr := netip.MustParseAddr(ip.String()) | ||
require.True(t, cl.addConn(ipAddr)) | ||
} | ||
|
||
// Next one should fail | ||
ip := net.ParseIP("ff:2:3:4::1") | ||
binary.BigEndian.PutUint16(ip[14:], uint16(defaultMaxConcurrentConns+1)) | ||
require.False(t, cl.addConn(netip.MustParseAddr(ip.String()))) | ||
|
||
// But on a different root subnet should work | ||
otherIP := netip.MustParseAddr("ffef:2:3::1") | ||
require.True(t, cl.addConn(otherIP)) | ||
|
||
// But too many on the next subnet limit will fail too | ||
for i := 0; i < defaultMaxConcurrentConns*8; i++ { | ||
ip := net.ParseIP("ffef:2:3:4::1") | ||
binary.BigEndian.PutUint16(ip[5:7], uint16(i)) | ||
fmt.Println(ip.String()) | ||
ipAddr := netip.MustParseAddr(ip.String()) | ||
require.True(t, cl.addConn(ipAddr)) | ||
} | ||
|
||
ip = net.ParseIP("ffef:2:3:4::1") | ||
binary.BigEndian.PutUint16(ip[5:7], uint16(defaultMaxConcurrentConns*8+1)) | ||
ipAddr := netip.MustParseAddr(ip.String()) | ||
require.False(t, cl.addConn(ipAddr)) | ||
}) | ||
} | ||
|
||
func genIP(data *[]byte) (netip.Addr, bool) { | ||
if len(*data) < 1 { | ||
return netip.Addr{}, false | ||
} | ||
|
||
genIP6 := (*data)[0]&0x01 == 1 | ||
bytesRequired := 4 | ||
if genIP6 { | ||
bytesRequired = 16 | ||
} | ||
|
||
if len((*data)[1:]) < bytesRequired { | ||
return netip.Addr{}, false | ||
} | ||
|
||
*data = (*data)[1:] | ||
ip, ok := netip.AddrFromSlice((*data)[:bytesRequired]) | ||
*data = (*data)[bytesRequired:] | ||
return ip, ok | ||
} | ||
|
||
func FuzzConnLimiter(f *testing.F) { | ||
// The goal is to try to enter a state where the count is incorrectly 0 | ||
f.Fuzz(func(t *testing.T, data []byte) { | ||
ips := make([]netip.Addr, 0, len(data)/5) | ||
for { | ||
ip, ok := genIP(&data) | ||
if !ok { | ||
break | ||
} | ||
ips = append(ips, ip) | ||
} | ||
|
||
cl := newConnLimiter() | ||
addedConns := make([]netip.Addr, 0, len(ips)) | ||
for _, ip := range ips { | ||
if cl.addConn(ip) { | ||
addedConns = append(addedConns, ip) | ||
} | ||
} | ||
|
||
addedCount := 0 | ||
for _, ip := range cl.ip4connsPerLimit { | ||
for _, count := range ip { | ||
addedCount += count | ||
} | ||
} | ||
for _, ip := range cl.ip6connsPerLimit { | ||
for _, count := range ip { | ||
addedCount += count | ||
} | ||
} | ||
if addedCount == 0 && len(addedConns) > 0 { | ||
t.Fatalf("added count: %d", addedCount) | ||
} | ||
|
||
for _, ip := range addedConns { | ||
cl.rmConn(ip) | ||
} | ||
|
||
leftoverCount := 0 | ||
for _, ip := range cl.ip4connsPerLimit { | ||
for _, count := range ip { | ||
leftoverCount += count | ||
} | ||
} | ||
for _, ip := range cl.ip6connsPerLimit { | ||
for _, count := range ip { | ||
leftoverCount += count | ||
} | ||
} | ||
if leftoverCount != 0 { | ||
t.Fatalf("leftover count: %d", leftoverCount) | ||
} | ||
}) | ||
|
||
} |
Oops, something went wrong.