From d16a3b18f1d73274b8101854355aefffad8d4d42 Mon Sep 17 00:00:00 2001 From: Mike Jensen Date: Mon, 1 May 2023 15:54:22 -0600 Subject: [PATCH] Rate limit all unauthenticated HTTP endpoints (#24623) * Rate limit all unauthenticated HTTP endpoints This commit is an extension to what was done in #172. And is designed to fix https://github.com/gravitational/teleport/issues/4330 and https://github.com/gravitational/teleport-private/issues/403. Rather than audit endpoints and choose what endpoints should be rate limited, this commit proposes that for safety and reduced cognitive load, all unauthenticated endpoints become rate limited. The primary concern in this type of change would be if our rate limit becomes too aggressive for general use. There are two considered strategies to make sure this does not become impacting: 1. Adjust the rate limiter so the rate limit becomes endpoint specific. This would avoid the need to consider how activity on one endpoint effects another. 2. Accept that rate limit interactions are possible and instead ensure rate limits are high enough to avoid this concern. This commit chooses option #2. While #1 has advantages, particularly as endpoints and new use cases are added. #2 provides the strictest and safest rate limits. Our rate limits were configured to: period: 1 min avg rate: 10 burst rate: 20 In order to build a safety buffer with option #2 those allowed rates were doubled. Additionally the ability to avoid rate limits by authenticating your request (even if the endpoint is otherwise unauthenticated) was added. This is particularly useful for the `ping` endpoint which may have high levels of activity on large clusters, but which has a portion of that activity over authenticated requests. * Add additional `High` Rate Limiting This new `High` rate limit is designed for endpoints which are only CPU bound (and thus don't have as significant of DoS risks). Initially this was motivated for `ping` and `find` due to the concern that these endpoints are used unauthenticated at login, and potential NAT's may result in very high rates from single egress IP's. In my testing on my laptop, all of these endpoints can easily get 640/req/sec on a single core within a VM. Setting the maximum of 480 burst and 120 continuous should both ensure that no single source utilizes all the CPU, as well as build in additional safety margins while providing a layer of protection. * Fix for missing error check --- lib/auth/grpcserver_test.go | 2 +- lib/auth/middleware.go | 6 +- lib/defaults/defaults.go | 25 ++++-- lib/web/apiserver.go | 147 ++++++++++++++++++++++++-------- lib/web/apiserver_login_test.go | 2 +- 5 files changed, 135 insertions(+), 47 deletions(-) diff --git a/lib/auth/grpcserver_test.go b/lib/auth/grpcserver_test.go index 7c808ece7ca88..d9c984761acfb 100644 --- a/lib/auth/grpcserver_test.go +++ b/lib/auth/grpcserver_test.go @@ -3167,7 +3167,7 @@ func TestCustomRateLimiting(t *testing.T) { }, { name: "RPC CreateAuthenticateChallenge", - burst: defaults.LimiterPasswordlessBurst, + burst: defaults.LimiterBurst, fn: func(clt *Client) error { _, err := clt.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{}) return err diff --git a/lib/auth/middleware.go b/lib/auth/middleware.go index 0cbf38b6a9966..3c7ef6c985c99 100644 --- a/lib/auth/middleware.go +++ b/lib/auth/middleware.go @@ -383,9 +383,9 @@ func getCustomRate(endpoint string) *ratelimit.RateSet { return rates // Passwordless RPCs (potential unauthenticated challenge generation). case "/proto.AuthService/CreateAuthenticateChallenge": - const period = defaults.LimiterPasswordlessPeriod - const average = defaults.LimiterPasswordlessAverage - const burst = defaults.LimiterPasswordlessBurst + const period = defaults.LimiterPeriod + const average = defaults.LimiterAverage + const burst = defaults.LimiterBurst rates := ratelimit.NewRateSet() if err := rates.Add(period, average, burst); err != nil { log.WithError(err).Debugf("Failed to define a custom rate for rpc method %q, using default rate", endpoint) diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index e331bdcdaef40..0e73470ff7f58 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -362,15 +362,24 @@ const ( LimiterMaxConcurrentSignatures = 10 ) -// Default rate limits for unauthenticated passwordless endpoints. +// Default rate limits for unauthenticated endpoints. const ( - // LimiterPasswordlessPeriod is the default period for passwordless limiters. - LimiterPasswordlessPeriod = 1 * time.Minute - // LimiterPasswordlessAverage is the default average for passwordless - // limiters. - LimiterPasswordlessAverage = 10 - // LimiterPasswordlessBurst is the default burst for passwordless limiters. - LimiterPasswordlessBurst = 20 + // LimiterPeriod is the default period for unauthenticated limiters. + LimiterPeriod = 1 * time.Minute + // LimiterAverage is the default average for unauthenticated limiters. + LimiterAverage = 20 + // LimiterBurst is the default burst for unauthenticated limiters. + LimiterBurst = 40 +) + +// Default high rate limits for unauthenticated endpoints that are CPU constrained. +const ( + // LimiterHighPeriod is the default period for high rate unauthenticated limiters. + LimiterHighPeriod = 1 * time.Minute + // LimiterHighAverage is the default average for high rate unauthenticated limiters. + LimiterHighAverage = 120 + // LimiterHighBurst is the default burst for high rate unauthenticated limiters. + LimiterHighBurst = 480 ) const ( diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index e3b3bfb14b8ee..1e8d4f8508b14 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -109,6 +109,7 @@ type Handler struct { sessionStreamPollPeriod time.Duration clock clockwork.Clock limiter *limiter.RateLimiter + highLimiter *limiter.RateLimiter healthCheckAppServer healthCheckAppServerFunc // sshPort specifies the SSH proxy port extracted // from configuration @@ -337,9 +338,24 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { h.limiter, err = limiter.NewRateLimiter(limiter.Config{ Rates: []limiter.Rate{ { - Period: defaults.LimiterPasswordlessPeriod, - Average: defaults.LimiterPasswordlessAverage, - Burst: defaults.LimiterPasswordlessBurst, + Period: defaults.LimiterPeriod, + Average: defaults.LimiterAverage, + Burst: defaults.LimiterBurst, + }, + }, + MaxConnections: defaults.LimiterMaxConnections, + MaxNumberOfUsers: defaults.LimiterMaxConcurrentUsers, + }) + if err != nil { + return nil, trace.Wrap(err) + } + // highLimiter is used for endpoints which are only CPU constrained and require high request rates + h.highLimiter, err = limiter.NewRateLimiter(limiter.Config{ + Rates: []limiter.Rate{ + { + Period: defaults.LimiterHighPeriod, + Average: defaults.LimiterHighAverage, + Burst: defaults.LimiterHighBurst, }, }, MaxConnections: defaults.LimiterMaxConnections, @@ -373,7 +389,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { return nil, trace.BadParameter("failed parsing index.html template: %v", err) } - h.Handle("GET", "/web/config.js", httplib.MakeHandler(h.getWebConfig)) + h.Handle("GET", "/web/config.js", h.WithUnauthenticatedLimiter(h.getWebConfig)) } routingHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -509,9 +525,9 @@ func (h *Handler) bindMinimalEndpoints() { // find is like ping, but is faster because it is optimized for servers // and does not fetch the data that servers don't need, e.g. // OIDC connectors and auth preferences - h.GET("/webapi/find", httplib.MakeHandler(h.find)) + h.GET("/webapi/find", h.WithUnauthenticatedHighLimiter(h.find)) // Issue host credentials. - h.POST("/webapi/host/credentials", httplib.MakeHandler(h.hostCredentials)) + h.POST("/webapi/host/credentials", h.WithUnauthenticatedHighLimiter(h.hostCredentials)) } // bindDefaultEndpoints binds the default endpoints for the web API. @@ -522,23 +538,23 @@ func (h *Handler) bindDefaultEndpoints() { // endpoint returns the default authentication method and configuration that // the server supports. the /webapi/ping/:connector endpoint can be used to // query the authentication configuration for a specific connector. - h.GET("/webapi/ping", httplib.MakeHandler(h.ping)) - h.GET("/webapi/ping/:connector", httplib.MakeHandler(h.pingWithConnector)) + h.GET("/webapi/ping", h.WithUnauthenticatedHighLimiter(h.ping)) + h.GET("/webapi/ping/:connector", h.WithUnauthenticatedHighLimiter(h.pingWithConnector)) // Unauthenticated access to JWT public keys. - h.GET("/.well-known/jwks.json", httplib.MakeHandler(h.jwks)) + h.GET("/.well-known/jwks.json", h.WithUnauthenticatedHighLimiter(h.jwks)) // Unauthenticated access to the message of the day - h.GET("/webapi/motd", httplib.MakeHandler(h.motd)) + h.GET("/webapi/motd", h.WithHighLimiter(h.motd)) // Unauthenticated access to retrieving the script used to install // Teleport - h.GET("/webapi/scripts/installer/:name", httplib.MakeHandler(h.installer)) + h.GET("/webapi/scripts/installer/:name", h.WithLimiter(h.installer)) // desktop access configuration scripts - h.GET("/webapi/scripts/desktop-access/install-ad-ds.ps1", httplib.MakeHandler(h.desktopAccessScriptInstallADDSHandle)) - h.GET("/webapi/scripts/desktop-access/install-ad-cs.ps1", httplib.MakeHandler(h.desktopAccessScriptInstallADCSHandle)) - h.GET("/webapi/scripts/desktop-access/configure/:token/configure-ad.ps1", httplib.MakeHandler(h.desktopAccessScriptConfigureHandle)) + h.GET("/webapi/scripts/desktop-access/install-ad-ds.ps1", h.WithLimiter(h.desktopAccessScriptInstallADDSHandle)) + h.GET("/webapi/scripts/desktop-access/install-ad-cs.ps1", h.WithLimiter(h.desktopAccessScriptInstallADCSHandle)) + h.GET("/webapi/scripts/desktop-access/configure/:token/configure-ad.ps1", h.WithLimiter(h.desktopAccessScriptConfigureHandle)) // Forwards traces to the configured upstream collector h.POST("/webapi/traces", h.WithAuth(h.traces)) @@ -557,7 +573,7 @@ func (h *Handler) bindDefaultEndpoints() { // We have an overlap route here, please see godoc of handleGetUserOrResetToken // h.GET("/webapi/users/:username", h.WithAuth(h.getUserHandle)) - // h.GET("/webapi/users/password/token/:token", httplib.MakeHandler(h.getResetPasswordTokenHandle)) + // h.GET("/webapi/users/password/token/:token", h.WithLimiter(h.getResetPasswordTokenHandle)) h.GET("/webapi/users/*wildcard", h.handleGetUserOrResetToken) h.PUT("/webapi/users/password/token", httplib.WithCSRFProtection(h.changeUserAuthentication)) @@ -566,7 +582,7 @@ func (h *Handler) bindDefaultEndpoints() { h.POST("/webapi/users/privilege/token", h.WithAuth(h.createPrivilegeTokenHandle)) // Issues SSH temp certificates based on 2FA access creds - h.POST("/webapi/ssh/certs", h.WithLimiter(h.createSSHCert)) + h.POST("/webapi/ssh/certs", h.WithUnauthenticatedLimiter(h.createSSHCert)) // list available sites h.GET("/webapi/sites", h.WithAuth(h.getClusters)) @@ -618,9 +634,9 @@ func (h *Handler) bindDefaultEndpoints() { h.POST("/webapi/token", h.WithAuth(h.createTokenHandle)) // join scripts - h.GET("/scripts/:token/install-node.sh", httplib.MakeHandler(h.getNodeJoinScriptHandle)) - h.GET("/scripts/:token/install-app.sh", httplib.MakeHandler(h.getAppJoinScriptHandle)) - h.GET("/scripts/:token/install-database.sh", httplib.MakeHandler(h.getDatabaseJoinScriptHandle)) + h.GET("/scripts/:token/install-node.sh", h.WithLimiter(h.getNodeJoinScriptHandle)) + h.GET("/scripts/:token/install-app.sh", h.WithLimiter(h.getAppJoinScriptHandle)) + h.GET("/scripts/:token/install-database.sh", h.WithLimiter(h.getDatabaseJoinScriptHandle)) // web context h.GET("/webapi/sites/:site/context", h.WithClusterAuth(h.getUserContext)) h.GET("/webapi/sites/:site/resources/check", h.WithClusterAuth(h.checkAccessToRegisteredResource)) @@ -647,12 +663,12 @@ func (h *Handler) bindDefaultEndpoints() { // MFA public endpoints. h.POST("/webapi/sites/:site/mfa/required", h.WithClusterAuth(h.isMFARequired)) h.POST("/webapi/mfa/login/begin", h.WithLimiter(h.mfaLoginBegin)) - h.POST("/webapi/mfa/login/finish", httplib.MakeHandler(h.mfaLoginFinish)) - h.POST("/webapi/mfa/login/finishsession", httplib.MakeHandler(h.mfaLoginFinishSession)) - h.DELETE("/webapi/mfa/token/:token/devices/:devicename", httplib.MakeHandler(h.deleteMFADeviceWithTokenHandle)) - h.GET("/webapi/mfa/token/:token/devices", httplib.MakeHandler(h.getMFADevicesWithTokenHandle)) - h.POST("/webapi/mfa/token/:token/authenticatechallenge", httplib.MakeHandler(h.createAuthenticateChallengeWithTokenHandle)) - h.POST("/webapi/mfa/token/:token/registerchallenge", httplib.MakeHandler(h.createRegisterChallengeWithTokenHandle)) + h.POST("/webapi/mfa/login/finish", h.WithLimiter(h.mfaLoginFinish)) + h.POST("/webapi/mfa/login/finishsession", h.WithLimiter(h.mfaLoginFinishSession)) + h.DELETE("/webapi/mfa/token/:token/devices/:devicename", h.WithLimiter(h.deleteMFADeviceWithTokenHandle)) + h.GET("/webapi/mfa/token/:token/devices", h.WithLimiter(h.getMFADevicesWithTokenHandle)) + h.POST("/webapi/mfa/token/:token/authenticatechallenge", h.WithLimiter(h.createAuthenticateChallengeWithTokenHandle)) + h.POST("/webapi/mfa/token/:token/registerchallenge", h.WithLimiter(h.createRegisterChallengeWithTokenHandle)) // MFA private endpoints. h.GET("/webapi/mfa/devices", h.WithAuth(h.getMFADevicesHandle)) @@ -661,7 +677,7 @@ func (h *Handler) bindDefaultEndpoints() { h.POST("/webapi/mfa/authenticatechallenge/password", h.WithAuth(h.createAuthenticateChallengeWithPassword)) // trusted clusters - h.POST("/webapi/trustedclusters/validate", httplib.MakeHandler(h.validateTrustedCluster)) + h.POST("/webapi/trustedclusters/validate", h.WithUnauthenticatedLimiter(h.validateTrustedCluster)) // User Status (used by client to check if user session is valid) h.GET("/webapi/user/status", h.WithAuth(h.getUserStatus)) @@ -715,10 +731,10 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/thumbprint", h.WithLimiter(h.thumbprint)) // Connection upgrades. - h.GET("/webapi/connectionupgrade", httplib.MakeHandler(h.connectionUpgrade)) + h.GET("/webapi/connectionupgrade", h.WithLimiter(h.connectionUpgrade)) // create user events. - h.POST("/webapi/precapture", h.WithLimiter(h.createPreUserEventHandle)) + h.POST("/webapi/precapture", h.WithUnauthenticatedLimiter(h.createPreUserEventHandle)) // create authenticated user events. h.POST("/webapi/capture", h.WithAuth(h.createUserEventHandle)) @@ -3537,26 +3553,70 @@ func (h *Handler) WithAuthCookieAndCSRF(fn ContextHandler) httprouter.Handle { return httplib.WithCSRFProtection(f) } +// WithUnauthenticatedLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated requests. +// This is a good default to use as both Cluster and User auth are checked here, but `WithLimiter` can be used if +// you're certain that no authenticated requests will be made. +func (h *Handler) WithUnauthenticatedLimiter(fn httplib.HandlerFunc) httprouter.Handle { + return h.unauthenticatedLimiterFunc(fn, h.WithLimiterHandlerFunc) +} + +// WithUnauthenticatedHighLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated +// requests. This is similar to WithUnauthenticatedLimiter, however this one allows a much higher rate limit. +// This higher rate limit should only be used on endpoints which are only CPU constrained +// (no file or other resources used). +func (h *Handler) WithUnauthenticatedHighLimiter(fn httplib.HandlerFunc) httprouter.Handle { + return h.unauthenticatedLimiterFunc(fn, h.WithHighLimiterHandlerFunc) +} + +func (h *Handler) unauthenticatedLimiterFunc(fn httplib.HandlerFunc, rateFunc func(fn httplib.HandlerFunc) httplib.HandlerFunc) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + if _, _, err := h.authenticateRequestWithCluster(w, r, p); err != nil { + // retry with user auth + if _, err = h.AuthenticateRequest(w, r, true /* check token */); err != nil { + // no auth passed, limit request + return rateFunc(fn)(w, r, p) + } + } + // auth passed, call directly + return fn(w, r, p) + }) +} + // WithLimiter adds IP-based rate limiting to fn. +// Limits are applied to all requests, authenticated or not. func (h *Handler) WithLimiter(fn httplib.HandlerFunc) httprouter.Handle { return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { return h.WithLimiterHandlerFunc(fn)(w, r, p) }) } +// WithHighLimiter adds high rate IP-based rate limiting to fn. +// This should only be used on functions which are CPU constrained, and don't use disk or other services. +// Limits are applied to all requests, authenticated or not. +func (h *Handler) WithHighLimiter(fn httplib.HandlerFunc) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + return h.WithHighLimiterHandlerFunc(fn)(w, r, p) + }) +} + // WithLimiterHandlerFunc adds IP-based rate limiting to a HandlerFunc. This // should be used when you need to nest this inside another HandlerFunc. func (h *Handler) WithLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.HandlerFunc { return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { - remote, _, err := net.SplitHostPort(r.RemoteAddr) + err := rateLimitRequest(r, h.limiter) if err != nil { return nil, trace.Wrap(err) } - err = h.limiter.RegisterRequest(remote, nil /* customRate */) - // MaxRateError doesn't play well with errors.Is, hence the cast. - if _, ok := err.(*ratelimit.MaxRateError); ok { - return nil, trace.LimitExceeded(err.Error()) - } + return fn(w, r, p) + } +} + +// WithHighLimiterHandlerFunc adds IP-based rate limiting to a HandlerFunc. This is similar to WithLimiterHandlerFunc +// but provides a higher rate limit. This should only be used for requests which are only CPU bound (no disk or other +// resources used). +func (h *Handler) WithHighLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + err := rateLimitRequest(r, h.highLimiter) if err != nil { return nil, trace.Wrap(err) } @@ -3564,6 +3624,20 @@ func (h *Handler) WithLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.Handler } } +func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error { + remote, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return trace.Wrap(err) + } + + err = limiter.RegisterRequest(remote, nil /* customRate */) + // MaxRateError doesn't play well with errors.Is, hence the type assertion. + if _, ok := err.(*ratelimit.MaxRateError); ok { + return trace.LimitExceeded(err.Error()) + } + return trace.Wrap(err) +} + // AuthenticateRequest authenticates request using combination of a session cookie // and bearer token func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { @@ -3800,6 +3874,11 @@ func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, resp // GET /webapi/sites/:site/auth/export?type= // GET /webapi/auth/export?type= func (h *Handler) authExportPublic(w http.ResponseWriter, r *http.Request, p httprouter.Params) { + err := rateLimitRequest(r, h.limiter) + if err != nil { + http.Error(w, err.Error(), trace.ErrorToCode(err)) + return + } authorities, err := client.ExportAuthorities( r.Context(), h.GetProxyClient(), diff --git a/lib/web/apiserver_login_test.go b/lib/web/apiserver_login_test.go index 6231c6dafb071..e2a3aa1d6a2c7 100644 --- a/lib/web/apiserver_login_test.go +++ b/lib/web/apiserver_login_test.go @@ -305,7 +305,7 @@ func TestAuthenticate_rateLimiting(t *testing.T) { }{ { name: "/webapi/mfa/login/begin", - burst: defaults.LimiterPasswordlessBurst, + burst: defaults.LimiterBurst, fn: func(clt *client.WebClient) error { ep := clt.Endpoint("webapi", "mfa", "login", "begin") _, err := clt.PostJSON(ctx, ep, &client.MFAChallengeRequest{})