Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mempool)!: limit mempool gossip rate on a per-peer basis #787

Merged
merged 10 commits into from
May 20, 2024
6 changes: 0 additions & 6 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -901,12 +901,6 @@ type MempoolConfig struct {
// Default: 0
TxRecvRateLimit float64 `mapstructure:"tx-recv-rate-limit"`

// TxRecvRatePunishPeer set to true means that when the rate limit set in TxRecvRateLimit is reached, the
// peer will be punished (disconnected). If set to false, the peer will be throttled (messages will be dropped).
//
// Default: false
TxRecvRatePunishPeer bool `mapstructure:"tx-recv-rate-punish-peer"`

// TxEnqueueTimeout defines how long new mempool transaction will wait when internal
// processing queue is full (most likely due to busy CheckTx execution).
// Once the timeout is reached, the transaction will be silently dropped.
Expand Down
8 changes: 2 additions & 6 deletions config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,23 +459,19 @@ ttl-num-blocks = {{ .Mempool.TTLNumBlocks }}


# tx-send-rate-limit is the rate limit for sending transactions to peers, in transactions per second.
# This rate limit is individual for each peer.
# If zero, the rate limiter is disabled.
#
# Default: 0
tx-send-rate-limit = {{ .Mempool.TxSendRateLimit }}

# tx-recv-rate-limit is the rate limit for receiving transactions from peers, in transactions per second.
# This rate limit is individual for each peer.
# If zero, the rate limiter is disabled.
#
# Default: 0
tx-recv-rate-limit = {{ .Mempool.TxRecvRateLimit }}

# tx-recv-rate-punish-peer set to true means that when tx-recv-rate-limit is reached, the peer will be punished
# (disconnected). If set to false, the peer will be throttled (messages will be dropped).
#
# Default: false
tx-recv-rate-punish-peer = {{ .Mempool.TxRecvRatePunishPeer }}

# TxEnqueueTimeout defines how many nanoseconds new mempool transaction (received
# from other nodes) will wait when internal processing queue is full
# (most likely due to busy CheckTx execution).Once the timeout is reached, the transaction
Expand Down
15 changes: 14 additions & 1 deletion internal/mempool/p2p_msg_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,20 @@ type (
}
)

