Skip to content

Commit

Permalink
streamline http.Error calls
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Jan 27, 2025
1 parent a7cb02f commit 2872cd4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 148 deletions.
4 changes: 1 addition & 3 deletions hscontrol/noise.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ func (h *Headscale) NoiseUpgradeHandler(
noiseServer.earlyNoise,
)
if err != nil {
log.Error().Err(err).Msg("noise upgrade failed")
http.Error(writer, err.Error(), http.StatusInternalServerError)

httpError(writer, err, "noise upgrade failed", http.StatusInternalServerError)
return
}

Expand Down
45 changes: 19 additions & 26 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,34 +134,28 @@ func (a *AuthProviderOIDC) RegisterHandler(
req *http.Request,
) {
vars := mux.Vars(req)
registrationIdStr, ok := vars["registration_id"]
registrationIdStr, _ := vars["registration_id"]

// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
if err != nil {
http.Error(writer, "invalid registration ID", http.StatusBadRequest)
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
return
}

log.Debug().
Caller().
Str("registration_id", registrationId.String()).
Bool("ok", ok).
Msg("Received oidc register call")

// Set the state and nonce cookies to protect against CSRF attacks
state, err := setCSRFCookie(writer, req, "state")
if err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError)
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
return
}

// Set the state and nonce cookies to protect against CSRF attacks
nonce, err := setCSRFCookie(writer, req, "nonce")
if err != nil {
http.Error(writer, "Internal server error", http.StatusInternalServerError)
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
return
}

Expand Down Expand Up @@ -225,64 +219,64 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
) {
code, state, err := extractCodeAndStateParamFromRequest(req)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
httpError(writer, err, err.Error(), http.StatusBadRequest)
return
}

log.Debug().Interface("cookies", req.Cookies()).Msg("Received oidc callback")
cookieState, err := req.Cookie("state")
if err != nil {
http.Error(writer, "state not found", http.StatusBadRequest)
httpError(writer, err, "state not found", http.StatusBadRequest)
return
}

if state != cookieState.Value {
http.Error(writer, "state did not match", http.StatusBadRequest)
httpError(writer, err, "state did not match", http.StatusBadRequest)
return
}

idToken, err := a.extractIDToken(req.Context(), code, state)
if err != nil {
http.Error(writer, err.Error(), http.StatusBadRequest)
httpError(writer, err, err.Error(), http.StatusBadRequest)
return
}

nonce, err := req.Cookie("nonce")
if err != nil {
http.Error(writer, "nonce not found", http.StatusBadRequest)
httpError(writer, err, "nonce not found", http.StatusBadRequest)
return
}
if idToken.Nonce != nonce.Value {
http.Error(writer, "nonce did not match", http.StatusBadRequest)
httpError(writer, err, "nonce did not match", http.StatusBadRequest)
return
}

nodeExpiry := a.determineNodeExpiry(idToken.Expiry)

var claims types.OIDCClaims
if err := idToken.Claims(&claims); err != nil {
http.Error(writer, fmt.Errorf("failed to decode ID token claims: %w", err).Error(), http.StatusInternalServerError)
err = fmt.Errorf("decoding ID token claims: %w", err)
httpError(writer, err, err.Error(), http.StatusInternalServerError)
return
}

if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
httpError(writer, err, err.Error(), http.StatusUnauthorized)
return
}

if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
httpError(writer, err, err.Error(), http.StatusUnauthorized)
return
}

if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
http.Error(writer, err.Error(), http.StatusUnauthorized)
httpError(writer, err, err.Error(), http.StatusUnauthorized)
return
}

user, err := a.createOrUpdateUserFromClaim(&claims)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
httpError(writer, err, err.Error(), http.StatusInternalServerError)
return
}

Expand All @@ -297,7 +291,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
verb := "Reauthenticated"
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
httpError(writer, err, err.Error(), http.StatusInternalServerError)
return
}

Expand All @@ -308,7 +302,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// TODO(kradalby): replace with go-elem
content, err := renderOIDCCallbackTemplate(user, verb)
if err != nil {
http.Error(writer, err.Error(), http.StatusInternalServerError)
httpError(writer, err, err.Error(), http.StatusInternalServerError)
return
}

Expand All @@ -323,7 +317,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(

// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
http.Error(writer, "login session expired, try again", http.StatusInternalServerError)
httpError(writer, nil, "login session expired, try again", http.StatusInternalServerError)
return
}

Expand Down Expand Up @@ -423,7 +417,6 @@ func validateOIDCAllowedUsers(
) error {
if len(allowedUsers) > 0 &&
!slices.Contains(allowedUsers, claims.Email) {
log.Trace().Msg("authenticated principal does not match any allowed user")
return errOIDCAllowedUsers
}

Expand Down
130 changes: 11 additions & 119 deletions hscontrol/platform_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/gofrs/uuid/v5"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/rs/zerolog/log"
)

// WindowsConfigMessage shows a simple message in the browser for how to configure the Windows Tailscale client.
Expand All @@ -20,13 +19,7 @@ func (h *Headscale) WindowsConfigMessage(
) {
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)

if _, err := writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())); err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
}

// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
Expand All @@ -36,13 +29,7 @@ func (h *Headscale) AppleConfigMessage(
) {
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)

if _, err := writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())); err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
}

func (h *Headscale) ApplePlatformConfig(
Expand All @@ -52,51 +39,19 @@ func (h *Headscale) ApplePlatformConfig(
vars := mux.Vars(req)
platform, ok := vars["platform"]
if !ok {
log.Error().
Str("handler", "ApplePlatformConfig").
Msg("No platform specified")
http.Error(writer, "No platform specified", http.StatusBadRequest)

httpError(writer, nil, "No platform specified", http.StatusBadRequest)
return
}

id, err := uuid.NewV4()
if err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Failed to create UUID"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}

httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError)
return
}

contentID, err := uuid.NewV4()
if err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Failed to create content UUID"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}

httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError)
return
}

Expand All @@ -106,68 +61,25 @@ func (h *Headscale) ApplePlatformConfig(
}

var payload bytes.Buffer
handleMacError := func(ierr error) {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(ierr).
Msg("Could not render Apple macOS template")

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Apple macOS template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}

switch platform {
case "macos-standalone":
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil {
handleMacError(err)

httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError)
return
}
case "macos-app-store":
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil {
handleMacError(err)

httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError)
return
}
case "ios":
if err := iosTemplate.Execute(&payload, platformConfig); err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple iOS template")

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Apple iOS template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}

httpError(writer, err, "Could not render Apple iOS template", http.StatusInternalServerError)
return
}
default:
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write(
[]byte("Invalid platform. Only ios, macos-app-store and macos-standalone are supported"),
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}

httpError(writer, err, "Invalid platform. Only ios, macos-app-store and macos-standalone are supported", http.StatusInternalServerError)
return
}

Expand All @@ -179,34 +91,14 @@ func (h *Headscale) ApplePlatformConfig(

var content bytes.Buffer
if err := commonTemplate.Execute(&content, config); err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple platform template")

writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Apple platform template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}

httpError(writer, err, "Could not render platform iOS template", http.StatusInternalServerError)
return
}

writer.Header().
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
writer.Write(content.Bytes())
}

type AppleMobileConfig struct {
Expand Down

0 comments on commit 2872cd4

Please sign in to comment.