From 289b6114f5e6acd4749be8ef0bd046881ec077dd Mon Sep 17 00:00:00 2001 From: Mike Jensen Date: Tue, 18 Apr 2023 13:59:41 -0600 Subject: [PATCH] Add the ability to only rate limit unauthenticated requests This allows requests which are authenticated to avoid rate limiting. It was seen on `ping` that authentication was already provided and thus this can help reduce risks around large clusters needing to ping frequently through a NAT. --- lib/web/apiserver.go | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 2d02a08033070..0de729f8ab77c 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -372,7 +372,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { return nil, trace.BadParameter("failed parsing index.html template: %v", err) } - h.Handle("GET", "/web/config.js", h.WithLimiter(h.getWebConfig)) + h.Handle("GET", "/web/config.js", h.WithUnauthenticatedLimiter(h.getWebConfig)) } routingHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -508,9 +508,9 @@ func (h *Handler) bindMinimalEndpoints() { // find is like ping, but is faster because it is optimized for servers // and does not fetch the data that servers don't need, e.g. // OIDC connectors and auth preferences - h.GET("/webapi/find", h.WithLimiter(h.find)) + h.GET("/webapi/find", h.WithUnauthenticatedLimiter(h.find)) // Issue host credentials. - h.POST("/webapi/host/credentials", h.WithLimiter(h.hostCredentials)) + h.POST("/webapi/host/credentials", h.WithUnauthenticatedLimiter(h.hostCredentials)) } // bindDefaultEndpoints binds the default endpoints for the web API. @@ -521,11 +521,11 @@ func (h *Handler) bindDefaultEndpoints() { // endpoint returns the default authentication method and configuration that // the server supports. the /webapi/ping/:connector endpoint can be used to // query the authentication configuration for a specific connector. - h.GET("/webapi/ping", h.WithLimiter(h.ping)) - h.GET("/webapi/ping/:connector", h.WithLimiter(h.pingWithConnector)) + h.GET("/webapi/ping", h.WithUnauthenticatedLimiter(h.ping)) + h.GET("/webapi/ping/:connector", h.WithUnauthenticatedLimiter(h.pingWithConnector)) // Unauthenticated access to JWT public keys. - h.GET("/.well-known/jwks.json", h.WithLimiter(h.jwks)) + h.GET("/.well-known/jwks.json", h.WithUnauthenticatedLimiter(h.jwks)) // Unauthenticated access to the message of the day h.GET("/webapi/motd", h.WithLimiter(h.motd)) @@ -565,7 +565,7 @@ func (h *Handler) bindDefaultEndpoints() { h.POST("/webapi/users/privilege/token", h.WithAuth(h.createPrivilegeTokenHandle)) // Issues SSH temp certificates based on 2FA access creds - h.POST("/webapi/ssh/certs", h.WithLimiter(h.createSSHCert)) + h.POST("/webapi/ssh/certs", h.WithUnauthenticatedLimiter(h.createSSHCert)) // list available sites h.GET("/webapi/sites", h.WithAuth(h.getClusters)) @@ -665,7 +665,7 @@ func (h *Handler) bindDefaultEndpoints() { h.POST("/webapi/mfa/authenticatechallenge/password", h.WithAuth(h.createAuthenticateChallengeWithPassword)) // trusted clusters - h.POST("/webapi/trustedclusters/validate", h.WithLimiter(h.validateTrustedCluster)) + h.POST("/webapi/trustedclusters/validate", h.WithUnauthenticatedLimiter(h.validateTrustedCluster)) // User Status (used by client to check if user session is valid) h.GET("/webapi/user/status", h.WithAuth(h.getUserStatus)) @@ -714,7 +714,7 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/connectionupgrade", h.WithLimiter(h.connectionUpgrade)) // create user events. - h.POST("/webapi/precapture", h.WithLimiter(h.createPreUserEventHandle)) + h.POST("/webapi/precapture", h.WithUnauthenticatedLimiter(h.createPreUserEventHandle)) // create authenticated user events. h.POST("/webapi/capture", h.WithAuth(h.createUserEventHandle)) @@ -3621,6 +3621,25 @@ func (h *Handler) WithAuthCookieAndCSRF(fn ContextHandler) httprouter.Handle { return httplib.WithCSRFProtection(f) } +// WithUnauthenticatedLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated requests. +// This is a good default to use as both Cluster and User auth are checked here, but `WithLimiter` can be used if +// you're certain that no authenticated requests will be made. +func (h *Handler) WithUnauthenticatedLimiter(fn httplib.HandlerFunc) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + _, _, err := h.authenticateRequestWithCluster(w, r, p) + if err != nil { + // retry with user auth + _, err = h.AuthenticateRequest(w, r, true) + if err != nil { + // no auth passed, limit request + return h.WithLimiterHandlerFunc(fn)(w, r, p) + } + } + // auth passed, call directly + return fn(w, r, p) + }) +} + // WithLimiter adds IP-based rate limiting to fn. func (h *Handler) WithLimiter(fn httplib.HandlerFunc) httprouter.Handle { return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {