Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mempool)!: limit mempool gossip rate on a per-peer basis #787

Merged
merged 10 commits into from
May 20, 2024
Prev Previous commit
Next Next commit
test(p2pclient): rate limits unit tests
  • Loading branch information
lklimek committed May 16, 2024
commit a92728630efddc8335dcb86c3ce5b954fa6b0b41
10 changes: 5 additions & 5 deletions internal/p2p/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ func (suite *ChannelTestSuite) SetupTest() {
suite.fakeClock = clockwork.NewFakeClock()
suite.client = New(
suite.descriptors,
func(ctx context.Context, descriptor *p2p.ChannelDescriptor) (p2p.Channel, error) {
func(_ctx context.Context, _descriptor *p2p.ChannelDescriptor) (p2p.Channel, error) {
return suite.p2pChannel, nil
},
WithClock(suite.fakeClock),
WithChanIDResolver(func(msg proto.Message) p2p.ChannelID {
WithChanIDResolver(func(_msg proto.Message) p2p.ChannelID {
return testChannelID
}),
)
Expand Down Expand Up @@ -185,7 +185,7 @@ func (suite *ChannelTestSuite) TestConsumeHandle() {
suite.p2pChannel.
On("Receive", ctx).
Once().
Return(func(ctx context.Context) p2p.ChannelIterator {
Return(func(_ctx context.Context) p2p.ChannelIterator {
return p2p.NewChannelIterator(outCh)
})
consumer := newMockConsumer(suite.T())
Expand Down Expand Up @@ -226,7 +226,7 @@ func (suite *ChannelTestSuite) TestConsumeResolve() {
suite.p2pChannel.
On("Receive", ctx).
Once().
Return(func(ctx context.Context) p2p.ChannelIterator {
Return(func(_ctx context.Context) p2p.ChannelIterator {
return p2p.NewChannelIterator(outCh)
})
resCh := suite.client.addPending(reqID)
Expand Down Expand Up @@ -278,7 +278,7 @@ func (suite *ChannelTestSuite) TestConsumeError() {
suite.p2pChannel.
On("Receive", ctx).
Once().
Return(func(ctx context.Context) p2p.ChannelIterator {
Return(func(_ctx context.Context) p2p.ChannelIterator {
return p2p.NewChannelIterator(outCh)
})
consumer := newMockConsumer(suite.T())
Expand Down
115 changes: 0 additions & 115 deletions internal/p2p/client/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,16 @@ import (
"context"
"errors"
"fmt"
"math"
"regexp"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/dashpay/tenderdash/internal/p2p"
tmrequire "github.com/dashpay/tenderdash/internal/test/require"
"github.com/dashpay/tenderdash/libs/log"
bcproto "github.com/dashpay/tenderdash/proto/tendermint/blocksync"
"github.com/dashpay/tenderdash/types"
)

func TestErrorLoggerP2PMessageHandler(t *testing.T) {
Expand Down Expand Up @@ -151,110 +143,3 @@ func TestValidateMessageHandler(t *testing.T) {
})
}
}

// TestRateLimitHandler tests the rate limit middleware.
//
// GIVEN 5 peers named 1..5 and rate limit of 2/s and burst 4,
// WHEN we send 1, 2, 3, 4 and 5 msgs per second respectively for 3 seconds,
// THEN:
// * peer 1 and 2 receive all messages,
// * other peers receive 2 messages per second plus 4 burst messages.
func TestRateLimitHandler(t *testing.T) {
const (
Peers = 5
RateLimit = 2.0
Burst = 4
TestTimeSeconds = 3
)

// don't run this if we are in short mode
if testing.Short() {
t.Skip("skipping test in short mode.")
}

fakeHandler := newMockConsumer(t)

// we cancel manually to control race conditions
ctx, cancel := context.WithCancel(context.Background())

logger := log.NewTestingLogger(t)
client := &Client{}

mw := WithRecvRateLimitPerPeerHandler(ctx, RateLimit, func(*p2p.Envelope) uint { return 1 }, false, logger)(fakeHandler).(*recvRateLimitPerPeerHandler)
mw.burst = Burst

start := sync.RWMutex{}
start.Lock()

sent := make([]atomic.Uint32, Peers)

for peer := 1; peer <= Peers; peer++ {
counter := &sent[peer-1]
peerID := types.NodeID(strconv.Itoa(peer))
fakeHandler.On("Handle", mock.Anything, mock.Anything, mock.MatchedBy(
func(e *p2p.Envelope) bool {
return e.From == peerID
},
)).Return(nil).Run(func(_args mock.Arguments) {
counter.Add(1)
})

go func(peerID types.NodeID, rate int) {
start.RLock()
defer start.RUnlock()

for s := 0; s < TestTimeSeconds; s++ {
until := time.NewTimer(time.Second)
defer until.Stop()

for i := 0; i < rate; i++ {
select {
case <-ctx.Done():
return
default:
}

envelope := &p2p.Envelope{
From: peerID,
}

err := mw.Handle(ctx, client, envelope)
require.NoError(t, err)
}

select {
case <-until.C:
// noop, we just sleep
case <-ctx.Done():
return
}
}
}(peerID, peer)
}

// start the test
startTime := time.Now()
start.Unlock()
time.Sleep(TestTimeSeconds * time.Second)
cancel()
// wait for all goroutines to finish, that is - drop RLocks
start.Lock()

// Check assertions

// we floor with 1 decimal place, as elapsed will be slightly more than TestTimeSeconds
elapsed := math.Floor(time.Since(startTime).Seconds()*10) / 10
assert.Equal(t, float64(TestTimeSeconds), elapsed, "test should run for %d seconds", TestTimeSeconds)

for peer := 1; peer <= Peers; peer++ {
expected := int(RateLimit)*TestTimeSeconds + Burst
if expected > peer*TestTimeSeconds {
expected = peer * TestTimeSeconds
}

assert.Equal(t, expected, int(sent[peer-1].Load()), "peer %d should receive %d messages", peer, expected)
}
// require.Equal(t, uint32(1*TestTimeSeconds), sent[0].Load(), "peer 0 should receive 1 message per second")
// require.Equal(t, uint32(2*TestTimeSeconds), sent[1].Load(), "peer 1 should receive 2 messages per second")
// require.Equal(t, uint32(2*TestTimeSeconds+Burst), sent[2].Load(), "peer 2 should receive 2 messages per second")
}
219 changes: 219 additions & 0 deletions internal/p2p/client/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package client

import (
"context"
"errors"
"math"
"runtime"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/dashpay/tenderdash/internal/p2p"
"github.com/dashpay/tenderdash/internal/p2p/conn"
"github.com/dashpay/tenderdash/libs/log"
"github.com/dashpay/tenderdash/types"
)

// TestRecvRateLimitHandler tests the rate limit middleware when receiving messages from peers.
// It tests that the rate limit is applied per peer.
//
// GIVEN 5 peers named 1..5 and rate limit of 2/s and burst 4,
// WHEN we send 1, 2, 3, 4 and 5 msgs per second respectively for 3 seconds,
// THEN:
// * peer 1 and 2 receive all messages,
// * other peers receive 2 messages per second plus 4 burst messages.
//
// Reuses testRateLimit from client_test.go
func TestRecvRateLimitHandler(t *testing.T) {
// don't run this if we are in short mode
if testing.Short() {
t.Skip("skipping test in short mode.")
}

const (
Limit = 2.0
Burst = 4
Peers = 5
TestTimeSeconds = 3
)

sent := make([]atomic.Uint32, Peers)

fakeHandler := newMockConsumer(t)
fakeHandler.On("Handle", mock.Anything, mock.Anything, mock.Anything).
Return(nil).
Run(func(args mock.Arguments) {
peerID := args.Get(2).(*p2p.Envelope).From
peerNum, err := strconv.Atoi(string(peerID))
require.NoError(t, err)
sent[peerNum-1].Add(1)
})

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

logger := log.NewTestingLogger(t)
client := &Client{}

mw := WithRecvRateLimitPerPeerHandler(ctx,
Limit,
func(*p2p.Envelope) uint { return 1 },
false,
logger,
)(fakeHandler).(*recvRateLimitPerPeerHandler)

mw.burst = Burst

sendFn := func(peerID types.NodeID) error {
envelope := p2p.Envelope{
From: peerID,
ChannelID: testChannelID,
}
return mw.Handle(ctx, client, &envelope)
}

parallelSendWithLimit(t, ctx, sendFn, Peers, TestTimeSeconds)
assertRateLimits(t, sent, Limit, Burst, TestTimeSeconds)
}

// TestSendRateLimit tests the rate limit for sending messages using p2p.client.
//
// Each peer should have his own, independent rate limit.
//
// GIVEN 5 peers named 1..5 and rate limit of 2/s and burst 4,
// WHEN we send 1, 2, 3, 4 and 5 msgs per second respectively for 3 seconds,
// THEN:
// * peer 1 and 2 receive all messages,
// * other peers receive 2 messages per second plus 4 burst messages.
func (suite *ChannelTestSuite) TestSendRateLimit() {
if testing.Short() {
suite.T().Skip("skipping test in short mode.")
}

const (
Limit = 2.0
Burst = 4
Peers = 5
TestTimeSeconds = 3
)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

client := suite.client

limiter := NewRateLimit(ctx, Limit, false, suite.client.logger)
limiter.burst = Burst
suite.client.rateLimit = map[conn.ChannelID]*RateLimit{
testChannelID: limiter,
}

sendFn := func(peerID types.NodeID) error {
envelope := p2p.Envelope{
To: peerID,
ChannelID: testChannelID,
}
return client.Send(ctx, envelope)

}
sent := make([]atomic.Uint32, Peers)

suite.p2pChannel.On("Send", mock.Anything, mock.Anything).
Run(func(args mock.Arguments) {
peerID := args.Get(1).(p2p.Envelope).To
peerNum, err := strconv.Atoi(string(peerID))
suite.NoError(err)
sent[peerNum-1].Add(1)
}).
Return(nil)

parallelSendWithLimit(suite.T(), ctx, sendFn, Peers, TestTimeSeconds)
assertRateLimits(suite.T(), sent, Limit, Burst, TestTimeSeconds)
}

// parallelSendWithLimit sends messages to peers in parallel with a rate limit.
//
// The function sends messages to peers. Each peer gets its number, starting from 1.
// Rate limit is equal to the peer number, eg. peer 1 sends 1 msg/s, peeer 2 sends 2 msg/s etc.
func parallelSendWithLimit(t *testing.T, ctx context.Context, sendFn func(peerID types.NodeID) error,
peers int, testTimeSeconds int) {
t.Helper()

ctx, cancel := context.WithCancel(ctx)
defer cancel()

// all goroutines will wait for the start signal
start := sync.RWMutex{}
start.Lock()

for peer := 1; peer <= peers; peer++ {
peerID := types.NodeID(strconv.Itoa(peer))
// peer number is the rate limit
msgsPerSec := peer

go func(peerID types.NodeID, rate int) {
start.RLock()
defer start.RUnlock()

for s := 0; s < testTimeSeconds; s++ {
until := time.NewTimer(time.Second)
defer until.Stop()

for i := 0; i < rate; i++ {
select {
case <-ctx.Done():
return
default:
}

if err := sendFn(peerID); !errors.Is(err, context.Canceled) {
require.NoError(t, err)
}
}

select {
case <-until.C:
// noop, we just sleep until the end of the second
case <-ctx.Done():
return
}
}

}(peerID, msgsPerSec)
}

// start the test
startTime := time.Now()
start.Unlock()
runtime.Gosched()
time.Sleep(time.Duration(testTimeSeconds) * time.Second)
cancel()
// wait for all goroutines to finish, that is - drop RLocks
start.Lock()
defer start.Unlock()

// check if test ran for the expected time
// note we ignore up to 99 ms to account for any processing time
elapsed := math.Floor(time.Since(startTime).Seconds()*10) / 10
assert.Equal(t, float64(testTimeSeconds), elapsed, "test should run for %d seconds", testTimeSeconds)
}

// assertRateLimits checks if the rate limits were applied correctly
// We assume that index of each item in `sent` is the peer number, as described in parallelSendWithLimit.
func assertRateLimits(t *testing.T, sent []atomic.Uint32, limit float64, burst int, seconds int) {
for peer := 1; peer <= len(sent); peer++ {
expected := int(limit)*seconds + burst
if expected > peer*seconds {
expected = peer * seconds
}

assert.Equal(t, expected, int(sent[peer-1].Load()), "peer %d should receive %d messages", peer, expected)
}
}
Loading