Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rate limit all unauthenticated HTTP endpoints #24623

Merged
merged 3 commits into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/auth/grpcserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3167,7 +3167,7 @@ func TestCustomRateLimiting(t *testing.T) {
},
{
name: "RPC CreateAuthenticateChallenge",
burst: defaults.LimiterPasswordlessBurst,
burst: defaults.LimiterBurst,
fn: func(clt *Client) error {
_, err := clt.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{})
return err
Expand Down
6 changes: 3 additions & 3 deletions lib/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ func getCustomRate(endpoint string) *ratelimit.RateSet {
return rates
// Passwordless RPCs (potential unauthenticated challenge generation).
case "/proto.AuthService/CreateAuthenticateChallenge":
const period = defaults.LimiterPasswordlessPeriod
const average = defaults.LimiterPasswordlessAverage
const burst = defaults.LimiterPasswordlessBurst
const period = defaults.LimiterPeriod
const average = defaults.LimiterAverage
const burst = defaults.LimiterBurst
rates := ratelimit.NewRateSet()
if err := rates.Add(period, average, burst); err != nil {
log.WithError(err).Debugf("Failed to define a custom rate for rpc method %q, using default rate", endpoint)
Expand Down
25 changes: 17 additions & 8 deletions lib/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,24 @@ const (
LimiterMaxConcurrentSignatures = 10
)

// Default rate limits for unauthenticated passwordless endpoints.
// Default rate limits for unauthenticated endpoints.
const (
// LimiterPasswordlessPeriod is the default period for passwordless limiters.
LimiterPasswordlessPeriod = 1 * time.Minute
// LimiterPasswordlessAverage is the default average for passwordless
// limiters.
LimiterPasswordlessAverage = 10
// LimiterPasswordlessBurst is the default burst for passwordless limiters.
LimiterPasswordlessBurst = 20
// LimiterPeriod is the default period for unauthenticated limiters.
LimiterPeriod = 1 * time.Minute
// LimiterAverage is the default average for unauthenticated limiters.
LimiterAverage = 20
// LimiterBurst is the default burst for unauthenticated limiters.
LimiterBurst = 40
)

// Default high rate limits for unauthenticated endpoints that are CPU constrained.
const (
// LimiterHighPeriod is the default period for high rate unauthenticated limiters.
LimiterHighPeriod = 1 * time.Minute
// LimiterHighAverage is the default average for high rate unauthenticated limiters.
LimiterHighAverage = 120
// LimiterHighBurst is the default burst for high rate unauthenticated limiters.
LimiterHighBurst = 480
)

const (
Expand Down
147 changes: 113 additions & 34 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ type Handler struct {
sessionStreamPollPeriod time.Duration
clock clockwork.Clock
limiter *limiter.RateLimiter
highLimiter *limiter.RateLimiter
healthCheckAppServer healthCheckAppServerFunc
// sshPort specifies the SSH proxy port extracted
// from configuration
Expand Down Expand Up @@ -337,9 +338,24 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
h.limiter, err = limiter.NewRateLimiter(limiter.Config{
Rates: []limiter.Rate{
{
Period: defaults.LimiterPasswordlessPeriod,
Average: defaults.LimiterPasswordlessAverage,
Burst: defaults.LimiterPasswordlessBurst,
Period: defaults.LimiterPeriod,
Average: defaults.LimiterAverage,
Burst: defaults.LimiterBurst,
},
},
MaxConnections: defaults.LimiterMaxConnections,
MaxNumberOfUsers: defaults.LimiterMaxConcurrentUsers,
})
if err != nil {
return nil, trace.Wrap(err)
}
// highLimiter is used for endpoints which are only CPU constrained and require high request rates
h.highLimiter, err = limiter.NewRateLimiter(limiter.Config{
Rates: []limiter.Rate{
{
Period: defaults.LimiterHighPeriod,
Average: defaults.LimiterHighAverage,
Burst: defaults.LimiterHighBurst,
},
},
MaxConnections: defaults.LimiterMaxConnections,
Expand Down Expand Up @@ -373,7 +389,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
return nil, trace.BadParameter("failed parsing index.html template: %v", err)
}

h.Handle("GET", "/web/config.js", httplib.MakeHandler(h.getWebConfig))
h.Handle("GET", "/web/config.js", h.WithUnauthenticatedLimiter(h.getWebConfig))
}

routingHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -509,9 +525,9 @@ func (h *Handler) bindMinimalEndpoints() {
// find is like ping, but is faster because it is optimized for servers
// and does not fetch the data that servers don't need, e.g.
// OIDC connectors and auth preferences
h.GET("/webapi/find", httplib.MakeHandler(h.find))
h.GET("/webapi/find", h.WithUnauthenticatedHighLimiter(h.find))
// Issue host credentials.
h.POST("/webapi/host/credentials", httplib.MakeHandler(h.hostCredentials))
h.POST("/webapi/host/credentials", h.WithUnauthenticatedHighLimiter(h.hostCredentials))
}

// bindDefaultEndpoints binds the default endpoints for the web API.
Expand All @@ -522,23 +538,23 @@ func (h *Handler) bindDefaultEndpoints() {
// endpoint returns the default authentication method and configuration that
// the server supports. the /webapi/ping/:connector endpoint can be used to
// query the authentication configuration for a specific connector.
h.GET("/webapi/ping", httplib.MakeHandler(h.ping))
h.GET("/webapi/ping/:connector", httplib.MakeHandler(h.pingWithConnector))
h.GET("/webapi/ping", h.WithUnauthenticatedHighLimiter(h.ping))
h.GET("/webapi/ping/:connector", h.WithUnauthenticatedHighLimiter(h.pingWithConnector))

// Unauthenticated access to JWT public keys.
h.GET("/.well-known/jwks.json", httplib.MakeHandler(h.jwks))
h.GET("/.well-known/jwks.json", h.WithUnauthenticatedHighLimiter(h.jwks))

// Unauthenticated access to the message of the day
h.GET("/webapi/motd", httplib.MakeHandler(h.motd))
h.GET("/webapi/motd", h.WithHighLimiter(h.motd))

// Unauthenticated access to retrieving the script used to install
// Teleport
h.GET("/webapi/scripts/installer/:name", httplib.MakeHandler(h.installer))
h.GET("/webapi/scripts/installer/:name", h.WithLimiter(h.installer))

// desktop access configuration scripts
h.GET("/webapi/scripts/desktop-access/install-ad-ds.ps1", httplib.MakeHandler(h.desktopAccessScriptInstallADDSHandle))
h.GET("/webapi/scripts/desktop-access/install-ad-cs.ps1", httplib.MakeHandler(h.desktopAccessScriptInstallADCSHandle))
h.GET("/webapi/scripts/desktop-access/configure/:token/configure-ad.ps1", httplib.MakeHandler(h.desktopAccessScriptConfigureHandle))
h.GET("/webapi/scripts/desktop-access/install-ad-ds.ps1", h.WithLimiter(h.desktopAccessScriptInstallADDSHandle))
h.GET("/webapi/scripts/desktop-access/install-ad-cs.ps1", h.WithLimiter(h.desktopAccessScriptInstallADCSHandle))
h.GET("/webapi/scripts/desktop-access/configure/:token/configure-ad.ps1", h.WithLimiter(h.desktopAccessScriptConfigureHandle))

// Forwards traces to the configured upstream collector
h.POST("/webapi/traces", h.WithAuth(h.traces))
Expand All @@ -557,7 +573,7 @@ func (h *Handler) bindDefaultEndpoints() {

// We have an overlap route here, please see godoc of handleGetUserOrResetToken
// h.GET("/webapi/users/:username", h.WithAuth(h.getUserHandle))
// h.GET("/webapi/users/password/token/:token", httplib.MakeHandler(h.getResetPasswordTokenHandle))
// h.GET("/webapi/users/password/token/:token", h.WithLimiter(h.getResetPasswordTokenHandle))
h.GET("/webapi/users/*wildcard", h.handleGetUserOrResetToken)

h.PUT("/webapi/users/password/token", httplib.WithCSRFProtection(h.changeUserAuthentication))
Expand All @@ -566,7 +582,7 @@ func (h *Handler) bindDefaultEndpoints() {
h.POST("/webapi/users/privilege/token", h.WithAuth(h.createPrivilegeTokenHandle))

// Issues SSH temp certificates based on 2FA access creds
h.POST("/webapi/ssh/certs", h.WithLimiter(h.createSSHCert))
h.POST("/webapi/ssh/certs", h.WithUnauthenticatedLimiter(h.createSSHCert))

// list available sites
h.GET("/webapi/sites", h.WithAuth(h.getClusters))
Expand Down Expand Up @@ -623,9 +639,9 @@ func (h *Handler) bindDefaultEndpoints() {
h.POST("/webapi/token", h.WithAuth(h.createTokenHandle))

// join scripts
h.GET("/scripts/:token/install-node.sh", httplib.MakeHandler(h.getNodeJoinScriptHandle))
h.GET("/scripts/:token/install-app.sh", httplib.MakeHandler(h.getAppJoinScriptHandle))
h.GET("/scripts/:token/install-database.sh", httplib.MakeHandler(h.getDatabaseJoinScriptHandle))
h.GET("/scripts/:token/install-node.sh", h.WithLimiter(h.getNodeJoinScriptHandle))
h.GET("/scripts/:token/install-app.sh", h.WithLimiter(h.getAppJoinScriptHandle))
h.GET("/scripts/:token/install-database.sh", h.WithLimiter(h.getDatabaseJoinScriptHandle))
// web context
h.GET("/webapi/sites/:site/context", h.WithClusterAuth(h.getUserContext))
h.GET("/webapi/sites/:site/resources/check", h.WithClusterAuth(h.checkAccessToRegisteredResource))
Expand All @@ -652,12 +668,12 @@ func (h *Handler) bindDefaultEndpoints() {
// MFA public endpoints.
h.POST("/webapi/sites/:site/mfa/required", h.WithClusterAuth(h.isMFARequired))
h.POST("/webapi/mfa/login/begin", h.WithLimiter(h.mfaLoginBegin))
h.POST("/webapi/mfa/login/finish", httplib.MakeHandler(h.mfaLoginFinish))
h.POST("/webapi/mfa/login/finishsession", httplib.MakeHandler(h.mfaLoginFinishSession))
h.DELETE("/webapi/mfa/token/:token/devices/:devicename", httplib.MakeHandler(h.deleteMFADeviceWithTokenHandle))
h.GET("/webapi/mfa/token/:token/devices", httplib.MakeHandler(h.getMFADevicesWithTokenHandle))
h.POST("/webapi/mfa/token/:token/authenticatechallenge", httplib.MakeHandler(h.createAuthenticateChallengeWithTokenHandle))
h.POST("/webapi/mfa/token/:token/registerchallenge", httplib.MakeHandler(h.createRegisterChallengeWithTokenHandle))
h.POST("/webapi/mfa/login/finish", h.WithLimiter(h.mfaLoginFinish))
h.POST("/webapi/mfa/login/finishsession", h.WithLimiter(h.mfaLoginFinishSession))
h.DELETE("/webapi/mfa/token/:token/devices/:devicename", h.WithLimiter(h.deleteMFADeviceWithTokenHandle))
h.GET("/webapi/mfa/token/:token/devices", h.WithLimiter(h.getMFADevicesWithTokenHandle))
h.POST("/webapi/mfa/token/:token/authenticatechallenge", h.WithLimiter(h.createAuthenticateChallengeWithTokenHandle))
h.POST("/webapi/mfa/token/:token/registerchallenge", h.WithLimiter(h.createRegisterChallengeWithTokenHandle))

// MFA private endpoints.
h.GET("/webapi/mfa/devices", h.WithAuth(h.getMFADevicesHandle))
Expand All @@ -666,7 +682,7 @@ func (h *Handler) bindDefaultEndpoints() {
h.POST("/webapi/mfa/authenticatechallenge/password", h.WithAuth(h.createAuthenticateChallengeWithPassword))

// trusted clusters
h.POST("/webapi/trustedclusters/validate", httplib.MakeHandler(h.validateTrustedCluster))
h.POST("/webapi/trustedclusters/validate", h.WithUnauthenticatedLimiter(h.validateTrustedCluster))

// User Status (used by client to check if user session is valid)
h.GET("/webapi/user/status", h.WithAuth(h.getUserStatus))
Expand Down Expand Up @@ -719,10 +735,10 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET(OIDCJWKWURI, h.WithLimiter(h.jwksOIDC))

// Connection upgrades.
h.GET("/webapi/connectionupgrade", httplib.MakeHandler(h.connectionUpgrade))
h.GET("/webapi/connectionupgrade", h.WithLimiter(h.connectionUpgrade))

// create user events.
h.POST("/webapi/precapture", h.WithLimiter(h.createPreUserEventHandle))
h.POST("/webapi/precapture", h.WithUnauthenticatedLimiter(h.createPreUserEventHandle))
// create authenticated user events.
h.POST("/webapi/capture", h.WithAuth(h.createUserEventHandle))

Expand Down Expand Up @@ -3647,33 +3663,91 @@ func (h *Handler) WithAuthCookieAndCSRF(fn ContextHandler) httprouter.Handle {
return httplib.WithCSRFProtection(f)
}

// WithUnauthenticatedLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated requests.
// This is a good default to use as both Cluster and User auth are checked here, but `WithLimiter` can be used if
// you're certain that no authenticated requests will be made.
func (h *Handler) WithUnauthenticatedLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return h.unauthenticatedLimiterFunc(fn, h.WithLimiterHandlerFunc)
}

// WithUnauthenticatedHighLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated
// requests. This is similar to WithUnauthenticatedLimiter, however this one allows a much higher rate limit.
// This higher rate limit should only be used on endpoints which are only CPU constrained
// (no file or other resources used).
func (h *Handler) WithUnauthenticatedHighLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return h.unauthenticatedLimiterFunc(fn, h.WithHighLimiterHandlerFunc)
}

func (h *Handler) unauthenticatedLimiterFunc(fn httplib.HandlerFunc, rateFunc func(fn httplib.HandlerFunc) httplib.HandlerFunc) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
if _, _, err := h.authenticateRequestWithCluster(w, r, p); err != nil {
// retry with user auth
if _, err = h.AuthenticateRequest(w, r, true /* check token */); err != nil {
// no auth passed, limit request
return rateFunc(fn)(w, r, p)
}
}
// auth passed, call directly
return fn(w, r, p)
})
}

// WithLimiter adds IP-based rate limiting to fn.
// Limits are applied to all requests, authenticated or not.
func (h *Handler) WithLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
return h.WithLimiterHandlerFunc(fn)(w, r, p)
})
}

// WithHighLimiter adds high rate IP-based rate limiting to fn.
// This should only be used on functions which are CPU constrained, and don't use disk or other services.
// Limits are applied to all requests, authenticated or not.
func (h *Handler) WithHighLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
return h.WithHighLimiterHandlerFunc(fn)(w, r, p)
})
}

