Skip to content

Commit

Permalink
Add additional High Rate Limiting
Browse files Browse the repository at this point in the history
This new `High` rate limit is designed for endpoints which are only CPU bound (and thus don't have as significant of DoS risks).
Initially this was motivated for `ping` and `find` due to the concern that these endpoints are used unauthenticated at login, and potential NAT's may result in very high rates from single egress IP's.
In my testing on my laptop, all of these endpoints can easily get 640/req/sec on a single core within a VM.  Setting the maximum of 480 burst and 120 continuous should both ensure that no single source utilizes all the CPU, as well as build in additional safety margins while providing a layer of protection.
  • Loading branch information
jentfoo committed Apr 27, 2023
1 parent 47ac329 commit 3d31f86
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 11 deletions.
10 changes: 10 additions & 0 deletions lib/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,16 @@ const (
LimiterBurst = 40
)

// Default high rate limits for unauthenticated endpoints that are CPU constrained.
const (
// LimiterHighPeriod is the default period for high rate unauthenticated limiters.
LimiterHighPeriod = 1 * time.Minute
// LimiterHighAverage is the default average for high rate unauthenticated limiters.
LimiterHighAverage = 120
// LimiterHighBurst is the default burst for high rate unauthenticated limiters.
LimiterHighBurst = 480
)

const (
// HostCertCacheSize is the number of host certificates to cache at any moment.
HostCertCacheSize = 4000
Expand Down
69 changes: 58 additions & 11 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ type Handler struct {
sessionStreamPollPeriod time.Duration
clock clockwork.Clock
limiter *limiter.RateLimiter
highLimiter *limiter.RateLimiter
healthCheckAppServer healthCheckAppServerFunc
// sshPort specifies the SSH proxy port extracted
// from configuration
Expand Down Expand Up @@ -345,6 +346,18 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
MaxConnections: defaults.LimiterMaxConnections,
MaxNumberOfUsers: defaults.LimiterMaxConcurrentUsers,
})
// highLimiter is used for endpoints which are only CPU constrained and require high request rates
h.highLimiter, err = limiter.NewRateLimiter(limiter.Config{
Rates: []limiter.Rate{
{
Period: defaults.LimiterHighPeriod,
Average: defaults.LimiterHighAverage,
Burst: defaults.LimiterHighBurst,
},
},
MaxConnections: defaults.LimiterMaxConnections,
MaxNumberOfUsers: defaults.LimiterMaxConcurrentUsers,
})
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -509,9 +522,9 @@ func (h *Handler) bindMinimalEndpoints() {
// find is like ping, but is faster because it is optimized for servers
// and does not fetch the data that servers don't need, e.g.
// OIDC connectors and auth preferences
h.GET("/webapi/find", h.WithUnauthenticatedLimiter(h.find))
h.GET("/webapi/find", h.WithUnauthenticatedHighLimiter(h.find))
// Issue host credentials.
h.POST("/webapi/host/credentials", h.WithUnauthenticatedLimiter(h.hostCredentials))
h.POST("/webapi/host/credentials", h.WithUnauthenticatedHighLimiter(h.hostCredentials))
}

// bindDefaultEndpoints binds the default endpoints for the web API.
Expand All @@ -522,14 +535,14 @@ func (h *Handler) bindDefaultEndpoints() {
// endpoint returns the default authentication method and configuration that
// the server supports. the /webapi/ping/:connector endpoint can be used to
// query the authentication configuration for a specific connector.
h.GET("/webapi/ping", h.WithUnauthenticatedLimiter(h.ping))
h.GET("/webapi/ping/:connector", h.WithUnauthenticatedLimiter(h.pingWithConnector))
h.GET("/webapi/ping", h.WithUnauthenticatedHighLimiter(h.ping))
h.GET("/webapi/ping/:connector", h.WithUnauthenticatedHighLimiter(h.pingWithConnector))

// Unauthenticated access to JWT public keys.
h.GET("/.well-known/jwks.json", h.WithUnauthenticatedLimiter(h.jwks))
h.GET("/.well-known/jwks.json", h.WithUnauthenticatedHighLimiter(h.jwks))

// Unauthenticated access to the message of the day
h.GET("/webapi/motd", h.WithLimiter(h.motd))
h.GET("/webapi/motd", h.WithHighLimiter(h.motd))

// Unauthenticated access to retrieving the script used to install
// Teleport
Expand Down Expand Up @@ -3651,12 +3664,24 @@ func (h *Handler) WithAuthCookieAndCSRF(fn ContextHandler) httprouter.Handle {
// This is a good default to use as both Cluster and User auth are checked here, but `WithLimiter` can be used if
// you're certain that no authenticated requests will be made.
func (h *Handler) WithUnauthenticatedLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return h.unauthenticatedLimiterFunc(fn, h.WithLimiterHandlerFunc)
}

// WithUnauthenticatedHighLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated
// requests. This is similar to WithUnauthenticatedLimiter, however this one allows a much higher rate limit.
// This higher rate limit should only be used on endpoints which are only CPU constrained
// (no file or other resources used).
func (h *Handler) WithUnauthenticatedHighLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return h.unauthenticatedLimiterFunc(fn, h.WithHighLimiterHandlerFunc)
}

func (h *Handler) unauthenticatedLimiterFunc(fn httplib.HandlerFunc, rateFunc func(fn httplib.HandlerFunc) httplib.HandlerFunc) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
if _, _, err := h.authenticateRequestWithCluster(w, r, p); err != nil {
// retry with user auth
if _, err = h.AuthenticateRequest(w, r, true /* check token */); err != nil {
// no auth passed, limit request
return h.WithLimiterHandlerFunc(fn)(w, r, p)
return rateFunc(fn)(w, r, p)
}
}
// auth passed, call directly
Expand All @@ -3672,25 +3697,47 @@ func (h *Handler) WithLimiter(fn httplib.HandlerFunc) httprouter.Handle {
})
}

