diff --git a/cmd/server/handler_oauth2_factory.go b/cmd/server/handler_oauth2_factory.go index 47907063957..137244e2f34 100644 --- a/cmd/server/handler_oauth2_factory.go +++ b/cmd/server/handler_oauth2_factory.go @@ -194,6 +194,7 @@ func newOAuth2Handler(c *config.Config, router *httprouter.Router, cm consent.Ma L: c.GetLogger(), OpenIDJWTStrategy: openIDJWTStrategy, AccessTokenJWTStrategy: accessTokenJWTStrategy, + AccessTokenStrategy: c.OAuth2AccessTokenStrategy, IDTokenLifespan: c.GetIDTokenLifespan(), } diff --git a/jwk/jwt_strategy.go b/jwk/jwt_strategy.go index 700ba3836d6..1b0a6f85d67 100644 --- a/jwk/jwt_strategy.go +++ b/jwk/jwt_strategy.go @@ -29,6 +29,12 @@ import ( "github.com/pkg/errors" ) +type JWTStrategy interface { + GetPublicKeyID() (string, error) + + jwt.JWTStrategy +} + func NewRS256JWTStrategy(m Manager, set string) (*RS256JWTStrategy, error) { j := &RS256JWTStrategy{ Manager: m, diff --git a/oauth2/handler.go b/oauth2/handler.go index bb025edf504..33327ed2ab6 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -475,7 +475,7 @@ func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request, _ httprou if accessRequest.GetGrantTypes().Exact("client_credentials") { var accessTokenKeyID string - if h.AccessTokenJWTStrategy != nil { + if h.AccessTokenStrategy == "jwt" { accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID() if err != nil { pkg.LogError(err, h.L) @@ -557,7 +557,7 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout } var accessTokenKeyID string - if h.AccessTokenJWTStrategy != nil { + if h.AccessTokenStrategy == "jwt" { accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID() if err != nil { pkg.LogError(err, h.L) diff --git a/oauth2/handler_struct.go b/oauth2/handler_struct.go index 37a1c6e2d99..d21721d287a 100644 --- a/oauth2/handler_struct.go +++ b/oauth2/handler_struct.go @@ -47,8 +47,9 @@ type Handler struct { IDTokenLifespan time.Duration CookieStore sessions.Store - OpenIDJWTStrategy *jwk.RS256JWTStrategy - AccessTokenJWTStrategy *jwk.RS256JWTStrategy + OpenIDJWTStrategy jwk.JWTStrategy + AccessTokenJWTStrategy jwk.JWTStrategy + AccessTokenStrategy string L logrus.FieldLogger