Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom HTTP handler for authentication errors #211

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,9 @@ There are several ways to adjust functionality of the library:
1. `ClaimsUpdater` - interface with `Update(claims Claims) Claims` method. This is the primary way to alter a token at login time and add any attributes, set ip, email, admin status, roles and so on.
1. `Validator` - interface with `Validate(token string, claims Claims) bool` method. This is post-token hook and will be called on **each request** wrapped with `Auth` middleware. This will be the place for special logic to reject some tokens or users.
1. `UserUpdater` - interface with `Update(claims token.User) token.User` method. This method will be called on **each request** wrapped with `UpdateUser` middleware. This will be the place for special logic modify User Info in request context. [Example of usage.](https://github.com/go-pkgz/auth/blob/19c1b6d26608494955a4480f8f6165af85b1deab/_example/main.go#L189)
1. `AuthErrorHTTPHandler` - interface with `ServeAuthError(w http.ResponseWriter, r *http.Request, and other params)` method. It is possible to change how authentication errors are written into HTTP responses by configuring custom implementations of this interface for the middlewares.

All of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.
Some of the interfaces above have corresponding Func adapters - `SecretFunc`, `ClaimsUpdFunc`, `ValidatorFunc` and `UserUpdFunc`.
paskal marked this conversation as resolved.
Show resolved Hide resolved

### Implementing black list logic or some other filters

Expand Down
22 changes: 12 additions & 10 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,13 @@ type Opts struct {
AvatarRoutePath string // avatar routing prefix, i.e. "/api/v1/avatar", default `/avatar`
UseGravatar bool // for email based auth (verified provider) use gravatar service

AdminPasswd string // if presented, allows basic auth with user admin and given password
BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
AudSecrets bool // allow multiple secrets (secret per aud)
Logger logger.L // logger interface, default is no logging at all
RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
AdminPasswd string // if presented, allows basic auth with user admin and given password
BasicAuthChecker middleware.BasicAuthFunc // user custom checker for basic auth, if one defined then "AdminPasswd" will ignored
AudienceReader token.Audience // list of allowed aud values, default (empty) allows any
AudSecrets bool // allow multiple secrets (secret per aud)
Logger logger.L // logger interface, default is no logging at all
RefreshCache middleware.RefreshCache // optional cache to keep refreshed tokens
AuthErrorHTTPHandler middleware.AuthErrorHTTPHandler // optional HTTP handler for authentication errors
}

// NewService initializes everything
Expand All @@ -81,10 +82,11 @@ func NewService(opts Opts) (res *Service) {
opts: opts,
logger: opts.Logger,
authMiddleware: middleware.Authenticator{
Validator: opts.Validator,
AdminPasswd: opts.AdminPasswd,
BasicAuthChecker: opts.BasicAuthChecker,
RefreshCache: opts.RefreshCache,
Validator: opts.Validator,
AdminPasswd: opts.AdminPasswd,
BasicAuthChecker: opts.BasicAuthChecker,
RefreshCache: opts.RefreshCache,
AuthErrorHTTPHandler: opts.AuthErrorHTTPHandler,
},
issuer: opts.Issuer,
useGravatar: opts.UseGravatar,
Expand Down
201 changes: 201 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -246,6 +247,206 @@ func TestIntegrationList(t *testing.T) {
assert.Equal(t, `["dev","github","custom123"]`+"\n", string(b))
}

type testAuthErrorHTTPHandler struct {
wasCalled bool
statusCode int
contentType string
responseBody string
}

func (h *testAuthErrorHTTPHandler) ServeAuthError(
w http.ResponseWriter,
_ *http.Request,
authError error,
reason string,
statusCode int,
) {
h.wasCalled = true
w.Header().Set("Content-Type", h.contentType)
w.WriteHeader(h.statusCode)
fmt.Fprint(w, h.responseBody)
}

func TestIntegrationAuthErrorHTTPHandler(t *testing.T) {
paskal marked this conversation as resolved.
Show resolved Hide resolved
testErrorHandler1 := &testAuthErrorHTTPHandler{
statusCode: 401,
contentType: "application/json",
responseBody: `{"code": 401, "message": "from general error handler"}`,
}
testErrorHandler2 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "text/html",
responseBody: `<html><body><h1>from private2 error handler</h1></body></html>`,
}
testErrorHandler3 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "application/json",
responseBody: `{"code": 401, "message": "from admin error handler"}`,
}
testErrorHandler4 := &testAuthErrorHTTPHandler{
statusCode: 403,
contentType: "text/html",
responseBody: `<html><body><h1>from RBAC error handler</h1></body></html>`,
}

options := Opts{
SecretReader: token.SecretFunc(func(string) (string, error) { return "secret", nil }),
Issuer: "my-test-app",
URL: "http://127.0.0.1:8089",
}

svc := NewService(options)
svc.AddDevProvider("localhost", 18084) // add dev provider on 18084
svc.authMiddleware.AuthErrorHTTPHandler = testErrorHandler1

// setup http server
m := svc.Middleware()
mux := http.NewServeMux()
mux.Handle("/private1",
m.Auth(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("protected route1\n"))
},
),
),
)
mux.Handle("/private2",
m.AuthWithErrorHTTPHandler(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("protected route2\n"))
},
),
testErrorHandler2,
),
)
mux.Handle("/admin1",
m.AdminOnly(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("admin route1\n"))
},
),
),
)
mux.Handle("/admin2",
m.AdminOnlyWithErrorHTTPHandler(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("admin route2\n"))
},
),
testErrorHandler3,
),
)
mux.Handle("/rbac1",
m.RBAC("role1", "role2")(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("rbac route1\n"))
},
),
),
)
mux.Handle("/rbac2",
m.RBACwithErrorHTTPHandler(testErrorHandler4, "role1", "role2")(
http.HandlerFunc(
func(w http.ResponseWriter, _ *http.Request) { // token required
_, _ = w.Write([]byte("rbac route2\n"))
},
),
),
)