// WithLimiterHandlerFunc adds IP-based rate limiting to a HandlerFunc. This
// should be used when you need to nest this inside another HandlerFunc.
func (h *Handler) WithLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
remote, _, err := net.SplitHostPort(r.RemoteAddr)
err := rateLimitRequest(r, h.limiter)
if err != nil {
return nil, trace.Wrap(err)
}
err = h.limiter.RegisterRequest(remote, nil /* customRate */)
// MaxRateError doesn't play well with errors.Is, hence the cast.
if _, ok := err.(*ratelimit.MaxRateError); ok {
return nil, trace.LimitExceeded(err.Error())
}
return fn(w, r, p)
}
}

// WithHighLimiterHandlerFunc adds IP-based rate limiting to a HandlerFunc. This is similar to WithLimiterHandlerFunc
// but provides a higher rate limit. This should only be used for requests which are only CPU bound (no disk or other
// resources used).
func (h *Handler) WithHighLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
err := rateLimitRequest(r, h.highLimiter)
if err != nil {
return nil, trace.Wrap(err)
}
return fn(w, r, p)
}
}

func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error {
remote, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return trace.Wrap(err)
}

err = limiter.RegisterRequest(remote, nil /* customRate */)
// MaxRateError doesn't play well with errors.Is, hence the type assertion.
if _, ok := err.(*ratelimit.MaxRateError); ok {
return trace.LimitExceeded(err.Error())
}
return trace.Wrap(err)
}

// AuthenticateRequest authenticates request using combination of a session cookie
// and bearer token
func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) {
Expand Down Expand Up @@ -3910,6 +3984,11 @@ func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, resp
// GET /webapi/sites/:site/auth/export?type=<auth type>
// GET /webapi/auth/export?type=<auth type>
func (h *Handler) authExportPublic(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
err := rateLimitRequest(r, h.limiter)
if err != nil {
http.Error(w, err.Error(), trace.ErrorToCode(err))
return
}
authorities, err := client.ExportAuthorities(
r.Context(),
h.GetProxyClient(),
Expand Down
2 changes: 1 addition & 1 deletion lib/web/apiserver_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ func TestAuthenticate_rateLimiting(t *testing.T) {
}{
{
name: "/webapi/mfa/login/begin",
burst: defaults.LimiterPasswordlessBurst,
burst: defaults.LimiterBurst,
fn: func(clt *client.WebClient) error {
ep := clt.Endpoint("webapi", "mfa", "login", "begin")
_, err := clt.PostJSON(ctx, ep, &client.MFAChallengeRequest{})
Expand Down