// WithHighLimiter adds high rate IP-based rate limiting to fn.
// This should only be used on functions which are CPU constrained, and don't use disk or other services.
// Limits are applied to all requests, authenticated or not.
func (h *Handler) WithHighLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
return h.WithHighLimiterHandlerFunc(fn)(w, r, p)
})
}

// WithLimiterHandlerFunc adds IP-based rate limiting to a HandlerFunc. This
// should be used when you need to nest this inside another HandlerFunc.
func (h *Handler) WithLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
err := h.rateLimitRequest(r)
err := rateLimitRequest(r, h.limiter)
if err != nil {
return nil, trace.Wrap(err)
}
return fn(w, r, p)
}
}

// WithHighLimiterHandlerFunc adds IP-based rate limiting to a HandlerFunc. This is similar to WithLimiterHandlerFunc
// but provides a higher rate limit. This should only be used for requests which are only CPU bound (no disk or other
// resources used).
func (h *Handler) WithHighLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
err := rateLimitRequest(r, h.highLimiter)
if err != nil {
return nil, trace.Wrap(err)
}
return fn(w, r, p)
}
}

func (h *Handler) rateLimitRequest(r *http.Request) error {
func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error {
remote, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return trace.Wrap(err)
}

err = h.limiter.RegisterRequest(remote, nil /* customRate */)
err = limiter.RegisterRequest(remote, nil /* customRate */)
// MaxRateError doesn't play well with errors.Is, hence the type assertion.
if _, ok := err.(*ratelimit.MaxRateError); ok {
return trace.LimitExceeded(err.Error())
Expand Down Expand Up @@ -3934,7 +3981,7 @@ func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, resp
// GET /webapi/sites/:site/auth/export?type=<auth type>
// GET /webapi/auth/export?type=<auth type>
func (h *Handler) authExportPublic(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
err := h.rateLimitRequest(r)
err := rateLimitRequest(r, h.limiter)
if err != nil {
http.Error(w, err.Error(), trace.ErrorToCode(err))
return
Expand Down

0 comments on commit 3d31f86

Please sign in to comment.