l, listenErr := net.Listen("tcp", "127.0.0.1:8089")
require.Nil(t, listenErr)
ts := httptest.NewUnstartedServer(mux)
assert.NoError(t, ts.Listener.Close())
ts.Listener = l
ts.Start()
defer func() {
ts.Close()
}()

assertBodyEquals := func(t *testing.T, r *http.Response, expectedBody string) {
b, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.Equal(t, expectedBody, string(b))
}
assertContentTypeEquals := func(t *testing.T, r *http.Response, expectedContentType string) {
assert.Equal(t, expectedContentType, r.Header.Get("Content-Type"))
}

// private1
resp, err := http.Get("http://127.0.0.1:8089/private1")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler1.wasCalled)

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)

// private2
resp, err = http.Get("http://127.0.0.1:8089/private2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler2.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "text/html")
assertBodyEquals(t, resp, `<html><body><h1>from private2 error handler</h1></body></html>`)

// admin1
testErrorHandler1.wasCalled = false
resp, err = http.Get("http://127.0.0.1:8089/admin1")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler1.wasCalled)

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)

// admin2
resp, err = http.Get("http://127.0.0.1:8089/admin2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler3.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from admin error handler"}`)

// rbac1
testErrorHandler1.wasCalled = false
resp, err = http.Get("http://127.0.0.1:8089/rbac1")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler1.wasCalled)

assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assertContentTypeEquals(t, resp, "application/json")
assertBodyEquals(t, resp, `{"code": 401, "message": "from general error handler"}`)

// rbac2
resp, err = http.Get("http://127.0.0.1:8089/rbac2")
require.NoError(t, err)
defer resp.Body.Close()

require.True(t, testErrorHandler4.wasCalled)

assert.Equal(t, http.StatusForbidden, resp.StatusCode)
assertContentTypeEquals(t, resp, "text/html")
assertBodyEquals(t, resp, `<html><body><h1>from RBAC error handler</h1></body></html>`)
}

func TestIntegrationUserInfo(t *testing.T) {
_, teardown := prepService(t)
defer teardown()
Expand Down
Loading
Loading