Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
TimVosch committed May 13, 2024
1 parent da9c187 commit 8bd13d2
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 87 deletions.
68 changes: 46 additions & 22 deletions pkg/auth/context.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package auth

import "context"
import (
"context"
"errors"
"fmt"
)

type ctxKey int

Expand All @@ -10,6 +14,11 @@ const (
ctxPermissions
)

var (
ErrInvalidContext = errors.New("invalid auth context")
ErrContextMissing = errors.New("missing auth context")
)

func setUserID(ctx context.Context, userID string) context.Context {
return context.WithValue(ctx, ctxUserID, userID)
}
Expand All @@ -23,36 +32,51 @@ func setPermissions(ctx context.Context, permissions Permissions) context.Contex
}

func GetTenant(ctx context.Context) (int64, error) {
val, ok := fromContext[int64](ctx, ctxTenantID)
if !ok || val == 0 {
return 0, ErrNoTenantIDFound
value := ctx.Value(ctxTenantID)
if value == nil {
return 0, fmt.Errorf("%w: %w", ErrInvalidContext, ErrNoTenantIDFound)
}

typedValue, ok := value.(int64)
if !ok {
return 0, fmt.Errorf("%w: TenantID value is wrong type %T", ErrInvalidContext, value)
}

if typedValue == 0 {
return 0, fmt.Errorf("%w: %w", ErrInvalidContext, ErrNoTenantIDFound)
}
return val, nil

return typedValue, nil
}

func GetUser(ctx context.Context) (string, error) {
val, ok := fromContext[string](ctx, ctxUserID)
if !ok || val == "" {
return "", ErrNoUserID
value := ctx.Value(ctxUserID)
if value == nil {
return "", fmt.Errorf("%w: %w", ErrInvalidContext, ErrNoUserID)
}
return val, nil
}

func GetPermissions(ctx context.Context) (Permissions, error) {
val, ok := fromContext[Permissions](ctx, ctxPermissions)
typedValue, ok := value.(string)
if !ok {
return Permissions{}, ErrNoPermissions
return "", fmt.Errorf("%w: UserID value is wrong type %T", ErrInvalidContext, value)
}

if typedValue == "" {
return "", fmt.Errorf("%w: %w", ErrInvalidContext, ErrNoUserID)
}
return val, nil

return typedValue, nil
}

func fromContext[T any](ctx context.Context, key ctxKey) (T, bool) {
var val T
var ok bool
ival := ctx.Value(key)
if ival == nil {
return val, false
func GetPermissions(ctx context.Context) (Permissions, error) {
value := ctx.Value(ctxPermissions)
if value == nil {
return Permissions{}, fmt.Errorf("%w: %w", ErrInvalidContext, ErrNoPermissions)
}
val, ok = ival.(T)
return val, ok

typedValue, ok := value.(Permissions)
if !ok {
return Permissions{}, fmt.Errorf("%w: TenantID value is wrong type %T", ErrInvalidContext, value)
}

return typedValue, nil
}
14 changes: 8 additions & 6 deletions pkg/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
Expand Down Expand Up @@ -64,16 +65,17 @@ func Protect() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := GetTenant(r.Context()); err != nil {
log.Println("[Auth] token is missing tenant!")
log.Printf("[Auth] %v\n", err)
web.HTTPError(w, ErrUnauthorized)
return
}
if _, err := GetUser(r.Context()); err != nil && !errors.Is(err, ErrContextMissing) {
log.Printf("[Auth] %v\n", err)
web.HTTPError(w, ErrUnauthorized)
return
}
//if _, err := GetUser(r.Context()); err != nil {
// web.HTTPError(w, ErrUnauthorized)
// return
//}
if _, err := GetPermissions(r.Context()); err != nil {
log.Println("[Auth] token is missing permissions!")
log.Printf("[Auth] %v\n", err)
web.HTTPError(w, ErrUnauthorized)
return
}
Expand Down
11 changes: 5 additions & 6 deletions pkg/auth/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/golang-jwt/jwt"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// test jwks is unreachable
Expand Down Expand Up @@ -82,9 +83,8 @@ func TestProtectAndAuthenticatePassClaimsToNext(t *testing.T) {
s.Handle("/", auth(protect(next)))

req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

rr := httptest.NewRecorder()
token := createToken(jwt.MapClaims{
"tid": 11,
Expand Down Expand Up @@ -175,9 +175,8 @@ func TestProtect(t *testing.T) {
for scene, cfg := range scenarios {
t.Run(scene, func(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

rr := httptest.NewRecorder()
ctx := createTestContext(context.Background(), cfg.values)

Expand Down
2 changes: 1 addition & 1 deletion pkg/auth/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func TestGetTenant(t *testing.T) {

// Assert
assert.Equal(t, cfg.expectedRes, result)
assert.Equal(t, cfg.expectedErr, err)
assert.ErrorIs(t, err, cfg.expectedErr)
})
}
}
35 changes: 18 additions & 17 deletions services/tenants/apikeys/application_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package apikeys_test

import (
"context"
"encoding/base64"
"fmt"
"testing"
Expand Down Expand Up @@ -42,7 +43,7 @@ func TestGenerateNewApiKeyCreatesNewApiKey(t *testing.T) {
}
s := apikeys.NewAPIKeyService(tenantStore, apiKeyStore)
// Act
res, err := s.GenerateNewApiKey("whatever", 905, auth.Permissions{auth.READ_DEVICES}, &exp)
res, err := s.GenerateNewApiKey(context.Background(), "whatever", 905, auth.Permissions{auth.READ_DEVICES}, &exp)

// Assert
assert.NoError(t, err)
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestGenerateNewAPIKeyNameAndTenantCombinationNotUnique(t *testing.T) {
s := apikeys.NewAPIKeyService(tenantStore, apiKeyStore)

// Act
res, err := s.GenerateNewApiKey("whatever", 905, auth.Permissions{auth.READ_API_KEYS}, &exp)
res, err := s.GenerateNewApiKey(context.Background(), "whatever", 905, auth.Permissions{auth.READ_API_KEYS}, &exp)

// Assert
assert.ErrorIs(t, err, apikeys.ErrKeyNameTenantIDCombinationNotUnique)
Expand Down Expand Up @@ -108,7 +109,7 @@ func TestGenerateNewAPIKeyCheckCombinationUniqueErrorOccurs(t *testing.T) {
s := apikeys.NewAPIKeyService(tenantStore, apiKeyStore)

// Act
res, err := s.GenerateNewApiKey("whatever", 905, auth.Permissions{auth.READ_DEVICES}, &exp)
res, err := s.GenerateNewApiKey(context.Background(), "whatever", 905, auth.Permissions{auth.READ_DEVICES}, &exp)

// Assert
assert.Error(t, err)
Expand Down Expand Up @@ -144,7 +145,7 @@ func TestGenerateNewApiKeyErrorOccursWhileAddingApiKeyToStore(t *testing.T) {
s := apikeys.NewAPIKeyService(tenantStore, apiKeyStore)

// Act
res, err := s.GenerateNewApiKey("whatever", 905, auth.Permissions{auth.READ_DEVICES}, nil)
res, err := s.GenerateNewApiKey(context.Background(), "whatever", 905, auth.Permissions{auth.READ_DEVICES}, nil)

// Assert
assert.Error(t, err)
Expand All @@ -160,7 +161,7 @@ func TestGenerateNewApiKeyPermissionsContains1InvalidPermission(t *testing.T) {
s := apikeys.NewAPIKeyService(tenantStore, apiKeyStore)

// Act
res, err := s.GenerateNewApiKey("whatever", 905, auth.Permissions{auth.READ_API_KEYS, auth.READ_DEVICES, auth.Permission("invalidpermission")}, nil)
res, err := s.GenerateNewApiKey(context.Background(), "whatever", 905, auth.Permissions{auth.READ_API_KEYS, auth.READ_DEVICES, auth.Permission("invalidpermission")}, nil)

// Assert
assert.ErrorIs(t, err, apikeys.ErrPermissionsInvalid)
Expand All @@ -182,7 +183,7 @@ func TestGenerateNewApiKeyErrorOccursWhenRetrievingTenant(t *testing.T) {
s := apikeys.NewAPIKeyService(tenantStore, &ApiKeyStoreMock{})

// Act
res, err := s.GenerateNewApiKey("whatever", 905, auth.Permissions{auth.READ_DEVICES}, nil)
res, err := s.GenerateNewApiKey(context.Background(), "whatever", 905, auth.Permissions{auth.READ_DEVICES}, nil)

// Assert
assert.Error(t, err)
Expand All @@ -201,7 +202,7 @@ func TestGenerateNewApiKeyTenantDoesNotExist(t *testing.T) {
s := apikeys.NewAPIKeyService(tenantStore, &ApiKeyStoreMock{})

// Act
res, err := s.GenerateNewApiKey("whatever", 334, auth.Permissions{auth.READ_DEVICES}, nil)
res, err := s.GenerateNewApiKey(context.Background(), "whatever", 334, auth.Permissions{auth.READ_DEVICES}, nil)

// Assert
assert.ErrorIs(t, err, apikeys.ErrTenantIsNotValid)
Expand All @@ -222,7 +223,7 @@ func TestGenerateNewApiKeyTenantIsNottenantsActive(t *testing.T) {
s := apikeys.NewAPIKeyService(tenantStore, &ApiKeyStoreMock{})

// Act
res, err := s.GenerateNewApiKey("whatever", 334, auth.Permissions{auth.READ_DEVICES}, nil)
res, err := s.GenerateNewApiKey(context.Background(), "whatever", 334, auth.Permissions{auth.READ_DEVICES}, nil)

// Assert
assert.ErrorIs(t, err, apikeys.ErrTenantIsNotValid)
Expand All @@ -241,7 +242,7 @@ func TestRevokeApiKeyDeletesKey(t *testing.T) {
s := apikeys.NewAPIKeyService(&TenantStoreMock{}, apiKeyStore)

// Act
err := s.RevokeApiKey(665213432)
err := s.RevokeApiKey(context.Background(), 665213432)

// Assert
assert.NoError(t, err)
Expand All @@ -259,7 +260,7 @@ func TestRevokeApiKeyErrorOccurs(t *testing.T) {
s := apikeys.NewAPIKeyService(&TenantStoreMock{}, apiKeyStore)

// Act
err := s.RevokeApiKey(83245345)
err := s.RevokeApiKey(context.Background(), 83245345)

// Assert
assert.Error(t, err)
Expand All @@ -277,7 +278,7 @@ func TestRevokeApiKeyWasNotDeletedByStore(t *testing.T) {
s := apikeys.NewAPIKeyService(&TenantStoreMock{}, apiKeyStore)

// Act
err := s.RevokeApiKey(83245345)
err := s.RevokeApiKey(context.Background(), 83245345)

// Assert
assert.ErrorIs(t, err, apikeys.ErrKeyNotFound)
Expand All @@ -297,7 +298,7 @@ func TestValidateApiKeyInvalidEncoding(t *testing.T) {
t.Run(scenario, func(t *testing.T) {
// Act
s := &apikeys.Service{}
res, err := s.AuthenticateApiKey(input)
res, err := s.AuthenticateApiKey(context.Background(), input)

// Assert
assert.EqualValues(t, 0, res.TenantID)
Expand All @@ -318,7 +319,7 @@ func TestValidateApiKeyErrorOccursWhileRetrievingKey(t *testing.T) {
s := apikeys.NewAPIKeyService(&TenantStoreMock{}, apiKeyStore)

// Act
res, err := s.AuthenticateApiKey(asBase64("43214:somevalidapikey"))
res, err := s.AuthenticateApiKey(context.Background(), asBase64("43214:somevalidapikey"))

// Assert
assert.EqualValues(t, 0, res.TenantID)
Expand All @@ -344,7 +345,7 @@ func TestValidateApiKeyInvalidKey(t *testing.T) {
s := apikeys.NewAPIKeyService(&TenantStoreMock{}, apiKeyStore)

// Act
res, err := s.AuthenticateApiKey(asBase64("43214:someinvalidapikey"))
res, err := s.AuthenticateApiKey(context.Background(), asBase64("43214:someinvalidapikey"))

// Assert
assert.EqualValues(t, 0, res.TenantID)
Expand Down Expand Up @@ -376,7 +377,7 @@ func TestValidateApiKeyKeyIsExpired(t *testing.T) {
s := apikeys.NewAPIKeyService(&TenantStoreMock{}, apiKeyStore)

// Act
res, err := s.AuthenticateApiKey(asBase64("43214:kayJhmgiCNNQAKwtvewxN6BWSTiEINOy"))
res, err := s.AuthenticateApiKey(context.Background(), asBase64("43214:kayJhmgiCNNQAKwtvewxN6BWSTiEINOy"))

// Assert
assert.EqualValues(t, 0, res.TenantID)
Expand Down Expand Up @@ -409,7 +410,7 @@ func TestValidateApiKeyKeyIsExpiredDeleteErrorOccurs(t *testing.T) {
s := apikeys.NewAPIKeyService(&TenantStoreMock{}, apiKeyStore)

// Act
res, err := s.AuthenticateApiKey(asBase64("43214:kayJhmgiCNNQAKwtvewxN6BWSTiEINOy"))
res, err := s.AuthenticateApiKey(context.Background(), asBase64("43214:kayJhmgiCNNQAKwtvewxN6BWSTiEINOy"))

// Assert
assert.EqualValues(t, 0, res.TenantID)
Expand All @@ -436,7 +437,7 @@ func TestValidateApiKeyValidKey(t *testing.T) {
s := apikeys.NewAPIKeyService(&TenantStoreMock{}, apiKeyStore)

// Act
res, err := s.AuthenticateApiKey(asBase64("43214:kayJhmgiCNNQAKwtvewxN6BWSTiEINOy"))
res, err := s.AuthenticateApiKey(context.Background(), asBase64("43214:kayJhmgiCNNQAKwtvewxN6BWSTiEINOy"))

// Assert
assert.EqualValues(t, 534, res.TenantID)
Expand Down
Loading

0 comments on commit 8bd13d2

Please sign in to comment.