diff --git a/server/introspectionhandler.go b/server/introspectionhandler.go index 0906640383..27134c992c 100644 --- a/server/introspectionhandler.go +++ b/server/introspectionhandler.go @@ -82,7 +82,7 @@ func (s *Server) introspectRefreshToken(_ context.Context, token string) (*Intro return nil, newIntrospectInternalServerError() } - subjectString, sErr := s.genSubject(rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) + subjectString, sErr := genSubject(rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) if sErr != nil { s.logger.Errorf("failed to marshal offline session ID: %v", err) return nil, newIntrospectInternalServerError() @@ -96,7 +96,7 @@ func (s *Server) introspectRefreshToken(_ context.Context, token string) (*Intro Expiry: rCtx.storageToken.CreatedAt.Add(s.refreshTokenPolicy.absoluteLifetime).Unix(), Subject: subjectString, Username: rCtx.storageToken.Claims.PreferredUsername, - Audience: s.getAudience(rCtx.storageToken.ClientID, rCtx.scopes), + Audience: getAudience(rCtx.storageToken.ClientID, rCtx.scopes), Issuer: s.issuerURL.String(), Extra: IntrospectionExtra{ @@ -123,7 +123,7 @@ func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Intr return nil, newIntrospectInternalServerError() } - clientID, err := s.getClientID(idToken.Audience, claims.AuthorizingParty) + clientID, err := getClientID(idToken.Audience, claims.AuthorizingParty) if err != nil { return nil, newIntrospectInternalServerError() } diff --git a/server/oauth2.go b/server/oauth2.go index 47e9d1edc4..0e967ddebc 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -307,7 +307,7 @@ func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes [ return s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID) } -func (s *Server) getClientID(aud audience, azp string) (string, error) { +func getClientID(aud audience, azp string) (string, error) { switch len(aud) { case 0: return "", fmt.Errorf("no audience is set, could not find ClientID") @@ -318,7 +318,7 @@ func (s *Server) getClientID(aud audience, azp string) (string, error) { } } -func (s *Server) getAudience(clientID string, scopes []string) audience { +func getAudience(clientID string, scopes []string) audience { var aud audience for _, scope := range scopes { @@ -341,7 +341,7 @@ func (s *Server) getAudience(clientID string, scopes []string) audience { return aud } -func (s *Server) genSubject(userID string, connID string) (string, error) { +func genSubject(userID string, connID string) (string, error) { sub := &internal.IDTokenSubject{ UserId: userID, ConnId: connID, @@ -369,7 +369,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str issuedAt := s.now() expiry = issuedAt.Add(s.idTokensValidFor) - subjectString, err := s.genSubject(claims.UserID, connID) + subjectString, err := genSubject(claims.UserID, connID) if err != nil { s.logger.Errorf("failed to marshal offline session ID: %v", err) return "", expiry, fmt.Errorf("failed to marshal offline session ID: %v", err) @@ -434,7 +434,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str } } - tok.Audience = s.getAudience(clientID, scopes) + tok.Audience = getAudience(clientID, scopes) if len(tok.Audience) > 1 { // The current client becomes the authorizing party. tok.AuthorizingParty = clientID diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 1acff6518a..e27fc79932 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -14,8 +14,40 @@ import ( "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/memory" + "github.com/stretchr/testify/require" ) +func TestGetClientID(t *testing.T) { + cid, err := getClientID(audience{}, "") + require.Equal(t, "", cid) + require.Equal(t, "no audience is set, could not find ClientID", err.Error()) + + cid, err = getClientID(audience{"a"}, "") + require.Equal(t, "a", cid) + require.NoError(t, err) + + cid, err = getClientID(audience{"a", "b"}, "azp") + require.Equal(t, "azp", cid) + require.NoError(t, err) +} + +func TestGetAudience(t *testing.T) { + aud := getAudience("client-id", []string{}) + require.Equal(t, aud, audience{"client-id"}) + + aud = getAudience("client-id", []string{"ascope"}) + require.Equal(t, aud, audience{"client-id"}) + + aud = getAudience("client-id", []string{"ascope", "audience:server:client_id:aa", "audience:server:client_id:bb"}) + require.Equal(t, aud, audience{"aa", "bb", "client-id"}) +} + +func TestGetSubject(t *testing.T) { + sub, err := genSubject("foo", "bar") + require.Equal(t, "CgNmb28SA2Jhcg", sub) + require.NoError(t, err) +} + func TestParseAuthorizationRequest(t *testing.T) { tests := []struct { name string