Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxyd: Support per-RPC rate limits #3471

Merged
merged 3 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/three-weeks-kneel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@eth-optimism/proxyd': minor
---

Support per-method rate limiting
27 changes: 23 additions & 4 deletions proxyd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"strings"
"time"
)

type ServerConfig struct {
Expand Down Expand Up @@ -40,10 +41,28 @@ type MetricsConfig struct {
}

type RateLimitConfig struct {
RatePerSecond int `toml:"rate_per_second"`
ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"`
RatePerSecond int `toml:"rate_per_second"`
ExemptOrigins []string `toml:"exempt_origins"`
ExemptUserAgents []string `toml:"exempt_user_agents"`
ErrorMessage string `toml:"error_message"`
MethodOverrides map[string]*RateLimitMethodOverride `toml:"method_overrides"`
}

type RateLimitMethodOverride struct {
Limit int `toml:"limit"`
Interval TOMLDuration `toml:"interval"`
}

type TOMLDuration time.Duration

func (t *TOMLDuration) UnmarshalText(b []byte) error {
d, err := time.ParseDuration(string(b))
if err != nil {
return err
}

*t = TOMLDuration(d)
return nil
}

type BackendOptions struct {
Expand Down
70 changes: 60 additions & 10 deletions proxyd/integration_tests/rate_limit_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package integration_tests

import (
"encoding/json"
"net/http"
"os"
"testing"
Expand All @@ -17,6 +18,8 @@ type resWithCode struct {

const frontendOverLimitResponse = `{"error":{"code":-32016,"message":"over rate limit"},"id":null,"jsonrpc":"2.0"}`

var ethChainID = "eth_chainId"

func TestBackendMaxRPSLimit(t *testing.T) {
goodBackend := NewMockBackend(BatchedResponseHandler(200, goodResponse))
defer goodBackend.Close()
Expand All @@ -28,8 +31,7 @@ func TestBackendMaxRPSLimit(t *testing.T) {
shutdown, err := proxyd.Start(config)
require.NoError(t, err)
defer shutdown()

limitedRes, codes := spamReqs(t, client, 503)
limitedRes, codes := spamReqs(t, client, ethChainID, 503)
require.Equal(t, 2, codes[200])
require.Equal(t, 1, codes[503])
RequireEqualJSON(t, []byte(noBackendsResponse), limitedRes)
Expand All @@ -48,7 +50,7 @@ func TestFrontendMaxRPSLimit(t *testing.T) {

t.Run("non-exempt over limit", func(t *testing.T) {
client := NewProxydClient("http://127.0.0.1:8545")
limitedRes, codes := spamReqs(t, client, 429)
limitedRes, codes := spamReqs(t, client, ethChainID, 429)
require.Equal(t, 1, codes[429])
require.Equal(t, 2, codes[200])
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
Expand All @@ -58,15 +60,15 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
h := make(http.Header)
h.Set("User-Agent", "exempt_agent")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
_, codes := spamReqs(t, client, 429)
_, codes := spamReqs(t, client, ethChainID, 429)
require.Equal(t, 3, codes[200])
})

t.Run("exempt origin over limit", func(t *testing.T) {
h := make(http.Header)
h.Set("Origin", "exempt_origin")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
_, codes := spamReqs(t, client, 429)
_, codes := spamReqs(t, client, ethChainID, 429)
require.Equal(t, 3, codes[200])
})

Expand All @@ -77,24 +79,72 @@ func TestFrontendMaxRPSLimit(t *testing.T) {
h2.Set("X-Forwarded-For", "1.1.1.1")
client1 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h1)
client2 := NewProxydClientWithHeaders("http://127.0.0.1:8545", h2)
_, codes := spamReqs(t, client1, 429)
_, codes := spamReqs(t, client1, ethChainID, 429)
require.Equal(t, 1, codes[429])
require.Equal(t, 2, codes[200])
_, code, err := client2.SendRPC("eth_chainId", nil)
_, code, err := client2.SendRPC(ethChainID, nil)
require.Equal(t, 200, code)
require.NoError(t, err)
time.Sleep(time.Second)
_, code, err = client2.SendRPC("eth_chainId", nil)
_, code, err = client2.SendRPC(ethChainID, nil)
require.Equal(t, 200, code)
require.NoError(t, err)
})

time.Sleep(time.Second)

t.Run("RPC override", func(t *testing.T) {
client := NewProxydClient("http://127.0.0.1:8545")
limitedRes, codes := spamReqs(t, client, "eth_foobar", 429)
// use 2 and 1 here since the limit for eth_foobar is 1
require.Equal(t, 2, codes[429])
require.Equal(t, 1, codes[200])
RequireEqualJSON(t, []byte(frontendOverLimitResponse), limitedRes)
})

time.Sleep(time.Second)

t.Run("RPC override in batch", func(t *testing.T) {
client := NewProxydClient("http://127.0.0.1:8545")
req := NewRPCReq("123", "eth_foobar", nil)
out, code, err := client.SendBatchRPC(req, req, req)
require.NoError(t, err)
var res []proxyd.RPCRes
require.NoError(t, json.Unmarshal(out, &res))

expCode := proxyd.ErrOverRateLimit.Code
require.Equal(t, 200, code)
require.Equal(t, 3, len(res))
require.Nil(t, res[0].Error)
require.Equal(t, expCode, res[1].Error.Code)
require.Equal(t, expCode, res[2].Error.Code)
})

time.Sleep(time.Second)

t.Run("RPC override in batch exempt", func(t *testing.T) {
h := make(http.Header)
h.Set("User-Agent", "exempt_agent")
client := NewProxydClientWithHeaders("http://127.0.0.1:8545", h)
req := NewRPCReq("123", "eth_foobar", nil)
out, code, err := client.SendBatchRPC(req, req, req)
require.NoError(t, err)
var res []proxyd.RPCRes
require.NoError(t, json.Unmarshal(out, &res))

require.Equal(t, 200, code)
require.Equal(t, 3, len(res))
require.Nil(t, res[0].Error)
require.Nil(t, res[1].Error)
require.Nil(t, res[2].Error)
})
}

func spamReqs(t *testing.T, client *ProxydHTTPClient, limCode int) ([]byte, map[int]int) {
func spamReqs(t *testing.T, client *ProxydHTTPClient, method string, limCode int) ([]byte, map[int]int) {
resCh := make(chan *resWithCode)
for i := 0; i < 3; i++ {
go func() {
res, code, err := client.SendRPC("eth_chainId", nil)
res, code, err := client.SendRPC(method, nil)
require.NoError(t, err)
resCh <- &resWithCode{
code: code,
Expand Down
5 changes: 5 additions & 0 deletions proxyd/integration_tests/testdata/frontend_rate_limit.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@ backends = ["good"]

[rpc_method_mappings]
eth_chainId = "main"
eth_foobar = "main"

[rate_limit]
rate_per_second = 2
exempt_origins = ["exempt_origin"]
exempt_user_agents = ["exempt_agent"]
error_message = "over rate limit"

[rate_limit.method_overrides.eth_foobar]
limit = 1
interval = "1s"
87 changes: 68 additions & 19 deletions proxyd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ type Server struct {
timeout time.Duration
maxUpstreamBatchSize int
upgrader *websocket.Upgrader
lim limiter.Store
mainLim limiter.Store
overrideLims map[string]limiter.Store
limConfig RateLimitConfig
limExemptOrigins map[string]bool
limExemptUserAgents map[string]bool
Expand All @@ -59,6 +60,8 @@ type Server struct {
srvMu sync.Mutex
}

type limiterFunc func(method string) bool

func NewServer(
backendGroups map[string]*BackendGroup,
wsBackendGroup *BackendGroup,
Expand Down Expand Up @@ -89,12 +92,12 @@ func NewServer(
maxUpstreamBatchSize = defaultMaxUpstreamBatchSize
}

var lim limiter.Store
var mainLim limiter.Store
limExemptOrigins := make(map[string]bool)
limExemptUserAgents := make(map[string]bool)
if rateLimitConfig.RatePerSecond > 0 {
var err error
lim, err = memorystore.New(&memorystore.Config{
mainLim, err = memorystore.New(&memorystore.Config{
Tokens: uint64(rateLimitConfig.RatePerSecond),
Interval: time.Second,
})
Expand All @@ -109,7 +112,19 @@ func NewServer(
limExemptUserAgents[strings.ToLower(agent)] = true
}
} else {
lim, _ = noopstore.New()
mainLim, _ = noopstore.New()
}

overrideLims := make(map[string]limiter.Store)
for method, override := range rateLimitConfig.MethodOverrides {
var err error
overrideLims[method], err = memorystore.New(&memorystore.Config{
Tokens: uint64(override.Limit),
Interval: time.Duration(override.Interval),
})
if err != nil {
return nil, err
}
}

return &Server{
Expand All @@ -127,7 +142,8 @@ func NewServer(
upgrader: &websocket.Upgrader{
HandshakeTimeout: 5 * time.Second,
},
lim: lim,
mainLim: mainLim,
overrideLims: overrideLims,
limConfig: rateLimitConfig,
limExemptOrigins: limExemptOrigins,
limExemptUserAgents: limExemptUserAgents,
Expand Down Expand Up @@ -197,22 +213,37 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {

origin := r.Header.Get("Origin")
userAgent := r.Header.Get("User-Agent")
exemptOrigin := s.limExemptOrigins[strings.ToLower(origin)]
exemptUserAgent := s.limExemptUserAgents[strings.ToLower(userAgent)]
// Use XFF in context since it will automatically be replaced by the remote IP
xff := stripXFF(GetXForwardedFor(ctx))
var ok bool
if exemptOrigin || exemptUserAgent {
ok = true
} else {
if xff == "" {
log.Warn("rejecting request without XFF or remote IP")
ok = false
isUnlimitedOrigin := s.isUnlimitedOrigin(origin)
isUnlimitedUserAgent := s.isUnlimitedUserAgent(userAgent)

if xff == "" {
writeRPCError(ctx, w, nil, ErrInvalidRequest("request does not include a remote IP"))
return
}

isLimited := func(method string) bool {
if isUnlimitedOrigin || isUnlimitedUserAgent {
return false
}

var lim limiter.Store
if method == "" {
lim = s.mainLim
} else {
_, _, _, ok, _ = s.lim.Take(ctx, xff)
lim = s.overrideLims[method]
}

if lim == nil {
return false
}

_, _, _, ok, _ := lim.Take(ctx, xff)
return !ok
}
if !ok {

if isLimited("") {
rpcErr := ErrOverRateLimit.Clone()
rpcErr.Message = s.limConfig.ErrorMessage
RecordRPCError(ctx, BackendProxyd, "unknown", rpcErr)
Expand Down Expand Up @@ -271,7 +302,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
return
}

batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, true)
batchRes, batchContainsCached, err := s.handleBatchRPC(ctx, reqs, isLimited, true)
if err == context.DeadlineExceeded {
writeRPCError(ctx, w, nil, ErrGatewayTimeout)
return
Expand All @@ -287,7 +318,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
}

rawBody := json.RawMessage(body)
backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, false)
backendRes, cached, err := s.handleBatchRPC(ctx, []json.RawMessage{rawBody}, isLimited, false)
if err != nil {
writeRPCError(ctx, w, nil, ErrInternal)
return
Expand All @@ -296,7 +327,7 @@ func (s *Server) HandleRPC(w http.ResponseWriter, r *http.Request) {
writeRPCRes(ctx, w, backendRes[0])
}

func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isBatch bool) ([]*RPCRes, bool, error) {
func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isLimited limiterFunc, isBatch bool) ([]*RPCRes, bool, error) {
// A request set is transformed into groups of batches.
// Each batch group maps to a forwarded JSON-RPC batch request (subject to maxUpstreamBatchSize constraints)
// A groupID is used to decouple Requests that have duplicate ID so they're not part of the same batch that's
Expand Down Expand Up @@ -347,6 +378,16 @@ func (s *Server) handleBatchRPC(ctx context.Context, reqs []json.RawMessage, isB
continue
}

// Take rate limit for specific methods.
// NOTE: eventually, this should apply to all batch requests. However,
// since we don't have data right now on the size of each batch, we
// only apply this to the methods that have an additional rate limit.
if _, ok := s.overrideLims[parsedReq.Method]; ok && isLimited(parsedReq.Method) {
RecordRPCError(ctx, BackendProxyd, parsedReq.Method, ErrOverRateLimit)
responses[i] = NewRPCErrorRes(parsedReq.ID, ErrOverRateLimit)
continue
}

id := string(parsedReq.ID)
// If this is a duplicate Request ID, move the Request to a new batchGroup
ids[id]++
Expand Down Expand Up @@ -494,6 +535,14 @@ func (s *Server) populateContext(w http.ResponseWriter, r *http.Request) context
)
}

func (s *Server) isUnlimitedOrigin(origin string) bool {
return s.limExemptOrigins[strings.ToLower(origin)]
}

func (s *Server) isUnlimitedUserAgent(origin string) bool {
return s.limExemptUserAgents[strings.ToLower(origin)]
}

func setCacheHeader(w http.ResponseWriter, cached bool) {
if cached {
w.Header().Set(cacheStatusHdr, "HIT")
Expand Down