func consumerHandler(logger log.Logger, config *config.MempoolConfig, checker TxChecker, ids *IDs) client.ConsumerParams {
func consumerHandler(ctx context.Context, logger log.Logger, config *config.MempoolConfig, checker TxChecker, ids *IDs) client.ConsumerParams {
chanIDs := []p2p.ChannelID{p2p.MempoolChannel}

nTokensFunc := func(e *p2p.Envelope) uint {
if m, ok := e.Message.(*protomem.Txs); ok {
return uint(len(m.Txs))
}

// unknown message type; this should not happen, we expect only Txs messages
// But we don't panic, as this is not a critical error
logger.Error("received unknown message type, expected Txs; assuming weight 1", "type", fmt.Sprintf("%T", e.Message), "from", e.From)
return 1
}

return client.ConsumerParams{
ReadChannels: chanIDs,
Handler: client.HandlerWithMiddlewares(
Expand All @@ -34,6 +46,7 @@ func consumerHandler(logger log.Logger, config *config.MempoolConfig, checker Tx
checker: checker,
ids: ids,
},
client.WithRecvRateLimitPerPeerHandler(ctx, config.TxRecvRateLimit, nTokensFunc, true, logger),
client.WithValidateMessageHandler(chanIDs),
client.WithErrorLoggerMiddleware(logger),
client.WithRecoveryMiddleware(logger),
Expand Down
15 changes: 13 additions & 2 deletions internal/mempool/reactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (r *Reactor) OnStart(ctx context.Context) error {
r.logger.Info("tx broadcasting is disabled")
}
go func() {
err := r.p2pClient.Consume(ctx, consumerHandler(r.logger, r.mempool.config, r.mempool, r.ids))
err := r.p2pClient.Consume(ctx, consumerHandler(ctx, r.logger, r.mempool.config, r.mempool, r.ids))
if err != nil {
r.logger.Error("failed to consume p2p checker messages", "error", err)
}
Expand Down Expand Up @@ -153,6 +153,17 @@ func (r *Reactor) processPeerUpdates(ctx context.Context, peerUpdates *p2p.PeerU
}
}

// sendTxs sends the given txs to the given peer.
//
// Sending txs to a peer is rate limited to prevent spamming the network.
// Each peer has its own rate limiter.
//
// As we will wait for confirmation of the txs being delivered, it is generally safe to
// drop the txs if the send fails.
func (r *Reactor) sendTxs(ctx context.Context, peerID types.NodeID, txs ...types.Tx) error {
return r.p2pClient.SendTxs(ctx, peerID, txs...)
}

func (r *Reactor) broadcastTxRoutine(ctx context.Context, peerID types.NodeID) {
peerMempoolID := r.ids.GetForPeer(peerID)
var nextGossipTx *clist.CElement
Expand Down Expand Up @@ -201,7 +212,7 @@ func (r *Reactor) broadcastTxRoutine(ctx context.Context, peerID types.NodeID) {
if !memTx.HasPeer(peerMempoolID) {
// Send the mempool tx to the corresponding peer. Note, the peer may be
// behind and thus would not be able to process the mempool tx correctly.
err := r.p2pClient.SendTxs(ctx, peerID, memTx.tx)
err := r.sendTxs(ctx, peerID, memTx.tx)
if err != nil {
r.logger.Error("failed to gossip transaction", "peerID", peerID, "error", err)
return
Expand Down
6 changes: 0 additions & 6 deletions internal/p2p/channel_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"

"github.com/gogo/protobuf/proto"
"golang.org/x/time/rate"

"github.com/dashpay/tenderdash/config"
"github.com/dashpay/tenderdash/proto/tendermint/blocksync"
Expand Down Expand Up @@ -80,11 +79,6 @@ func ChannelDescriptors(cfg *config.Config) map[ChannelID]*ChannelDescriptor {
RecvMessageCapacity: mempoolBatchSize(cfg.Mempool.MaxTxBytes),
RecvBufferCapacity: 1000,
Name: "mempool",
SendRateLimit: rate.Limit(cfg.Mempool.TxSendRateLimit),
SendRateBurst: int(5 * cfg.Mempool.TxSendRateLimit),
RecvRateLimit: rate.Limit(cfg.Mempool.TxRecvRateLimit),
RecvRateBurst: int(10 * cfg.Mempool.TxRecvRateLimit), // twice as big as send, to avoid false punishment
RecvRateShouldErr: cfg.Mempool.TxRecvRatePunishPeer,
EnqueueTimeout: cfg.Mempool.TxEnqueueTimeout,
},
}
Expand Down
35 changes: 35 additions & 0 deletions internal/p2p/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ type (
pending sync.Map
reqTimeout time.Duration
chanIDResolver func(msg proto.Message) p2p.ChannelID
// rateLimit represents a rate limiter for the channel; can be nil
rateLimit map[p2p.ChannelID]*RateLimit
}
// OptionFunc is a client optional function, it is used to override the default parameters in a Client
OptionFunc func(c *Client)
Expand Down Expand Up @@ -128,6 +130,18 @@ func WithChanIDResolver(resolver func(msg proto.Message) p2p.ChannelID) OptionFu
}
}

// WithSendRateLimits defines a rate limiter for the provided channels.
//
// Provided rate limiter will be shared between provided channels.
// Use this function multiple times to set different rate limiters for different channels.
func WithSendRateLimits(rateLimit *RateLimit, channels ...p2p.ChannelID) OptionFunc {
return func(c *Client) {
for _, ch := range channels {
c.rateLimit[ch] = rateLimit
}
}
}

// New creates and returns Client with optional functions
func New(descriptors map[p2p.ChannelID]*p2p.ChannelDescriptor, creator p2p.ChannelCreator, opts ...OptionFunc) *Client {
client := &Client{
Expand All @@ -136,6 +150,7 @@ func New(descriptors map[p2p.ChannelID]*p2p.ChannelDescriptor, creator p2p.Chann
logger: log.NewNopLogger(),
reqTimeout: peerTimeout,
chanIDResolver: p2p.ResolveChannelID,
rateLimit: make(map[p2p.ChannelID]*RateLimit),
}
for _, opt := range opts {
opt(client)
Expand Down Expand Up @@ -238,6 +253,13 @@ func (c *Client) SendTxs(ctx context.Context, peerID types.NodeID, tx ...types.T

// Send sends p2p message to a peer, allowed p2p.Envelope or p2p.PeerError types
func (c *Client) Send(ctx context.Context, msg any) error {
return c.SendN(ctx, msg, 1)
}

// SendN sends p2p message to a peer, consuming `nTokens` from rate limiter.
//
// Allowed `msg` types are: p2p.Envelope or p2p.PeerError
func (c *Client) SendN(ctx context.Context, msg any, nTokens int) error {
switch t := msg.(type) {
case p2p.PeerError:
ch, err := c.chanStore.get(ctx, p2p.ErrorChannel)
Expand All @@ -257,6 +279,19 @@ func (c *Client) Send(ctx context.Context, msg any) error {
if err != nil {
return err
}
if limiter, ok := c.rateLimit[t.ChannelID]; ok {
ok, err := limiter.Limit(ctx, t.To, nTokens)
if err != nil {
return fmt.Errorf("rate limited when sending message %T on channel %d to %s: %w",
t.Message, t.ChannelID, t.To, err)
}
if !ok {
c.logger.Debug("dropping message due to rate limit",
"channel", t.ChannelID, "peer", t.To, "message", t.Message)
return nil
}
}

return ch.Send(ctx, t)
}
return fmt.Errorf("cannot send an unsupported message type %T", msg)
Expand Down
10 changes: 5 additions & 5 deletions internal/p2p/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ func (suite *ChannelTestSuite) SetupTest() {
suite.fakeClock = clockwork.NewFakeClock()
suite.client = New(
suite.descriptors,
func(ctx context.Context, descriptor *p2p.ChannelDescriptor) (p2p.Channel, error) {
func(_ctx context.Context, _descriptor *p2p.ChannelDescriptor) (p2p.Channel, error) {
return suite.p2pChannel, nil
},
WithClock(suite.fakeClock),
WithChanIDResolver(func(msg proto.Message) p2p.ChannelID {
WithChanIDResolver(func(_msg proto.Message) p2p.ChannelID {
return testChannelID
}),
)
Expand Down Expand Up @@ -185,7 +185,7 @@ func (suite *ChannelTestSuite) TestConsumeHandle() {
suite.p2pChannel.
On("Receive", ctx).
Once().
Return(func(ctx context.Context) p2p.ChannelIterator {
Return(func(_ctx context.Context) p2p.ChannelIterator {
return p2p.NewChannelIterator(outCh)
})
consumer := newMockConsumer(suite.T())
Expand Down Expand Up @@ -226,7 +226,7 @@ func (suite *ChannelTestSuite) TestConsumeResolve() {
suite.p2pChannel.
On("Receive", ctx).
Once().
Return(func(ctx context.Context) p2p.ChannelIterator {
Return(func(_ctx context.Context) p2p.ChannelIterator {
return p2p.NewChannelIterator(outCh)
})
resCh := suite.client.addPending(reqID)
Expand Down Expand Up @@ -278,7 +278,7 @@ func (suite *ChannelTestSuite) TestConsumeError() {
suite.p2pChannel.
On("Receive", ctx).
Once().
Return(func(ctx context.Context) p2p.ChannelIterator {
Return(func(_ctx context.Context) p2p.ChannelIterator {
return p2p.NewChannelIterator(outCh)
})
consumer := newMockConsumer(suite.T())
Expand Down
41 changes: 41 additions & 0 deletions internal/p2p/client/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"github.com/dashpay/tenderdash/libs/log"
)

// DefaultRecvBurstMultiplier tells how many times burst is bigger than the limit in recvRateLimitPerPeerHandler
const DefaultRecvBurstMultiplier = 10

var (
ErrRequestIDAttributeRequired = errors.New("envelope requestID attribute is required")
ErrResponseIDAttributeRequired = errors.New("envelope responseID attribute is required")
Expand Down Expand Up @@ -41,6 +44,19 @@ type (
allowedChannelIDs map[p2p.ChannelID]struct{}
next ConsumerHandler
}

// TokenNumberFunc is a function that returns number of tokens to consume for a given envelope
TokenNumberFunc func(*p2p.Envelope) uint

recvRateLimitPerPeerHandler struct {
RateLimit

// next is the next handler in the chain
next ConsumerHandler

// nTokens is a function that returns number of tokens to consume for a given envelope; if unsure, return 1
nTokensFunc TokenNumberFunc
}
)

// WithRecoveryMiddleware creates panic recovery middleware
Expand Down Expand Up @@ -75,6 +91,18 @@ func WithValidateMessageHandler(allowedChannelIDs []p2p.ChannelID) ConsumerMiddl
}
}

func WithRecvRateLimitPerPeerHandler(ctx context.Context, limit float64, nTokensFunc TokenNumberFunc, drop bool, logger log.Logger) ConsumerMiddlewareFunc {
return func(next ConsumerHandler) ConsumerHandler {
hd := &recvRateLimitPerPeerHandler{
RateLimit: *NewRateLimit(ctx, limit, drop, logger),
nTokensFunc: nTokensFunc,
}

hd.next = next
return hd
}
}

// HandlerWithMiddlewares is a function that wraps a handler in middlewares
func HandlerWithMiddlewares(handler ConsumerHandler, mws ...ConsumerMiddlewareFunc) ConsumerHandler {
for _, mw := range mws {
Expand Down Expand Up @@ -125,3 +153,16 @@ func (h *validateMessageHandler) Handle(ctx context.Context, client *Client, env
}
return h.next.Handle(ctx, client, envelope)
}

func (h *recvRateLimitPerPeerHandler) Handle(ctx context.Context, client *Client, envelope *p2p.Envelope) error {
accepted, err := h.RateLimit.Limit(ctx, envelope.From, int(h.nTokensFunc(envelope)))
if err != nil {
return fmt.Errorf("rate limit failed for peer '%s;: %w", envelope.From, err)
}
if !accepted {
h.logger.Debug("silently dropping message due to rate limit", "peer", envelope.From, "envelope", envelope)
return nil
}

return h.next.Handle(ctx, client, envelope)
}
2 changes: 1 addition & 1 deletion internal/p2p/client/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestErrorLoggerP2PMessageHandler(t *testing.T) {
wantErr: "error",
},
{
mockFn: func(hd *mockConsumer, logger *log.TestingLogger) {
mockFn: func(hd *mockConsumer, _logger *log.TestingLogger) {
hd.On("Handle", mock.Anything, mock.Anything, mock.Anything).
Once().
Return(nil)
Expand Down
Loading
Loading