From 50f8db6f2cfb6ea47c0da3467812aca6dddcbe67 Mon Sep 17 00:00:00 2001 From: lklimek <842586+lklimek@users.noreply.github.com> Date: Mon, 20 May 2024 15:36:00 +0200 Subject: [PATCH] fix(mempool)!: limit mempool gossip rate on a per-peer basis (#787) * fix(mempool): remove TxRecvRatePunishPeer option * refactor(mempool): per-peer rate limiting when broadcasting txs * feat(p2pclient): allow rate limiter to set burst * test(p2pclient): test rate limis for incoming messages * refactor(p2pclient): move ratelimit to separate struct * feat(p2pclient): rate limiting of Send operation * feat: garbage collector for unused rate limiters * test(p2pclient): rate limits unit tests * refactor(p2p): remove deprecated ThrottledChannel, replaced by p2pclient.RateLimit * chore: remove old rate limit impl in reactor --- config/config.go | 6 - config/toml.go | 8 +- internal/mempool/p2p_msg_handler.go | 15 +- internal/mempool/reactor.go | 15 +- internal/p2p/channel_params.go | 6 - internal/p2p/client/client.go | 35 ++++ internal/p2p/client/client_test.go | 10 +- internal/p2p/client/consumer.go | 41 +++++ internal/p2p/client/consumer_test.go | 2 +- internal/p2p/client/ratelimit.go | 132 +++++++++++++++ internal/p2p/client/ratelimit_test.go | 219 +++++++++++++++++++++++++ internal/p2p/conn/connection.go | 13 -- internal/p2p/router.go | 6 - internal/p2p/throttled_channel.go | 91 ---------- internal/p2p/throttled_channel_test.go | 189 --------------------- node/node.go | 2 + 16 files changed, 464 insertions(+), 326 deletions(-) create mode 100644 internal/p2p/client/ratelimit.go create mode 100644 internal/p2p/client/ratelimit_test.go delete mode 100644 internal/p2p/throttled_channel.go delete mode 100644 internal/p2p/throttled_channel_test.go diff --git a/config/config.go b/config/config.go index a178e5c365..82557e8412 100644 --- a/config/config.go +++ b/config/config.go @@ -901,12 +901,6 @@ type MempoolConfig struct { // Default: 0 TxRecvRateLimit float64 `mapstructure:"tx-recv-rate-limit"` - // TxRecvRatePunishPeer set to true means that when the rate limit set in TxRecvRateLimit is reached, the - // peer will be punished (disconnected). If set to false, the peer will be throttled (messages will be dropped). - // - // Default: false - TxRecvRatePunishPeer bool `mapstructure:"tx-recv-rate-punish-peer"` - // TxEnqueueTimeout defines how long new mempool transaction will wait when internal // processing queue is full (most likely due to busy CheckTx execution). // Once the timeout is reached, the transaction will be silently dropped. diff --git a/config/toml.go b/config/toml.go index 23a8d79d8e..90d3163f38 100644 --- a/config/toml.go +++ b/config/toml.go @@ -459,23 +459,19 @@ ttl-num-blocks = {{ .Mempool.TTLNumBlocks }} # tx-send-rate-limit is the rate limit for sending transactions to peers, in transactions per second. +# This rate limit is individual for each peer. # If zero, the rate limiter is disabled. # # Default: 0 tx-send-rate-limit = {{ .Mempool.TxSendRateLimit }} # tx-recv-rate-limit is the rate limit for receiving transactions from peers, in transactions per second. +# This rate limit is individual for each peer. # If zero, the rate limiter is disabled. # # Default: 0 tx-recv-rate-limit = {{ .Mempool.TxRecvRateLimit }} -# tx-recv-rate-punish-peer set to true means that when tx-recv-rate-limit is reached, the peer will be punished -# (disconnected). If set to false, the peer will be throttled (messages will be dropped). -# -# Default: false -tx-recv-rate-punish-peer = {{ .Mempool.TxRecvRatePunishPeer }} - # TxEnqueueTimeout defines how many nanoseconds new mempool transaction (received # from other nodes) will wait when internal processing queue is full # (most likely due to busy CheckTx execution).Once the timeout is reached, the transaction diff --git a/internal/mempool/p2p_msg_handler.go b/internal/mempool/p2p_msg_handler.go index df8158aceb..73f21e4e54 100644 --- a/internal/mempool/p2p_msg_handler.go +++ b/internal/mempool/p2p_msg_handler.go @@ -23,8 +23,20 @@ type ( } ) -func consumerHandler(logger log.Logger, config *config.MempoolConfig, checker TxChecker, ids *IDs) client.ConsumerParams { +func consumerHandler(ctx context.Context, logger log.Logger, config *config.MempoolConfig, checker TxChecker, ids *IDs) client.ConsumerParams { chanIDs := []p2p.ChannelID{p2p.MempoolChannel} + + nTokensFunc := func(e *p2p.Envelope) uint { + if m, ok := e.Message.(*protomem.Txs); ok { + return uint(len(m.Txs)) + } + + // unknown message type; this should not happen, we expect only Txs messages + // But we don't panic, as this is not a critical error + logger.Error("received unknown message type, expected Txs; assuming weight 1", "type", fmt.Sprintf("%T", e.Message), "from", e.From) + return 1 + } + return client.ConsumerParams{ ReadChannels: chanIDs, Handler: client.HandlerWithMiddlewares( @@ -34,6 +46,7 @@ func consumerHandler(logger log.Logger, config *config.MempoolConfig, checker Tx checker: checker, ids: ids, }, + client.WithRecvRateLimitPerPeerHandler(ctx, config.TxRecvRateLimit, nTokensFunc, true, logger), client.WithValidateMessageHandler(chanIDs), client.WithErrorLoggerMiddleware(logger), client.WithRecoveryMiddleware(logger), diff --git a/internal/mempool/reactor.go b/internal/mempool/reactor.go index 3f62f958ec..415fef8bec 100644 --- a/internal/mempool/reactor.go +++ b/internal/mempool/reactor.go @@ -74,7 +74,7 @@ func (r *Reactor) OnStart(ctx context.Context) error { r.logger.Info("tx broadcasting is disabled") } go func() { - err := r.p2pClient.Consume(ctx, consumerHandler(r.logger, r.mempool.config, r.mempool, r.ids)) + err := r.p2pClient.Consume(ctx, consumerHandler(ctx, r.logger, r.mempool.config, r.mempool, r.ids)) if err != nil { r.logger.Error("failed to consume p2p checker messages", "error", err) } @@ -153,6 +153,17 @@ func (r *Reactor) processPeerUpdates(ctx context.Context, peerUpdates *p2p.PeerU } } +// sendTxs sends the given txs to the given peer. +// +// Sending txs to a peer is rate limited to prevent spamming the network. +// Each peer has its own rate limiter. +// +// As we will wait for confirmation of the txs being delivered, it is generally safe to +// drop the txs if the send fails. +func (r *Reactor) sendTxs(ctx context.Context, peerID types.NodeID, txs ...types.Tx) error { + return r.p2pClient.SendTxs(ctx, peerID, txs...) +} + func (r *Reactor) broadcastTxRoutine(ctx context.Context, peerID types.NodeID) { peerMempoolID := r.ids.GetForPeer(peerID) var nextGossipTx *clist.CElement @@ -201,7 +212,7 @@ func (r *Reactor) broadcastTxRoutine(ctx context.Context, peerID types.NodeID) { if !memTx.HasPeer(peerMempoolID) { // Send the mempool tx to the corresponding peer. Note, the peer may be // behind and thus would not be able to process the mempool tx correctly. - err := r.p2pClient.SendTxs(ctx, peerID, memTx.tx) + err := r.sendTxs(ctx, peerID, memTx.tx) if err != nil { r.logger.Error("failed to gossip transaction", "peerID", peerID, "error", err) return diff --git a/internal/p2p/channel_params.go b/internal/p2p/channel_params.go index 3e0af1d6e3..d9580e2479 100644 --- a/internal/p2p/channel_params.go +++ b/internal/p2p/channel_params.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/gogo/protobuf/proto" - "golang.org/x/time/rate" "github.com/dashpay/tenderdash/config" "github.com/dashpay/tenderdash/proto/tendermint/blocksync" @@ -80,11 +79,6 @@ func ChannelDescriptors(cfg *config.Config) map[ChannelID]*ChannelDescriptor { RecvMessageCapacity: mempoolBatchSize(cfg.Mempool.MaxTxBytes), RecvBufferCapacity: 1000, Name: "mempool", - SendRateLimit: rate.Limit(cfg.Mempool.TxSendRateLimit), - SendRateBurst: int(5 * cfg.Mempool.TxSendRateLimit), - RecvRateLimit: rate.Limit(cfg.Mempool.TxRecvRateLimit), - RecvRateBurst: int(10 * cfg.Mempool.TxRecvRateLimit), // twice as big as send, to avoid false punishment - RecvRateShouldErr: cfg.Mempool.TxRecvRatePunishPeer, EnqueueTimeout: cfg.Mempool.TxEnqueueTimeout, }, } diff --git a/internal/p2p/client/client.go b/internal/p2p/client/client.go index 8aaa4206b2..ea6d4a08b8 100644 --- a/internal/p2p/client/client.go +++ b/internal/p2p/client/client.go @@ -98,6 +98,8 @@ type ( pending sync.Map reqTimeout time.Duration chanIDResolver func(msg proto.Message) p2p.ChannelID + // rateLimit represents a rate limiter for the channel; can be nil + rateLimit map[p2p.ChannelID]*RateLimit } // OptionFunc is a client optional function, it is used to override the default parameters in a Client OptionFunc func(c *Client) @@ -128,6 +130,18 @@ func WithChanIDResolver(resolver func(msg proto.Message) p2p.ChannelID) OptionFu } } +// WithSendRateLimits defines a rate limiter for the provided channels. +// +// Provided rate limiter will be shared between provided channels. +// Use this function multiple times to set different rate limiters for different channels. +func WithSendRateLimits(rateLimit *RateLimit, channels ...p2p.ChannelID) OptionFunc { + return func(c *Client) { + for _, ch := range channels { + c.rateLimit[ch] = rateLimit + } + } +} + // New creates and returns Client with optional functions func New(descriptors map[p2p.ChannelID]*p2p.ChannelDescriptor, creator p2p.ChannelCreator, opts ...OptionFunc) *Client { client := &Client{ @@ -136,6 +150,7 @@ func New(descriptors map[p2p.ChannelID]*p2p.ChannelDescriptor, creator p2p.Chann logger: log.NewNopLogger(), reqTimeout: peerTimeout, chanIDResolver: p2p.ResolveChannelID, + rateLimit: make(map[p2p.ChannelID]*RateLimit), } for _, opt := range opts { opt(client) @@ -238,6 +253,13 @@ func (c *Client) SendTxs(ctx context.Context, peerID types.NodeID, tx ...types.T // Send sends p2p message to a peer, allowed p2p.Envelope or p2p.PeerError types func (c *Client) Send(ctx context.Context, msg any) error { + return c.SendN(ctx, msg, 1) +} + +// SendN sends p2p message to a peer, consuming `nTokens` from rate limiter. +// +// Allowed `msg` types are: p2p.Envelope or p2p.PeerError +func (c *Client) SendN(ctx context.Context, msg any, nTokens int) error { switch t := msg.(type) { case p2p.PeerError: ch, err := c.chanStore.get(ctx, p2p.ErrorChannel) @@ -257,6 +279,19 @@ func (c *Client) Send(ctx context.Context, msg any) error { if err != nil { return err } + if limiter, ok := c.rateLimit[t.ChannelID]; ok { + ok, err := limiter.Limit(ctx, t.To, nTokens) + if err != nil { + return fmt.Errorf("rate limited when sending message %T on channel %d to %s: %w", + t.Message, t.ChannelID, t.To, err) + } + if !ok { + c.logger.Debug("dropping message due to rate limit", + "channel", t.ChannelID, "peer", t.To, "message", t.Message) + return nil + } + } + return ch.Send(ctx, t) } return fmt.Errorf("cannot send an unsupported message type %T", msg) diff --git a/internal/p2p/client/client_test.go b/internal/p2p/client/client_test.go index 50a88c4bb4..ee353b0f62 100644 --- a/internal/p2p/client/client_test.go +++ b/internal/p2p/client/client_test.go @@ -54,11 +54,11 @@ func (suite *ChannelTestSuite) SetupTest() { suite.fakeClock = clockwork.NewFakeClock() suite.client = New( suite.descriptors, - func(ctx context.Context, descriptor *p2p.ChannelDescriptor) (p2p.Channel, error) { + func(_ctx context.Context, _descriptor *p2p.ChannelDescriptor) (p2p.Channel, error) { return suite.p2pChannel, nil }, WithClock(suite.fakeClock), - WithChanIDResolver(func(msg proto.Message) p2p.ChannelID { + WithChanIDResolver(func(_msg proto.Message) p2p.ChannelID { return testChannelID }), ) @@ -185,7 +185,7 @@ func (suite *ChannelTestSuite) TestConsumeHandle() { suite.p2pChannel. On("Receive", ctx). Once(). - Return(func(ctx context.Context) p2p.ChannelIterator { + Return(func(_ctx context.Context) p2p.ChannelIterator { return p2p.NewChannelIterator(outCh) }) consumer := newMockConsumer(suite.T()) @@ -226,7 +226,7 @@ func (suite *ChannelTestSuite) TestConsumeResolve() { suite.p2pChannel. On("Receive", ctx). Once(). - Return(func(ctx context.Context) p2p.ChannelIterator { + Return(func(_ctx context.Context) p2p.ChannelIterator { return p2p.NewChannelIterator(outCh) }) resCh := suite.client.addPending(reqID) @@ -278,7 +278,7 @@ func (suite *ChannelTestSuite) TestConsumeError() { suite.p2pChannel. On("Receive", ctx). Once(). - Return(func(ctx context.Context) p2p.ChannelIterator { + Return(func(_ctx context.Context) p2p.ChannelIterator { return p2p.NewChannelIterator(outCh) }) consumer := newMockConsumer(suite.T()) diff --git a/internal/p2p/client/consumer.go b/internal/p2p/client/consumer.go index 6038737d7b..976c0590ff 100644 --- a/internal/p2p/client/consumer.go +++ b/internal/p2p/client/consumer.go @@ -9,6 +9,9 @@ import ( "github.com/dashpay/tenderdash/libs/log" ) +// DefaultRecvBurstMultiplier tells how many times burst is bigger than the limit in recvRateLimitPerPeerHandler +const DefaultRecvBurstMultiplier = 10 + var ( ErrRequestIDAttributeRequired = errors.New("envelope requestID attribute is required") ErrResponseIDAttributeRequired = errors.New("envelope responseID attribute is required") @@ -41,6 +44,19 @@ type ( allowedChannelIDs map[p2p.ChannelID]struct{} next ConsumerHandler } + + // TokenNumberFunc is a function that returns number of tokens to consume for a given envelope + TokenNumberFunc func(*p2p.Envelope) uint + + recvRateLimitPerPeerHandler struct { + RateLimit + + // next is the next handler in the chain + next ConsumerHandler + + // nTokens is a function that returns number of tokens to consume for a given envelope; if unsure, return 1 + nTokensFunc TokenNumberFunc + } ) // WithRecoveryMiddleware creates panic recovery middleware @@ -75,6 +91,18 @@ func WithValidateMessageHandler(allowedChannelIDs []p2p.ChannelID) ConsumerMiddl } } +func WithRecvRateLimitPerPeerHandler(ctx context.Context, limit float64, nTokensFunc TokenNumberFunc, drop bool, logger log.Logger) ConsumerMiddlewareFunc { + return func(next ConsumerHandler) ConsumerHandler { + hd := &recvRateLimitPerPeerHandler{ + RateLimit: *NewRateLimit(ctx, limit, drop, logger), + nTokensFunc: nTokensFunc, + } + + hd.next = next + return hd + } +} + // HandlerWithMiddlewares is a function that wraps a handler in middlewares func HandlerWithMiddlewares(handler ConsumerHandler, mws ...ConsumerMiddlewareFunc) ConsumerHandler { for _, mw := range mws { @@ -125,3 +153,16 @@ func (h *validateMessageHandler) Handle(ctx context.Context, client *Client, env } return h.next.Handle(ctx, client, envelope) } + +func (h *recvRateLimitPerPeerHandler) Handle(ctx context.Context, client *Client, envelope *p2p.Envelope) error { + accepted, err := h.RateLimit.Limit(ctx, envelope.From, int(h.nTokensFunc(envelope))) + if err != nil { + return fmt.Errorf("rate limit failed for peer '%s;: %w", envelope.From, err) + } + if !accepted { + h.logger.Debug("silently dropping message due to rate limit", "peer", envelope.From, "envelope", envelope) + return nil + } + + return h.next.Handle(ctx, client, envelope) +} diff --git a/internal/p2p/client/consumer_test.go b/internal/p2p/client/consumer_test.go index cfacda0267..31b1347cd3 100644 --- a/internal/p2p/client/consumer_test.go +++ b/internal/p2p/client/consumer_test.go @@ -35,7 +35,7 @@ func TestErrorLoggerP2PMessageHandler(t *testing.T) { wantErr: "error", }, { - mockFn: func(hd *mockConsumer, logger *log.TestingLogger) { + mockFn: func(hd *mockConsumer, _logger *log.TestingLogger) { hd.On("Handle", mock.Anything, mock.Anything, mock.Anything). Once(). Return(nil) diff --git a/internal/p2p/client/ratelimit.go b/internal/p2p/client/ratelimit.go new file mode 100644 index 0000000000..a42e22efb3 --- /dev/null +++ b/internal/p2p/client/ratelimit.go @@ -0,0 +1,132 @@ +package client + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "golang.org/x/time/rate" + + "github.com/dashpay/tenderdash/libs/log" + "github.com/dashpay/tenderdash/types" +) + +const PeerRateLimitLifetime = 60 // number of seconds to keep the rate limiter for a peer + +// RateLimit is a rate limiter for p2p messages. +// It is used to limit the rate of incoming messages from a peer. +// Each peer has its own independent limit. +// +// Use NewRateLimit to create a new rate limiter. +// Use [Limit()] to wait for the rate limit to allow the message to be sent. +type RateLimit struct { + // limit is the rate limit per peer per second; 0 means no limit + limit float64 + // burst is the initial number of tokens; see rate module for more details + burst int + // map of peerID to rate.Limiter + limiters sync.Map + // drop is a flag to silently drop the message if the rate limit is exceeded; otherwise we will wait + drop bool + + logger log.Logger +} + +type limiter struct { + *rate.Limiter + // lastAccess is the last time the limiter was accessed, as Unix time (seconds) + lastAccess atomic.Int64 +} + +// NewRateLimit creates a new rate limiter. +// +// # Arguments +// +// * `ctx` - context; used to gracefully shutdown the garbage collection routine +// * `limit` - rate limit per peer per second; 0 means no limit +// * `drop` - silently drop the message if the rate limit is exceeded; otherwise we will wait until the message is allowed +// * `logger` - logger +func NewRateLimit(ctx context.Context, limit float64, drop bool, logger log.Logger) *RateLimit { + h := &RateLimit{ + limiters: sync.Map{}, + limit: limit, + burst: int(DefaultRecvBurstMultiplier * limit), + drop: drop, + logger: logger, + } + + // start the garbage collection routine + go h.gcRoutine(ctx) + + return h +} + +func (h *RateLimit) getLimiter(peerID types.NodeID) *limiter { + var limit *limiter + if l, ok := h.limiters.Load(peerID); ok { + limit = l.(*limiter) + } else { + limit = &limiter{Limiter: rate.NewLimiter(rate.Limit(h.limit), h.burst)} + // we have a slight race condition here, possibly overwriting the limiter, but it's not a big deal + // as the worst case scenario is that we allow one or two more messages than we should + h.limiters.Store(peerID, limit) + } + + limit.lastAccess.Store(time.Now().Unix()) + + return limit +} + +// Limit waits for the rate limit to allow the message to be sent. +// It returns true if the message is allowed, false otherwise. +// +// If peerID is empty, messages is always allowed. +// +// Returns true when the message is allowed, false if it should be dropped. +// +// Arguments: +// - ctx: context +// - peerID: peer ID; if empty, the message is always allowed +// - nTokens: number of tokens to consume; use 1 if unsure +func (h *RateLimit) Limit(ctx context.Context, peerID types.NodeID, nTokens int) (allowed bool, err error) { + if h.limit > 0 && peerID != "" { + limiter := h.getLimiter(peerID) + + if h.drop { + return limiter.AllowN(time.Now(), nTokens), nil + } + + if err := limiter.WaitN(ctx, 1); err != nil { + return false, fmt.Errorf("rate limit failed for peer %s: %w", peerID, err) + } + } + return true, nil +} + +// gcRoutine is a goroutine that removes unused limiters for peers every `PeerRateLimitLifetime` seconds. +func (h *RateLimit) gcRoutine(ctx context.Context) { + ticker := time.NewTicker(PeerRateLimitLifetime * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + h.gc() + } + } +} + +// GC removes old limiters. +func (h *RateLimit) gc() { + now := time.Now().Unix() + h.limiters.Range(func(key, value interface{}) bool { + if value.(*limiter).lastAccess.Load() < now-60 { + h.limiters.Delete(key) + } + return true + }) +} diff --git a/internal/p2p/client/ratelimit_test.go b/internal/p2p/client/ratelimit_test.go new file mode 100644 index 0000000000..004a330049 --- /dev/null +++ b/internal/p2p/client/ratelimit_test.go @@ -0,0 +1,219 @@ +package client + +import ( + "context" + "errors" + "math" + "runtime" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/dashpay/tenderdash/internal/p2p" + "github.com/dashpay/tenderdash/internal/p2p/conn" + "github.com/dashpay/tenderdash/libs/log" + "github.com/dashpay/tenderdash/types" +) + +// TestRecvRateLimitHandler tests the rate limit middleware when receiving messages from peers. +// It tests that the rate limit is applied per peer. +// +// GIVEN 5 peers named 1..5 and rate limit of 2/s and burst 4, +// WHEN we send 1, 2, 3, 4 and 5 msgs per second respectively for 3 seconds, +// THEN: +// * peer 1 and 2 receive all messages, +// * other peers receive 2 messages per second plus 4 burst messages. +// +// Reuses testRateLimit from client_test.go +func TestRecvRateLimitHandler(t *testing.T) { + // don't run this if we are in short mode + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + const ( + Limit = 2.0 + Burst = 4 + Peers = 5 + TestTimeSeconds = 3 + ) + + sent := make([]atomic.Uint32, Peers) + + fakeHandler := newMockConsumer(t) + fakeHandler.On("Handle", mock.Anything, mock.Anything, mock.Anything). + Return(nil). + Run(func(args mock.Arguments) { + peerID := args.Get(2).(*p2p.Envelope).From + peerNum, err := strconv.Atoi(string(peerID)) + require.NoError(t, err) + sent[peerNum-1].Add(1) + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewTestingLogger(t) + client := &Client{} + + mw := WithRecvRateLimitPerPeerHandler(ctx, + Limit, + func(*p2p.Envelope) uint { return 1 }, + false, + logger, + )(fakeHandler).(*recvRateLimitPerPeerHandler) + + mw.burst = Burst + + sendFn := func(peerID types.NodeID) error { + envelope := p2p.Envelope{ + From: peerID, + ChannelID: testChannelID, + } + return mw.Handle(ctx, client, &envelope) + } + + parallelSendWithLimit(t, ctx, sendFn, Peers, TestTimeSeconds) + assertRateLimits(t, sent, Limit, Burst, TestTimeSeconds) +} + +// TestSendRateLimit tests the rate limit for sending messages using p2p.client. +// +// Each peer should have his own, independent rate limit. +// +// GIVEN 5 peers named 1..5 and rate limit of 2/s and burst 4, +// WHEN we send 1, 2, 3, 4 and 5 msgs per second respectively for 3 seconds, +// THEN: +// * peer 1 and 2 receive all messages, +// * other peers receive 2 messages per second plus 4 burst messages. +func (suite *ChannelTestSuite) TestSendRateLimit() { + if testing.Short() { + suite.T().Skip("skipping test in short mode.") + } + + const ( + Limit = 2.0 + Burst = 4 + Peers = 5 + TestTimeSeconds = 3 + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := suite.client + + limiter := NewRateLimit(ctx, Limit, false, suite.client.logger) + limiter.burst = Burst + suite.client.rateLimit = map[conn.ChannelID]*RateLimit{ + testChannelID: limiter, + } + + sendFn := func(peerID types.NodeID) error { + envelope := p2p.Envelope{ + To: peerID, + ChannelID: testChannelID, + } + return client.Send(ctx, envelope) + + } + sent := make([]atomic.Uint32, Peers) + + suite.p2pChannel.On("Send", mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + peerID := args.Get(1).(p2p.Envelope).To + peerNum, err := strconv.Atoi(string(peerID)) + suite.NoError(err) + sent[peerNum-1].Add(1) + }). + Return(nil) + + parallelSendWithLimit(suite.T(), ctx, sendFn, Peers, TestTimeSeconds) + assertRateLimits(suite.T(), sent, Limit, Burst, TestTimeSeconds) +} + +// parallelSendWithLimit sends messages to peers in parallel with a rate limit. +// +// The function sends messages to peers. Each peer gets its number, starting from 1. +// Rate limit is equal to the peer number, eg. peer 1 sends 1 msg/s, peeer 2 sends 2 msg/s etc. +func parallelSendWithLimit(t *testing.T, ctx context.Context, sendFn func(peerID types.NodeID) error, + peers int, testTimeSeconds int) { + t.Helper() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // all goroutines will wait for the start signal + start := sync.RWMutex{} + start.Lock() + + for peer := 1; peer <= peers; peer++ { + peerID := types.NodeID(strconv.Itoa(peer)) + // peer number is the rate limit + msgsPerSec := peer + + go func(peerID types.NodeID, rate int) { + start.RLock() + defer start.RUnlock() + + for s := 0; s < testTimeSeconds; s++ { + until := time.NewTimer(time.Second) + defer until.Stop() + + for i := 0; i < rate; i++ { + select { + case <-ctx.Done(): + return + default: + } + + if err := sendFn(peerID); !errors.Is(err, context.Canceled) { + require.NoError(t, err) + } + } + + select { + case <-until.C: + // noop, we just sleep until the end of the second + case <-ctx.Done(): + return + } + } + + }(peerID, msgsPerSec) + } + + // start the test + startTime := time.Now() + start.Unlock() + runtime.Gosched() + time.Sleep(time.Duration(testTimeSeconds) * time.Second) + cancel() + // wait for all goroutines to finish, that is - drop RLocks + start.Lock() + defer start.Unlock() + + // check if test ran for the expected time + // note we ignore up to 99 ms to account for any processing time + elapsed := math.Floor(time.Since(startTime).Seconds()*10) / 10 + assert.Equal(t, float64(testTimeSeconds), elapsed, "test should run for %d seconds", testTimeSeconds) +} + +// assertRateLimits checks if the rate limits were applied correctly +// We assume that index of each item in `sent` is the peer number, as described in parallelSendWithLimit. +func assertRateLimits(t *testing.T, sent []atomic.Uint32, limit float64, burst int, seconds int) { + for peer := 1; peer <= len(sent); peer++ { + expected := int(limit)*seconds + burst + if expected > peer*seconds { + expected = peer * seconds + } + + assert.Equal(t, expected, int(sent[peer-1].Load()), "peer %d should receive %d messages", peer, expected) + } +} diff --git a/internal/p2p/conn/connection.go b/internal/p2p/conn/connection.go index 8ba542dbf8..82494a5964 100644 --- a/internal/p2p/conn/connection.go +++ b/internal/p2p/conn/connection.go @@ -14,7 +14,6 @@ import ( "time" sync "github.com/sasha-s/go-deadlock" - "golang.org/x/time/rate" "github.com/gogo/protobuf/proto" @@ -611,18 +610,6 @@ type ChannelDescriptor struct { // RecvMessageCapacity defines the max message size for a given p2p Channel. RecvMessageCapacity int - /// SendRateLimit is used to limit the rate of sending messages, per second. - SendRateLimit rate.Limit - SendRateBurst int - - /// RecvRateLimit is used to limit the rate of receiving messages, per second. - RecvRateLimit rate.Limit - RecvRateBurst int - // RecvRateShouldErr is used to determine if the rate limiter should - // report an error whenever recv rate limit is exceeded, most likely - // causing the peer to disconnect. - RecvRateShouldErr bool - // RecvBufferCapacity defines the max number of inbound messages for a // given p2p Channel queue. RecvBufferCapacity int diff --git a/internal/p2p/router.go b/internal/p2p/router.go index 43bdde8164..31830139bd 100644 --- a/internal/p2p/router.go +++ b/internal/p2p/router.go @@ -263,12 +263,6 @@ func (r *Router) OpenChannel(ctx context.Context, chDesc *ChannelDescriptor) (Ch outCh := make(chan Envelope, chDesc.RecvBufferCapacity) errCh := make(chan PeerError, chDesc.RecvBufferCapacity) channel := NewChannel(chDesc.ID, chDesc.Name, queue.dequeue(), outCh, errCh) - if chDesc.SendRateLimit > 0 || chDesc.RecvRateLimit > 0 { - channel = NewThrottledChannel(channel, - chDesc.SendRateLimit, chDesc.SendRateBurst, - chDesc.RecvRateLimit, chDesc.RecvRateBurst, chDesc.RecvRateShouldErr, - r.logger) - } r.channelQueues[id] = queue diff --git a/internal/p2p/throttled_channel.go b/internal/p2p/throttled_channel.go deleted file mode 100644 index f82f0374b2..0000000000 --- a/internal/p2p/throttled_channel.go +++ /dev/null @@ -1,91 +0,0 @@ -package p2p - -import ( - "context" - - "golang.org/x/time/rate" - - "github.com/dashpay/tenderdash/libs/log" -) - -// / Channel that will block if the send limit is reached -type ThrottledChannel struct { - channel Channel - sendLimiter *rate.Limiter - recvLimiter *rate.Limiter - recvShouldErr bool - - logger log.Logger -} - -// NewThrottledChannel creates a new throttled channel. -// The rate is specified in messages per second. -// The burst is specified in messages. -func NewThrottledChannel(channel Channel, sendLimit rate.Limit, sendBurst int, - recvLimit rate.Limit, recvBurst int, recvShouldErr bool, - logger log.Logger) *ThrottledChannel { - - var ( - sendLimiter *rate.Limiter - recvLimiter *rate.Limiter - ) - if sendLimit > 0 { - sendLimiter = rate.NewLimiter(sendLimit, sendBurst) - } - if recvLimit > 0 { - recvLimiter = rate.NewLimiter(recvLimit, recvBurst) - } - - return &ThrottledChannel{ - channel: channel, - sendLimiter: sendLimiter, - recvLimiter: recvLimiter, - recvShouldErr: recvShouldErr, - logger: logger, - } -} - -var _ Channel = (*ThrottledChannel)(nil) - -func (ch *ThrottledChannel) Send(ctx context.Context, envelope Envelope) error { - // Wait until limiter allows us to proceed. - if ch.sendLimiter != nil { - if err := ch.sendLimiter.Wait(ctx); err != nil { - return err - } - } - - return ch.channel.Send(ctx, envelope) -} - -func (ch *ThrottledChannel) SendError(ctx context.Context, pe PeerError) error { - // Wait until limiter allows us to proceed. - if err := ch.sendLimiter.Wait(ctx); err != nil { - return err - } - - return ch.channel.SendError(ctx, pe) -} - -func (ch *ThrottledChannel) Receive(ctx context.Context) ChannelIterator { - if ch.recvLimiter == nil { - return ch.channel.Receive(ctx) - } - - innerIter := ch.channel.Receive(ctx) - iter, err := ThrottledChannelIterator(ctx, ch.recvLimiter, innerIter, ch.recvShouldErr, ch.channel, ch.logger) - if err != nil { - ch.logger.Error("error creating ThrottledChannelIterator", "err", err) - return nil - } - - return iter -} - -func (ch *ThrottledChannel) Err() error { - return ch.channel.Err() -} - -func (ch *ThrottledChannel) String() string { - return ch.channel.String() -} diff --git a/internal/p2p/throttled_channel_test.go b/internal/p2p/throttled_channel_test.go deleted file mode 100644 index 4571d89f10..0000000000 --- a/internal/p2p/throttled_channel_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package p2p_test - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/time/rate" - - "github.com/dashpay/tenderdash/internal/p2p" - "github.com/dashpay/tenderdash/libs/log" -) - -type mockChannel struct { - sent atomic.Uint32 - received atomic.Uint32 - errored atomic.Uint32 -} - -func (c *mockChannel) SendCount() int { - return int(c.sent.Load()) -} - -func (c *mockChannel) RecvCount() int { - return int(c.received.Load()) -} - -func (c *mockChannel) Send(_ context.Context, _e p2p.Envelope) error { - c.sent.Add(1) - return nil -} - -func (c *mockChannel) SendError(_ context.Context, _ p2p.PeerError) error { - c.errored.Add(1) - return nil -} - -func (c *mockChannel) Receive(ctx context.Context) p2p.ChannelIterator { - var pipe = make(chan p2p.Envelope, 5) - - go func() { - for { - select { - case pipe <- p2p.Envelope{}: - c.received.Add(1) - case <-ctx.Done(): - close(pipe) - return - } - } - }() - - return p2p.NewChannelIterator(pipe) -} - -func (c *mockChannel) Err() error { - if e := c.errored.Load(); e > 0 { - return fmt.Errorf("mock_channel_error: error count: %d", e) - } - return nil -} - -func (c *mockChannel) String() string { - return "mock_channel" -} - -func TestThrottledChannelSend(t *testing.T) { - const n = 31 - const rate rate.Limit = 10 - const burst = int(rate) - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - logger := log.NewTestingLogger(t) - - mock := &mockChannel{} - - ch := p2p.NewThrottledChannel(mock, rate, burst, 0, 0, false, logger) // 1 message per second - - wg := sync.WaitGroup{} - start := time.Now() - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - err := ch.Send(ctx, p2p.Envelope{}) - require.NoError(t, err) - wg.Done() - }() - } - - time.Sleep(time.Second) - assert.LessOrEqual(t, mock.SendCount(), burst+int(rate)) - - // Wait until all finish - wg.Wait() - took := time.Since(start) - assert.Equal(t, n, mock.SendCount()) - assert.GreaterOrEqual(t, took.Seconds(), 2.0) -} - -// Given some thrrottled channel that generates messages all the time, we should error out after receiving a rate -// of 10 messages per second. -func TestThrottledChannelRecvError(t *testing.T) { - const rate rate.Limit = 10 - const burst = int(rate) - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - logger := log.NewTestingLogger(t) - - mock := &mockChannel{} - - ch := p2p.NewThrottledChannel(mock, rate, burst, rate, burst, true, logger) // 1 message per second - - start := time.Now() - - assert.NoError(t, mock.Err()) - - iter := ch.Receive(ctx) - - for i := 0; i < burst+int(rate)+1; i++ { - assert.True(t, iter.Next(ctx)) - - e := iter.Envelope() - if e == nil { - t.Error("nil envelope") - } - } - - err := mock.Err() - t.Logf("expected mock error: %v", err) - assert.Error(t, mock.Err()) - - // Wait until all finish - cancel() - - took := time.Since(start) - assert.GreaterOrEqual(t, took.Seconds(), 1.0) -} - -// Given some thrrottled channel that generates messages all the time, we should be able to receive them at a rate -// of 10 messages per second. -func TestThrottledChannelRecv(t *testing.T) { - const rate rate.Limit = 10 - const burst = int(rate) - - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - - logger := log.NewTestingLogger(t) - - mock := &mockChannel{} - - ch := p2p.NewThrottledChannel(mock, rate, burst, rate, burst, false, logger) // 1 message per second - - start := time.Now() - count := 0 - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - iter := ch.Receive(ctx) - for iter.Next(ctx) { - e := iter.Envelope() - if e == nil { - t.Error("nil envelope") - } - count++ - } - wg.Done() - }() - - time.Sleep(time.Second) - assert.LessOrEqual(t, mock.SendCount(), burst+int(rate)) - - // Wait until all finish - cancel() - wg.Wait() - - took := time.Since(start) - assert.Greater(t, mock.RecvCount(), count*100, "we should generate much more messages than we can receive") - assert.GreaterOrEqual(t, took.Seconds(), 1.0) -} diff --git a/node/node.go b/node/node.go index 3c83a04611..67292f777e 100644 --- a/node/node.go +++ b/node/node.go @@ -306,7 +306,9 @@ func makeNode( p2p.ChannelDescriptors(cfg), node.router.OpenChannel, p2pclient.WithLogger(logger), + p2pclient.WithSendRateLimits(p2pclient.NewRateLimit(ctx, cfg.Mempool.TxSendRateLimit, false, logger), p2p.MempoolChannel), ) + evReactor, evPool, edbCloser, err := createEvidenceReactor(logger, cfg, dbProvider, stateStore, blockStore, peerManager.Subscribe, node.router.OpenChannel, nodeMetrics.evidence, eventBus) closers = append(closers, edbCloser)