Skip to content

Commit

Permalink
Add the ability to only rate limit unauthenticated requests
Browse files Browse the repository at this point in the history
This allows requests which are authenticated to avoid rate limiting.
It was seen on `ping` that authentication was already provided and thus this can help reduce risks around large clusters needing to ping frequently through a NAT.
  • Loading branch information
jentfoo committed Apr 18, 2023
1 parent d192821 commit 289b611
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
return nil, trace.BadParameter("failed parsing index.html template: %v", err)
}

h.Handle("GET", "/web/config.js", h.WithLimiter(h.getWebConfig))
h.Handle("GET", "/web/config.js", h.WithUnauthenticatedLimiter(h.getWebConfig))
}

routingHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -508,9 +508,9 @@ func (h *Handler) bindMinimalEndpoints() {
// find is like ping, but is faster because it is optimized for servers
// and does not fetch the data that servers don't need, e.g.
// OIDC connectors and auth preferences
h.GET("/webapi/find", h.WithLimiter(h.find))
h.GET("/webapi/find", h.WithUnauthenticatedLimiter(h.find))
// Issue host credentials.
h.POST("/webapi/host/credentials", h.WithLimiter(h.hostCredentials))
h.POST("/webapi/host/credentials", h.WithUnauthenticatedLimiter(h.hostCredentials))
}

// bindDefaultEndpoints binds the default endpoints for the web API.
Expand All @@ -521,11 +521,11 @@ func (h *Handler) bindDefaultEndpoints() {
// endpoint returns the default authentication method and configuration that
// the server supports. the /webapi/ping/:connector endpoint can be used to
// query the authentication configuration for a specific connector.
h.GET("/webapi/ping", h.WithLimiter(h.ping))
h.GET("/webapi/ping/:connector", h.WithLimiter(h.pingWithConnector))
h.GET("/webapi/ping", h.WithUnauthenticatedLimiter(h.ping))
h.GET("/webapi/ping/:connector", h.WithUnauthenticatedLimiter(h.pingWithConnector))

// Unauthenticated access to JWT public keys.
h.GET("/.well-known/jwks.json", h.WithLimiter(h.jwks))
h.GET("/.well-known/jwks.json", h.WithUnauthenticatedLimiter(h.jwks))

// Unauthenticated access to the message of the day
h.GET("/webapi/motd", h.WithLimiter(h.motd))
Expand Down Expand Up @@ -565,7 +565,7 @@ func (h *Handler) bindDefaultEndpoints() {
h.POST("/webapi/users/privilege/token", h.WithAuth(h.createPrivilegeTokenHandle))

// Issues SSH temp certificates based on 2FA access creds
h.POST("/webapi/ssh/certs", h.WithLimiter(h.createSSHCert))
h.POST("/webapi/ssh/certs", h.WithUnauthenticatedLimiter(h.createSSHCert))

// list available sites
h.GET("/webapi/sites", h.WithAuth(h.getClusters))
Expand Down Expand Up @@ -665,7 +665,7 @@ func (h *Handler) bindDefaultEndpoints() {
h.POST("/webapi/mfa/authenticatechallenge/password", h.WithAuth(h.createAuthenticateChallengeWithPassword))

// trusted clusters
h.POST("/webapi/trustedclusters/validate", h.WithLimiter(h.validateTrustedCluster))
h.POST("/webapi/trustedclusters/validate", h.WithUnauthenticatedLimiter(h.validateTrustedCluster))

// User Status (used by client to check if user session is valid)
h.GET("/webapi/user/status", h.WithAuth(h.getUserStatus))
Expand Down Expand Up @@ -714,7 +714,7 @@ func (h *Handler) bindDefaultEndpoints() {
h.GET("/webapi/connectionupgrade", h.WithLimiter(h.connectionUpgrade))

// create user events.
h.POST("/webapi/precapture", h.WithLimiter(h.createPreUserEventHandle))
h.POST("/webapi/precapture", h.WithUnauthenticatedLimiter(h.createPreUserEventHandle))
// create authenticated user events.
h.POST("/webapi/capture", h.WithAuth(h.createUserEventHandle))

Expand Down Expand Up @@ -3621,6 +3621,25 @@ func (h *Handler) WithAuthCookieAndCSRF(fn ContextHandler) httprouter.Handle {
return httplib.WithCSRFProtection(f)
}

// WithUnauthenticatedLimiter adds a conditional IP-based rate limiting that will limit only unauthenticated requests.
// This is a good default to use as both Cluster and User auth are checked here, but `WithLimiter` can be used if
// you're certain that no authenticated requests will be made.
func (h *Handler) WithUnauthenticatedLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
_, _, err := h.authenticateRequestWithCluster(w, r, p)
if err != nil {
// retry with user auth
_, err = h.AuthenticateRequest(w, r, true)
if err != nil {
// no auth passed, limit request
return h.WithLimiterHandlerFunc(fn)(w, r, p)
}
}
// auth passed, call directly
return fn(w, r, p)
})
}

// WithLimiter adds IP-based rate limiting to fn.
func (h *Handler) WithLimiter(fn httplib.HandlerFunc) httprouter.Handle {
return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
Expand Down

0 comments on commit 289b611

Please sign in to comment.