From 4bad8c435129512699d47a136548a43cfc502832 Mon Sep 17 00:00:00 2001 From: Frederic BIDON Date: Fri, 8 Dec 2023 16:32:47 +0100 Subject: [PATCH] test: refactored tests for authenticators * addressed code duplication in tests * used sub tests to improve readability * covered a few additional edged cases (passing unexpected types, empty realm) Signed-off-by: Frederic BIDON --- security/apikey_auth_test.go | 365 +++++++++++---------- security/authenticator.go | 13 +- security/authorizer_test.go | 55 ++++ security/basic_auth_test.go | 296 +++++++++-------- security/bearer_auth_test.go | 597 ++++++++++++++++------------------- 5 files changed, 694 insertions(+), 632 deletions(-) diff --git a/security/apikey_auth_test.go b/security/apikey_auth_test.go index 8a0c5a7..02b83cd 100644 --- a/security/apikey_auth_test.go +++ b/security/apikey_auth_test.go @@ -16,6 +16,7 @@ package security import ( "context" + "fmt" "net/http" "testing" @@ -25,181 +26,199 @@ import ( ) const ( - apiToken = "token123" - apiTokenPrincipal = "admin" + apiKeyParam = "api_key" + apiKeyHeader = "X-API-KEY" ) -var tokenAuth = TokenAuthentication(func(token string) (interface{}, error) { - if token == apiToken { - return apiTokenPrincipal, nil - } - return nil, errors.Unauthenticated("token") -}) - -var tokenAuthCtx = TokenAuthenticationCtx(func(ctx context.Context, token string) (context.Context, interface{}, error) { - if token == apiToken { - return context.WithValue(ctx, extra, extraWisdom), apiTokenPrincipal, nil - } - return context.WithValue(ctx, reason, expReason), nil, errors.Unauthenticated("token") -}) - -func TestInvalidApiKeyAuthInitialization(t *testing.T) { - assert.Panics(t, func() { APIKeyAuth("api_key", "qery", tokenAuth) }) +func TestApiKeyAuth(t *testing.T) { + tokenAuth := TokenAuthentication(func(token string) (interface{}, error) { + if token == validToken { + return principal, nil + } + return nil, errors.Unauthenticated("token") + }) + + t.Run("with invalid initialization", func(t *testing.T) { + assert.Panics(t, func() { APIKeyAuth(apiKeyParam, "qery", tokenAuth) }) + }) + + t.Run("with token in query param", func(t *testing.T) { + ta := APIKeyAuth(apiKeyParam, query, tokenAuth) + + t.Run("with valid token", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s?%s=%s", authPath, apiKeyParam, validToken), nil) + require.NoError(t, err) + + ok, usr, err := ta.Authenticate(req) + assert.True(t, ok) + assert.Equal(t, principal, usr) + require.NoError(t, err) + }) + + t.Run("with invalid token", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s?%s=%s", authPath, apiKeyParam, invalidToken), nil) + require.NoError(t, err) + + ok, usr, err := ta.Authenticate(req) + assert.True(t, ok) + assert.Equal(t, nil, usr) + require.Error(t, err) + }) + + t.Run("with missing token", func(t *testing.T) { + // put the token in the header, but query param is expected + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.Header.Set(apiKeyHeader, validToken) + + ok, usr, err := ta.Authenticate(req) + assert.False(t, ok) + assert.Equal(t, nil, usr) + require.NoError(t, err) + }) + }) + + t.Run("with token in header", func(t *testing.T) { + ta := APIKeyAuth(apiKeyHeader, header, tokenAuth) + + t.Run("with valid token", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.Header.Set(apiKeyHeader, validToken) + + ok, usr, err := ta.Authenticate(req) + assert.True(t, ok) + assert.Equal(t, principal, usr) + require.NoError(t, err) + }) + + t.Run("with invalid token", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.Header.Set(apiKeyHeader, invalidToken) + + ok, usr, err := ta.Authenticate(req) + assert.True(t, ok) + assert.Equal(t, nil, usr) + require.Error(t, err) + }) + + t.Run("with missing token", func(t *testing.T) { + // put the token in the query param, but header is expected + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, fmt.Sprintf("%s?%s=%s", authPath, apiKeyParam, validToken), nil) + require.NoError(t, err) + + ok, usr, err := ta.Authenticate(req) + assert.False(t, ok) + assert.Equal(t, nil, usr) + require.NoError(t, err) + }) + }) } -func TestValidApiKeyAuth(t *testing.T) { - ta := APIKeyAuth("api_key", "query", tokenAuth) - ta2 := APIKeyAuth("X-API-KEY", "header", tokenAuth) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?api_key=token123", nil) - require.NoError(t, err) - - ok, usr, err := ta.Authenticate(req1) - assert.True(t, ok) - assert.Equal(t, apiTokenPrincipal, usr) - require.NoError(t, err) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2.Header.Set("X-API-KEY", apiToken) - - ok, usr, err = ta2.Authenticate(req2) - assert.True(t, ok) - assert.Equal(t, apiTokenPrincipal, usr) - require.NoError(t, err) -} - -func TestInvalidApiKeyAuth(t *testing.T) { - ta := APIKeyAuth("api_key", "query", tokenAuth) - ta2 := APIKeyAuth("X-API-KEY", "header", tokenAuth) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?api_key=token124", nil) - require.NoError(t, err) - - ok, usr, err := ta.Authenticate(req1) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2.Header.Set("X-API-KEY", "token124") - - ok, usr, err = ta2.Authenticate(req2) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) -} - -func TestMissingApiKeyAuth(t *testing.T) { - ta := APIKeyAuth("api_key", "query", tokenAuth) - ta2 := APIKeyAuth("X-API-KEY", "header", tokenAuth) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req1.Header.Set("X-API-KEY", apiToken) - - ok, usr, err := ta.Authenticate(req1) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?api_key=token123", nil) - require.NoError(t, err) - - ok, usr, err = ta2.Authenticate(req2) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) -} - -func TestInvalidApiKeyAuthInitializationCtx(t *testing.T) { - assert.Panics(t, func() { APIKeyAuthCtx("api_key", "qery", tokenAuthCtx) }) -} - -func TestValidApiKeyAuthCtx(t *testing.T) { - ta := APIKeyAuthCtx("api_key", "query", tokenAuthCtx) - ta2 := APIKeyAuthCtx("X-API-KEY", "header", tokenAuthCtx) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?api_key=token123", nil) - require.NoError(t, err) - req1 = req1.WithContext(context.WithValue(req1.Context(), original, wisdom)) - ok, usr, err := ta.Authenticate(req1) - assert.True(t, ok) - assert.Equal(t, apiTokenPrincipal, usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req1.Context().Value(original)) - assert.Equal(t, extraWisdom, req1.Context().Value(extra)) - assert.Nil(t, req1.Context().Value(reason)) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom)) - req2.Header.Set("X-API-KEY", apiToken) - - ok, usr, err = ta2.Authenticate(req2) - assert.True(t, ok) - assert.Equal(t, apiTokenPrincipal, usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req2.Context().Value(original)) - assert.Equal(t, extraWisdom, req2.Context().Value(extra)) - assert.Nil(t, req2.Context().Value(reason)) -} - -func TestInvalidApiKeyAuthCtx(t *testing.T) { - ta := APIKeyAuthCtx("api_key", "query", tokenAuthCtx) - ta2 := APIKeyAuthCtx("X-API-KEY", "header", tokenAuthCtx) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?api_key=token124", nil) - require.NoError(t, err) - req1 = req1.WithContext(context.WithValue(req1.Context(), original, wisdom)) - ok, usr, err := ta.Authenticate(req1) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - assert.Equal(t, wisdom, req1.Context().Value(original)) - assert.Equal(t, expReason, req1.Context().Value(reason)) - assert.Nil(t, req1.Context().Value(extra)) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom)) - req2.Header.Set("X-API-KEY", "token124") - - ok, usr, err = ta2.Authenticate(req2) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - assert.Equal(t, wisdom, req2.Context().Value(original)) - assert.Equal(t, expReason, req2.Context().Value(reason)) - assert.Nil(t, req2.Context().Value(extra)) -} - -func TestMissingApiKeyAuthCtx(t *testing.T) { - ta := APIKeyAuthCtx("api_key", "query", tokenAuthCtx) - ta2 := APIKeyAuthCtx("X-API-KEY", "header", tokenAuthCtx) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req1 = req1.WithContext(context.WithValue(req1.Context(), original, wisdom)) - req1.Header.Set("X-API-KEY", apiToken) - - ok, usr, err := ta.Authenticate(req1) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req1.Context().Value(original)) - assert.Nil(t, req1.Context().Value(reason)) - assert.Nil(t, req1.Context().Value(extra)) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?api_key=token123", nil) - require.NoError(t, err) - req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom)) - ok, usr, err = ta2.Authenticate(req2) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req2.Context().Value(original)) - assert.Nil(t, req2.Context().Value(reason)) - assert.Nil(t, req2.Context().Value(extra)) +func TestApiKeyAuthCtx(t *testing.T) { + tokenAuthCtx := TokenAuthenticationCtx(func(ctx context.Context, token string) (context.Context, interface{}, error) { + if token == validToken { + return context.WithValue(ctx, extra, extraWisdom), principal, nil + } + return context.WithValue(ctx, reason, expReason), nil, errors.Unauthenticated("token") + }) + ctx := context.WithValue(context.Background(), original, wisdom) + + t.Run("with invalid initialization", func(t *testing.T) { + assert.Panics(t, func() { APIKeyAuthCtx(apiKeyParam, "qery", tokenAuthCtx) }) + }) + + t.Run("with token in query param", func(t *testing.T) { + ta := APIKeyAuthCtx(apiKeyParam, query, tokenAuthCtx) + + t.Run("with valid token", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s?%s=%s", authPath, apiKeyParam, validToken), nil) + require.NoError(t, err) + ok, usr, err := ta.Authenticate(req) + assert.True(t, ok) + assert.Equal(t, principal, usr) + require.NoError(t, err) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Equal(t, extraWisdom, req.Context().Value(extra)) + assert.Nil(t, req.Context().Value(reason)) + }) + + t.Run("with invalid token", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s?%s=%s", authPath, apiKeyParam, invalidToken), nil) + require.NoError(t, err) + ok, usr, err := ta.Authenticate(req) + assert.True(t, ok) + assert.Equal(t, nil, usr) + require.Error(t, err) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Equal(t, expReason, req.Context().Value(reason)) + assert.Nil(t, req.Context().Value(extra)) + }) + + t.Run("with missing token", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authPath, nil) + require.NoError(t, err) + req.Header.Set(apiKeyHeader, validToken) + + ok, usr, err := ta.Authenticate(req) + assert.False(t, ok) + assert.Equal(t, nil, usr) + require.NoError(t, err) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Nil(t, req.Context().Value(reason)) + assert.Nil(t, req.Context().Value(extra)) + }) + }) + + t.Run("with token in header", func(t *testing.T) { + ta := APIKeyAuthCtx(apiKeyHeader, header, tokenAuthCtx) + + t.Run("with valid token", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authPath, nil) + require.NoError(t, err) + req.Header.Set(apiKeyHeader, validToken) + + ok, usr, err := ta.Authenticate(req) + assert.True(t, ok) + assert.Equal(t, principal, usr) + require.NoError(t, err) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Equal(t, extraWisdom, req.Context().Value(extra)) + assert.Nil(t, req.Context().Value(reason)) + }) + + t.Run("with invalid token", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authPath, nil) + require.NoError(t, err) + req.Header.Set(apiKeyHeader, invalidToken) + + ok, usr, err := ta.Authenticate(req) + assert.True(t, ok) + assert.Equal(t, nil, usr) + require.Error(t, err) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Equal(t, expReason, req.Context().Value(reason)) + assert.Nil(t, req.Context().Value(extra)) + }) + + t.Run("with missing token", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s?%s=%s", authPath, apiKeyParam, validToken), nil) + require.NoError(t, err) + + ok, usr, err := ta.Authenticate(req) + assert.False(t, ok) + assert.Equal(t, nil, usr) + require.NoError(t, err) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Nil(t, req.Context().Value(reason)) + assert.Nil(t, req.Context().Value(extra)) + }) + }) } diff --git a/security/authenticator.go b/security/authenticator.go index a57d5e8..bb30472 100644 --- a/security/authenticator.go +++ b/security/authenticator.go @@ -25,8 +25,9 @@ import ( ) const ( - query = "query" - header = "header" + query = "query" + header = "header" + accessTokenParam = "access_token" ) // HttpAuthenticator is a function that authenticates a HTTP request @@ -226,12 +227,12 @@ func BearerAuth(name string, authenticate ScopedTokenAuthentication) runtime.Aut } if token == "" { qs := r.Request.URL.Query() - token = qs.Get("access_token") + token = qs.Get(accessTokenParam) } //#nosec ct, _, _ := runtime.ContentType(r.Request.Header) if token == "" && (ct == "application/x-www-form-urlencoded" || ct == "multipart/form-data") { - token = r.Request.FormValue("access_token") + token = r.Request.FormValue(accessTokenParam) } if token == "" { @@ -256,12 +257,12 @@ func BearerAuthCtx(name string, authenticate ScopedTokenAuthenticationCtx) runti } if token == "" { qs := r.Request.URL.Query() - token = qs.Get("access_token") + token = qs.Get(accessTokenParam) } //#nosec ct, _, _ := runtime.ContentType(r.Request.Header) if token == "" && (ct == "application/x-www-form-urlencoded" || ct == "multipart/form-data") { - token = r.Request.FormValue("access_token") + token = r.Request.FormValue(accessTokenParam) } if token == "" { diff --git a/security/authorizer_test.go b/security/authorizer_test.go index 92eb1e3..5463748 100644 --- a/security/authorizer_test.go +++ b/security/authorizer_test.go @@ -15,8 +15,11 @@ package security import ( + "context" + "net/http" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,3 +29,55 @@ func TestAuthorized(t *testing.T) { err := authorizer.Authorize(nil, nil) require.NoError(t, err) } + +func TestAuthenticator(t *testing.T) { + r, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) + require.NoError(t, err) + + t.Run("with HttpAuthenticator", func(t *testing.T) { + auth := HttpAuthenticator(func(_ *http.Request) (bool, interface{}, error) { return true, "test", nil }) + + t.Run("authenticator should work on *http.Request", func(t *testing.T) { + isAuth, user, err := auth.Authenticate(r) + require.NoError(t, err) + assert.True(t, isAuth) + assert.Equal(t, "test", user) + }) + + t.Run("authenticator should work on *ScopedAuthRequest", func(t *testing.T) { + scoped := &ScopedAuthRequest{Request: r} + + isAuth, user, err := auth.Authenticate(scoped) + require.NoError(t, err) + assert.True(t, isAuth) + assert.Equal(t, "test", user) + }) + + t.Run("authenticator should return false on other inputs", func(t *testing.T) { + isAuth, user, err := auth.Authenticate("") + require.NoError(t, err) + assert.False(t, isAuth) + assert.Empty(t, user) + }) + }) + + t.Run("with ScopedAuthenticator", func(t *testing.T) { + auth := ScopedAuthenticator(func(_ *ScopedAuthRequest) (bool, interface{}, error) { return true, "test", nil }) + + t.Run("authenticator should work on *ScopedAuthRequest", func(t *testing.T) { + scoped := &ScopedAuthRequest{Request: r} + + isAuth, user, err := auth.Authenticate(scoped) + require.NoError(t, err) + assert.True(t, isAuth) + assert.Equal(t, "test", user) + }) + + t.Run("authenticator should return false on other inputs", func(t *testing.T) { + isAuth, user, err := auth.Authenticate("") + require.NoError(t, err) + assert.False(t, isAuth) + assert.Empty(t, user) + }) + }) +} diff --git a/security/basic_auth_test.go b/security/basic_auth_test.go index 1e5f12f..8a90451 100644 --- a/security/basic_auth_test.go +++ b/security/basic_auth_test.go @@ -33,137 +33,181 @@ const ( ) const ( - wisdom = "The man who is swimming against the stream knows the strength of it." - extraWisdom = "Our greatest glory is not in never falling, but in rising every time we fall." - expReason = "I like the dreams of the future better than the history of the past." - authenticatedPath = "/blah" - testPassword = "123456" - basicPrincipal = "admin" + wisdom = "The man who is swimming against the stream knows the strength of it." + extraWisdom = "Our greatest glory is not in never falling, but in rising every time we fall." + expReason = "I like the dreams of the future better than the history of the past." + testPassword = "123456" ) -var basicAuthHandler = UserPassAuthentication(func(user, pass string) (interface{}, error) { - if user == basicPrincipal && pass == testPassword { - return basicPrincipal, nil - } - return "", errors.Unauthenticated("basic") -}) - -func TestValidBasicAuth(t *testing.T) { - ba := BasicAuth(basicAuthHandler) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authenticatedPath, nil) - require.NoError(t, err) - - req.SetBasicAuth(basicPrincipal, testPassword) - ok, usr, err := ba.Authenticate(req) - require.NoError(t, err) - assert.True(t, ok) - assert.Equal(t, basicPrincipal, usr) -} - -func TestInvalidBasicAuth(t *testing.T) { - ba := BasicAuth(basicAuthHandler) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authenticatedPath, nil) - require.NoError(t, err) - req.SetBasicAuth(basicPrincipal, basicPrincipal) - - ok, usr, err := ba.Authenticate(req) - require.Error(t, err) - assert.True(t, ok) - assert.Equal(t, "", usr) - - assert.NotEmpty(t, FailedBasicAuth(req)) - assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) -} - -func TestMissingbasicAuth(t *testing.T) { - ba := BasicAuth(basicAuthHandler) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authenticatedPath, nil) - require.NoError(t, err) - - ok, usr, err := ba.Authenticate(req) - require.NoError(t, err) - assert.False(t, ok) - assert.Equal(t, nil, usr) - - assert.NotEmpty(t, FailedBasicAuth(req)) - assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) -} - -func TestNoRequestBasicAuth(t *testing.T) { +func TestBasicAuth(t *testing.T) { + basicAuthHandler := UserPassAuthentication(func(user, pass string) (interface{}, error) { + if user == principal && pass == testPassword { + return principal, nil + } + return "", errors.Unauthenticated("basic") + }) ba := BasicAuth(basicAuthHandler) - ok, usr, err := ba.Authenticate("token") - require.NoError(t, err) - assert.False(t, ok) - assert.Nil(t, usr) + t.Run("with valid basic auth", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.SetBasicAuth(principal, testPassword) + + ok, usr, err := ba.Authenticate(req) + require.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, principal, usr) + }) + + t.Run("with invalid basic auth", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.SetBasicAuth(principal, principal) + + ok, usr, err := ba.Authenticate(req) + require.Error(t, err) + assert.True(t, ok) + assert.Equal(t, "", usr) + + assert.NotEmpty(t, FailedBasicAuth(req)) + assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) + }) + + t.Run("with missing basic auth", func(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + + ok, usr, err := ba.Authenticate(req) + require.NoError(t, err) + assert.False(t, ok) + assert.Equal(t, nil, usr) + + assert.NotEmpty(t, FailedBasicAuth(req)) + assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) + }) + + t.Run("basic auth without request", func(*testing.T) { + ok, usr, err := ba.Authenticate("token") + require.NoError(t, err) + assert.False(t, ok) + assert.Nil(t, usr) + }) + + t.Run("with realm, invalid basic auth", func(t *testing.T) { + br := BasicAuthRealm("realm", basicAuthHandler) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.SetBasicAuth(principal, principal) + + ok, usr, err := br.Authenticate(req) + require.Error(t, err) + assert.True(t, ok) + assert.Equal(t, "", usr) + assert.Equal(t, "realm", FailedBasicAuth(req)) + }) + + t.Run("with empty realm, invalid basic auth", func(t *testing.T) { + br := BasicAuthRealm("", basicAuthHandler) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.SetBasicAuth(principal, principal) + + ok, usr, err := br.Authenticate(req) + require.Error(t, err) + assert.True(t, ok) + assert.Equal(t, "", usr) + assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) + }) } -var basicAuthHandlerCtx = UserPassAuthenticationCtx(func(ctx context.Context, user, pass string) (context.Context, interface{}, error) { - if user == basicPrincipal && pass == testPassword { - return context.WithValue(ctx, extra, extraWisdom), basicPrincipal, nil - } - return context.WithValue(ctx, reason, expReason), "", errors.Unauthenticated("basic") -}) - -func TestValidBasicAuthCtx(t *testing.T) { +func TestBasicAuthCtx(t *testing.T) { + basicAuthHandlerCtx := UserPassAuthenticationCtx(func(ctx context.Context, user, pass string) (context.Context, interface{}, error) { + if user == principal && pass == testPassword { + return context.WithValue(ctx, extra, extraWisdom), principal, nil + } + return context.WithValue(ctx, reason, expReason), "", errors.Unauthenticated("basic") + }) ba := BasicAuthCtx(basicAuthHandlerCtx) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authenticatedPath, nil) - require.NoError(t, err) - req = req.WithContext(context.WithValue(req.Context(), original, wisdom)) - - req.SetBasicAuth(basicPrincipal, testPassword) - ok, usr, err := ba.Authenticate(req) - require.NoError(t, err) - assert.True(t, ok) - assert.Equal(t, basicPrincipal, usr) - assert.Equal(t, wisdom, req.Context().Value(original)) - assert.Equal(t, extraWisdom, req.Context().Value(extra)) - assert.Nil(t, req.Context().Value(reason)) -} - -func TestInvalidBasicAuthCtx(t *testing.T) { - ba := BasicAuthCtx(basicAuthHandlerCtx) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authenticatedPath, nil) - require.NoError(t, err) - req = req.WithContext(context.WithValue(req.Context(), original, wisdom)) - req.SetBasicAuth(basicPrincipal, basicPrincipal) - - ok, usr, err := ba.Authenticate(req) - require.Error(t, err) - assert.True(t, ok) - assert.Equal(t, "", usr) - assert.Equal(t, wisdom, req.Context().Value(original)) - assert.Nil(t, req.Context().Value(extra)) - assert.Equal(t, expReason, req.Context().Value(reason)) -} - -func TestMissingbasicAuthCtx(t *testing.T) { - ba := BasicAuthCtx(basicAuthHandlerCtx) - - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authenticatedPath, nil) - require.NoError(t, err) - req = req.WithContext(context.WithValue(req.Context(), original, wisdom)) - - ok, usr, err := ba.Authenticate(req) - require.NoError(t, err) - assert.False(t, ok) - assert.Equal(t, nil, usr) - - assert.Equal(t, wisdom, req.Context().Value(original)) - assert.Nil(t, req.Context().Value(extra)) - assert.Nil(t, req.Context().Value(reason)) -} - -func TestNoRequestBasicAuthCtx(t *testing.T) { - ba := BasicAuthCtx(basicAuthHandlerCtx) - - ok, usr, err := ba.Authenticate("token") - require.NoError(t, err) - assert.False(t, ok) - assert.Nil(t, usr) + ctx := context.WithValue(context.Background(), original, wisdom) + + t.Run("with valid basic auth", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authPath, nil) + require.NoError(t, err) + + req.SetBasicAuth(principal, testPassword) + ok, usr, err := ba.Authenticate(req) + require.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, principal, usr) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Equal(t, extraWisdom, req.Context().Value(extra)) + assert.Nil(t, req.Context().Value(reason)) + }) + + t.Run("with invalid basic auth", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authPath, nil) + require.NoError(t, err) + req.SetBasicAuth(principal, principal) + + ok, usr, err := ba.Authenticate(req) + require.Error(t, err) + assert.True(t, ok) + assert.Equal(t, "", usr) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Nil(t, req.Context().Value(extra)) + assert.Equal(t, expReason, req.Context().Value(reason)) + }) + + t.Run("with missing basic auth", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authPath, nil) + require.NoError(t, err) + + ok, usr, err := ba.Authenticate(req) + require.NoError(t, err) + assert.False(t, ok) + assert.Equal(t, nil, usr) + + assert.Equal(t, wisdom, req.Context().Value(original)) + assert.Nil(t, req.Context().Value(extra)) + assert.Nil(t, req.Context().Value(reason)) + }) + + t.Run("basic auth without request", func(*testing.T) { + ok, usr, err := ba.Authenticate("token") + require.NoError(t, err) + assert.False(t, ok) + assert.Nil(t, usr) + }) + + t.Run("with realm, invalid basic auth", func(t *testing.T) { + br := BasicAuthRealmCtx("realm", basicAuthHandlerCtx) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.SetBasicAuth(principal, principal) + + ok, usr, err := br.Authenticate(req) + require.Error(t, err) + assert.True(t, ok) + assert.Equal(t, "", usr) + assert.Equal(t, "realm", FailedBasicAuth(req)) + }) + + t.Run("with empty realm, invalid basic auth", func(t *testing.T) { + br := BasicAuthRealmCtx("", basicAuthHandlerCtx) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, authPath, nil) + require.NoError(t, err) + req.SetBasicAuth(principal, principal) + + ok, usr, err := br.Authenticate(req) + require.Error(t, err) + assert.True(t, ok) + assert.Equal(t, "", usr) + assert.Equal(t, DefaultRealmName, FailedBasicAuth(req)) + }) } diff --git a/security/bearer_auth_test.go b/security/bearer_auth_test.go index 476d35b..42ff57b 100644 --- a/security/bearer_auth_test.go +++ b/security/bearer_auth_test.go @@ -3,6 +3,7 @@ package security import ( "bytes" "context" + "fmt" "mime/multipart" "net/http" "net/url" @@ -15,343 +16,285 @@ import ( "github.com/stretchr/testify/require" ) -var bearerAuth = ScopedTokenAuthentication(func(token string, _ []string) (interface{}, error) { - if token == "token123" { - return "admin", nil +const ( + owners = "owners_auth" + validToken = "token123" + invalidToken = "token124" + principal = "admin" + authPath = "/blah" + invalidParam = "access_toke" +) + +type authExpectation uint8 + +const ( + expectIsAuthorized authExpectation = iota + expectInvalidAuthorization + expectNoAuthorization +) + +func TestBearerAuth(t *testing.T) { + bearerAuth := ScopedTokenAuthentication(func(token string, _ []string) (interface{}, error) { + if token == validToken { + return principal, nil + } + return nil, errors.Unauthenticated("bearer") + }) + ba := BearerAuth(owners, bearerAuth) + ctx := context.Background() + + t.Run("with valid bearer auth", func(t *testing.T) { + t.Run("token in query param", + testAuthenticateBearerInQuery(ctx, ba, "", validToken, expectIsAuthorized), + ) + t.Run("token in header", + testAuthenticateBearerInHeader(ctx, ba, "", validToken, expectIsAuthorized), + ) + t.Run("token in urlencoded form", + testAuthenticateBearerInForm(ctx, ba, "", validToken, expectIsAuthorized), + ) + t.Run("token in multipart form", + testAuthenticateBearerInMultipartForm(ctx, ba, "", validToken, expectIsAuthorized), + ) + }) + + t.Run("with invalid token", func(t *testing.T) { + t.Run("token in query param", + testAuthenticateBearerInQuery(ctx, ba, "", invalidToken, expectInvalidAuthorization), + ) + t.Run("token in header", + testAuthenticateBearerInHeader(ctx, ba, "", invalidToken, expectInvalidAuthorization), + ) + t.Run("token in urlencoded form", + testAuthenticateBearerInForm(ctx, ba, "", invalidToken, expectInvalidAuthorization), + ) + t.Run("token in multipart form", + testAuthenticateBearerInMultipartForm(ctx, ba, "", invalidToken, expectInvalidAuthorization), + ) + }) + + t.Run("with missing auth", func(t *testing.T) { + t.Run("token in query param", + testAuthenticateBearerInQuery(ctx, ba, invalidParam, validToken, expectNoAuthorization), + ) + t.Run("token in header", + testAuthenticateBearerInHeader(ctx, ba, "Beare", validToken, expectNoAuthorization), + ) + t.Run("token in urlencoded form", + testAuthenticateBearerInForm(ctx, ba, invalidParam, validToken, expectNoAuthorization), + ) + t.Run("token in multipart form", + testAuthenticateBearerInMultipartForm(ctx, ba, invalidParam, validToken, expectNoAuthorization), + ) + }) +} + +func TestBearerAuthCtx(t *testing.T) { + bearerAuthCtx := ScopedTokenAuthenticationCtx(func(ctx context.Context, token string, _ []string) (context.Context, interface{}, error) { + if token == validToken { + return context.WithValue(ctx, extra, extraWisdom), principal, nil + } + return context.WithValue(ctx, reason, expReason), nil, errors.Unauthenticated("bearer") + }) + ba := BearerAuthCtx(owners, bearerAuthCtx) + ctx := context.WithValue(context.Background(), original, wisdom) + + assertContextOK := func(requestContext context.Context, t *testing.T) { + // when authorized, we have an "extra" key in context + assert.Equal(t, wisdom, requestContext.Value(original)) + assert.Equal(t, extraWisdom, requestContext.Value(extra)) + assert.Nil(t, requestContext.Value(reason)) + } + + assertContextKO := func(requestContext context.Context, t *testing.T) { + // when not authorized, we have a "reason" key in context + assert.Equal(t, wisdom, requestContext.Value(original)) + assert.Nil(t, requestContext.Value(extra)) + assert.Equal(t, expReason, requestContext.Value(reason)) + } + + assertContextNone := func(requestContext context.Context, t *testing.T) { + // when missing authorization, we only have the original context + assert.Equal(t, wisdom, requestContext.Value(original)) + assert.Nil(t, requestContext.Value(extra)) + assert.Nil(t, requestContext.Value(reason)) + } + + t.Run("with valid bearer auth", func(t *testing.T) { + t.Run("token in query param", + testAuthenticateBearerInQuery(ctx, ba, "", validToken, expectIsAuthorized, assertContextOK), + ) + t.Run("token in header", + testAuthenticateBearerInHeader(ctx, ba, "", validToken, expectIsAuthorized, assertContextOK), + ) + t.Run("token in urlencoded form", + testAuthenticateBearerInForm(ctx, ba, "", validToken, expectIsAuthorized, assertContextOK), + ) + t.Run("token in multipart form", + testAuthenticateBearerInMultipartForm(ctx, ba, "", validToken, expectIsAuthorized, assertContextOK), + ) + }) + + t.Run("with invalid token", func(t *testing.T) { + t.Run("token in query param", + testAuthenticateBearerInQuery(ctx, ba, "", invalidToken, expectInvalidAuthorization, assertContextKO), + ) + t.Run("token in header", + testAuthenticateBearerInHeader(ctx, ba, "", invalidToken, expectInvalidAuthorization, assertContextKO), + ) + t.Run("token in urlencoded form", + testAuthenticateBearerInForm(ctx, ba, "", invalidToken, expectInvalidAuthorization, assertContextKO), + ) + t.Run("token in multipart form", + testAuthenticateBearerInMultipartForm(ctx, ba, "", invalidToken, expectInvalidAuthorization, assertContextKO), + ) + }) + + t.Run("with missing auth", func(t *testing.T) { + t.Run("token in query param", + testAuthenticateBearerInQuery(ctx, ba, invalidParam, validToken, expectNoAuthorization, assertContextNone), + ) + t.Run("token in header", + testAuthenticateBearerInHeader(ctx, ba, "Beare", validToken, expectNoAuthorization, assertContextNone), + ) + t.Run("token in urlencoded form", + testAuthenticateBearerInForm(ctx, ba, invalidParam, validToken, expectNoAuthorization, assertContextNone), + ) + t.Run("token in multipart form", + testAuthenticateBearerInMultipartForm(ctx, ba, invalidParam, validToken, expectNoAuthorization, assertContextNone), + ) + }) +} + +func testIsAuthorized(_ context.Context, req *http.Request, authorizer runtime.Authenticator, expectAuthorized authExpectation, extraAsserters ...func(context.Context, *testing.T)) func(*testing.T) { + return func(t *testing.T) { + hasToken, usr, err := authorizer.Authenticate(&ScopedAuthRequest{Request: req}) + switch expectAuthorized { + + case expectIsAuthorized: + require.NoError(t, err) + assert.True(t, hasToken) + assert.Equal(t, principal, usr) + assert.Equal(t, owners, OAuth2SchemeName(req)) + + case expectInvalidAuthorization: + require.Error(t, err) + require.ErrorContains(t, err, "unauthenticated") + assert.True(t, hasToken) + assert.Nil(t, usr) + assert.Equal(t, owners, OAuth2SchemeName(req)) + + case expectNoAuthorization: + require.NoError(t, err) + assert.False(t, hasToken) + assert.Nil(t, usr) + assert.Empty(t, OAuth2SchemeName(req)) + } + + for _, contextAsserter := range extraAsserters { + contextAsserter(req.Context(), t) + } } - return nil, errors.Unauthenticated("bearer") -}) - -func TestValidBearerAuth(t *testing.T) { - ba := BearerAuth("owners_auth", bearerAuth) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?access_token=token123", nil) - require.NoError(t, err) - - ok, usr, err := ba.Authenticate(&ScopedAuthRequest{Request: req1}) - assert.True(t, ok) - assert.Equal(t, "admin", usr) - require.NoError(t, err) - assert.Equal(t, "owners_auth", OAuth2SchemeName(req1)) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2.Header.Set(runtime.HeaderAuthorization, "Bearer token123") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req2}) - assert.True(t, ok) - assert.Equal(t, "admin", usr) - require.NoError(t, err) - assert.Equal(t, "owners_auth", OAuth2SchemeName(req2)) - - body := url.Values(map[string][]string{}) - body.Set("access_token", "token123") - req3, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", strings.NewReader(body.Encode())) - require.NoError(t, err) - req3.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req3}) - assert.True(t, ok) - assert.Equal(t, "admin", usr) - require.NoError(t, err) - assert.Equal(t, "owners_auth", OAuth2SchemeName(req3)) - - mpbody := bytes.NewBuffer(nil) - writer := multipart.NewWriter(mpbody) - err = writer.WriteField("access_token", "token123") - require.NoError(t, err) - writer.Close() - req4, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", mpbody) - require.NoError(t, err) - req4.Header.Set("Content-Type", writer.FormDataContentType()) - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req4}) - assert.True(t, ok) - assert.Equal(t, "admin", usr) - require.NoError(t, err) - assert.Equal(t, "owners_auth", OAuth2SchemeName(req4)) } -//nolint:dupl -func TestInvalidBearerAuth(t *testing.T) { - ba := BearerAuth("owners_auth", bearerAuth) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?access_token=token124", nil) - require.NoError(t, err) - - ok, usr, err := ba.Authenticate(&ScopedAuthRequest{Request: req1}) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2.Header.Set(runtime.HeaderAuthorization, "Bearer token124") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req2}) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - - body := url.Values(map[string][]string{}) - body.Set("access_token", "token124") - req3, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", strings.NewReader(body.Encode())) - require.NoError(t, err) - req3.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req3}) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - - mpbody := bytes.NewBuffer(nil) - writer := multipart.NewWriter(mpbody) - require.NoError(t, writer.WriteField("access_token", "token124")) - writer.Close() - req4, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", mpbody) - require.NoError(t, err) - req4.Header.Set("Content-Type", writer.FormDataContentType()) - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req4}) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) +func shouldAuthorizeOrNot(expectAuthorized authExpectation) string { + if expectAuthorized == expectIsAuthorized { + return "should authorize" + } + + return "should not authorize" } -//nolint:dupl -func TestMissingBearerAuth(t *testing.T) { - ba := BearerAuth("owners_auth", bearerAuth) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?access_toke=token123", nil) - require.NoError(t, err) - - ok, usr, err := ba.Authenticate(&ScopedAuthRequest{Request: req1}) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2.Header.Set(runtime.HeaderAuthorization, "Beare token123") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req2}) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - - body := url.Values(map[string][]string{}) - body.Set("access_toke", "token123") - req3, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", strings.NewReader(body.Encode())) - require.NoError(t, err) - req3.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req3}) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - - mpbody := bytes.NewBuffer(nil) - writer := multipart.NewWriter(mpbody) - require.NoError(t, writer.WriteField("access_toke", "token123")) - writer.Close() - req4, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", mpbody) - require.NoError(t, err) - req4.Header.Set("Content-Type", writer.FormDataContentType()) - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req4}) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) +func testAuthenticateBearerInQuery( + // build a request with the token as a query parameter, then check against the authorizer + // + // the request context after authorization may be checked with the extraAsserters. + ctx context.Context, authorizer runtime.Authenticator, parameter, token string, expectAuthorized authExpectation, + extraAsserters ...func(context.Context, *testing.T), +) func(*testing.T) { + if parameter == "" { + parameter = accessTokenParam + } + + return func(t *testing.T) { + req, err := http.NewRequestWithContext( + ctx, http.MethodGet, + fmt.Sprintf("%s?%s=%s", authPath, parameter, token), + nil, + ) + require.NoError(t, err) + + t.Run( + shouldAuthorizeOrNot(expectAuthorized), + testIsAuthorized(ctx, req, authorizer, expectAuthorized, extraAsserters...), + ) + } } -var bearerAuthCtx = ScopedTokenAuthenticationCtx(func(ctx context.Context, token string, requiredScopes []string) (context.Context, interface{}, error) { - if token == "token123" { - return context.WithValue(ctx, extra, extraWisdom), "admin", nil +func testAuthenticateBearerInHeader( + // build a request with the token as a header, then check against the authorizer + ctx context.Context, authorizer runtime.Authenticator, parameter, token string, expectAuthorized authExpectation, + extraAsserters ...func(context.Context, *testing.T), +) func(*testing.T) { + if parameter == "" { + parameter = "Bearer" + } + + return func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authPath, nil) + require.NoError(t, err) + req.Header.Set(runtime.HeaderAuthorization, fmt.Sprintf("%s %s", parameter, token)) + + t.Run( + shouldAuthorizeOrNot(expectAuthorized), + testIsAuthorized(ctx, req, authorizer, expectAuthorized, extraAsserters...), + ) } - return context.WithValue(ctx, reason, expReason), nil, errors.Unauthenticated("bearer") -}) - -func TestValidBearerAuthCtx(t *testing.T) { - ba := BearerAuthCtx("owners_auth", bearerAuthCtx) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?access_token=token123", nil) - require.NoError(t, err) - req1 = req1.WithContext(context.WithValue(req1.Context(), original, wisdom)) - ok, usr, err := ba.Authenticate(&ScopedAuthRequest{Request: req1}) - assert.True(t, ok) - assert.Equal(t, "admin", usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req1.Context().Value(original)) - assert.Equal(t, extraWisdom, req1.Context().Value(extra)) - assert.Nil(t, req1.Context().Value(reason)) - assert.Equal(t, "owners_auth", OAuth2SchemeName(req1)) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom)) - req2.Header.Set(runtime.HeaderAuthorization, "Bearer token123") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req2}) - assert.True(t, ok) - assert.Equal(t, "admin", usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req2.Context().Value(original)) - assert.Equal(t, extraWisdom, req2.Context().Value(extra)) - assert.Nil(t, req2.Context().Value(reason)) - assert.Equal(t, "owners_auth", OAuth2SchemeName(req2)) - - body := url.Values(map[string][]string{}) - body.Set("access_token", "token123") - req3, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", strings.NewReader(body.Encode())) - require.NoError(t, err) - req3 = req3.WithContext(context.WithValue(req3.Context(), original, wisdom)) - req3.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req3}) - assert.True(t, ok) - assert.Equal(t, "admin", usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req3.Context().Value(original)) - assert.Equal(t, extraWisdom, req3.Context().Value(extra)) - assert.Nil(t, req3.Context().Value(reason)) - assert.Equal(t, "owners_auth", OAuth2SchemeName(req3)) - - mpbody := bytes.NewBuffer(nil) - writer := multipart.NewWriter(mpbody) - require.NoError(t, writer.WriteField("access_token", "token123")) - writer.Close() - req4, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", mpbody) - require.NoError(t, err) - req4 = req4.WithContext(context.WithValue(req4.Context(), original, wisdom)) - req4.Header.Set("Content-Type", writer.FormDataContentType()) - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req4}) - assert.True(t, ok) - assert.Equal(t, "admin", usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req4.Context().Value(original)) - assert.Equal(t, extraWisdom, req4.Context().Value(extra)) - assert.Nil(t, req4.Context().Value(reason)) - assert.Equal(t, "owners_auth", OAuth2SchemeName(req4)) } -func TestInvalidBearerAuthCtx(t *testing.T) { - ba := BearerAuthCtx("owners_auth", bearerAuthCtx) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?access_token=token124", nil) - require.NoError(t, err) - req1 = req1.WithContext(context.WithValue(req1.Context(), original, wisdom)) - ok, usr, err := ba.Authenticate(&ScopedAuthRequest{Request: req1}) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - assert.Equal(t, wisdom, req1.Context().Value(original)) - assert.Equal(t, expReason, req1.Context().Value(reason)) - assert.Nil(t, req1.Context().Value(extra)) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom)) - req2.Header.Set(runtime.HeaderAuthorization, "Bearer token124") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req2}) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - assert.Equal(t, wisdom, req2.Context().Value(original)) - assert.Equal(t, expReason, req2.Context().Value(reason)) - assert.Nil(t, req2.Context().Value(extra)) - - body := url.Values(map[string][]string{}) - body.Set("access_token", "token124") - req3, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", strings.NewReader(body.Encode())) - require.NoError(t, err) - req3 = req3.WithContext(context.WithValue(req3.Context(), original, wisdom)) - req3.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req3}) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - assert.Equal(t, wisdom, req3.Context().Value(original)) - assert.Equal(t, expReason, req3.Context().Value(reason)) - assert.Nil(t, req3.Context().Value(extra)) - - mpbody := bytes.NewBuffer(nil) - writer := multipart.NewWriter(mpbody) - require.NoError(t, writer.WriteField("access_token", "token124")) - writer.Close() - req4, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", mpbody) - require.NoError(t, err) - req4 = req4.WithContext(context.WithValue(req4.Context(), original, wisdom)) - req4.Header.Set("Content-Type", writer.FormDataContentType()) - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req4}) - assert.True(t, ok) - assert.Equal(t, nil, usr) - require.Error(t, err) - assert.Equal(t, wisdom, req4.Context().Value(original)) - assert.Equal(t, expReason, req4.Context().Value(reason)) - assert.Nil(t, req4.Context().Value(extra)) +func testAuthenticateBearerInForm( + // build a request with the token as a form field, then check against the authorizer + ctx context.Context, authorizer runtime.Authenticator, parameter, token string, expectAuthorized authExpectation, + extraAsserters ...func(context.Context, *testing.T), +) func(*testing.T) { + if parameter == "" { + parameter = accessTokenParam + } + + return func(t *testing.T) { + body := url.Values(map[string][]string{}) + body.Set(parameter, token) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authPath, strings.NewReader(body.Encode())) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + t.Run( + shouldAuthorizeOrNot(expectAuthorized), + testIsAuthorized(ctx, req, authorizer, expectAuthorized, extraAsserters...), + ) + } } +func testAuthenticateBearerInMultipartForm( + // build a request with the token as a multipart form field, then check against the authorizer + ctx context.Context, authorizer runtime.Authenticator, parameter, token string, expectAuthorized authExpectation, + extraAsserters ...func(context.Context, *testing.T), +) func(*testing.T) { + if parameter == "" { + parameter = accessTokenParam + } -func TestMissingBearerAuthCtx(t *testing.T) { - ba := BearerAuthCtx("owners_auth", bearerAuthCtx) - - req1, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah?access_toke=token123", nil) - require.NoError(t, err) - req1 = req1.WithContext(context.WithValue(req1.Context(), original, wisdom)) - ok, usr, err := ba.Authenticate(&ScopedAuthRequest{Request: req1}) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req1.Context().Value(original)) - assert.Nil(t, req1.Context().Value(reason)) - assert.Nil(t, req1.Context().Value(extra)) - - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/blah", nil) - require.NoError(t, err) - req2.Header.Set(runtime.HeaderAuthorization, "Beare token123") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req2}) - req2 = req2.WithContext(context.WithValue(req2.Context(), original, wisdom)) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req2.Context().Value(original)) - assert.Nil(t, req2.Context().Value(reason)) - assert.Nil(t, req2.Context().Value(extra)) - - body := url.Values(map[string][]string{}) - body.Set("access_toke", "token123") - req3, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", strings.NewReader(body.Encode())) - require.NoError(t, err) - req3 = req3.WithContext(context.WithValue(req3.Context(), original, wisdom)) - req3.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req3}) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req3.Context().Value(original)) - assert.Nil(t, req3.Context().Value(reason)) - assert.Nil(t, req3.Context().Value(extra)) - - mpbody := bytes.NewBuffer(nil) - writer := multipart.NewWriter(mpbody) - err = writer.WriteField("access_toke", "token123") - require.NoError(t, err) - require.NoError(t, writer.Close()) - req4, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/blah", mpbody) - require.NoError(t, err) - req4 = req4.WithContext(context.WithValue(req4.Context(), original, wisdom)) - req4.Header.Set("Content-Type", writer.FormDataContentType()) - - ok, usr, err = ba.Authenticate(&ScopedAuthRequest{Request: req4}) - assert.False(t, ok) - assert.Equal(t, nil, usr) - require.NoError(t, err) - assert.Equal(t, wisdom, req4.Context().Value(original)) - assert.Nil(t, req4.Context().Value(reason)) - assert.Nil(t, req4.Context().Value(extra)) + return func(t *testing.T) { + body := bytes.NewBuffer(nil) + writer := multipart.NewWriter(body) + require.NoError(t, writer.WriteField(parameter, token)) + require.NoError(t, writer.Close()) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authPath, body) + require.NoError(t, err) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + t.Run( + shouldAuthorizeOrNot(expectAuthorized), + testIsAuthorized(ctx, req, authorizer, expectAuthorized, extraAsserters...), + ) + } }