diff --git a/server/handlers.go b/server/handlers.go index 5faab2c9ae..ccd534d991 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -78,6 +78,7 @@ type discovery struct { Keys string `json:"jwks_uri"` UserInfo string `json:"userinfo_endpoint"` DeviceEndpoint string `json:"device_authorization_endpoint"` + Introspect string `json:"introspection_endpoint"` GrantTypes []string `json:"grant_types_supported"` ResponseTypes []string `json:"response_types_supported"` Subjects []string `json:"subject_types_supported"` @@ -96,6 +97,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { Keys: s.absURL("/keys"), UserInfo: s.absURL("/userinfo"), DeviceEndpoint: s.absURL("/device/code"), + Introspect: s.absURL("/token/introspect"), Subjects: []string{"public"}, IDTokenAlgs: []string{string(jose.RS256)}, CodeChallengeAlgs: []string{codeChallengeMethodS256, codeChallengeMethodPlain}, diff --git a/server/handlers_test.go b/server/handlers_test.go index f6ada3634d..d32101b1cf 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -37,6 +37,76 @@ func TestHandleHealth(t *testing.T) { } } +func TestHandleDiscovery(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, server := newTestServer(ctx, t, nil) + defer httpServer.Close() + + rr := httptest.NewRecorder() + server.ServeHTTP(rr, httptest.NewRequest("GET", "/.well-known/openid-configuration", nil)) + if rr.Code != http.StatusOK { + t.Errorf("expected 200 got %d", rr.Code) + } + + var res discovery + err := json.NewDecoder(rr.Result().Body).Decode(&res) + require.NoError(t, err) + require.Equal(t, discovery{ + Issuer: httpServer.URL, + Auth: fmt.Sprintf("%s/auth", httpServer.URL), + Token: fmt.Sprintf("%s/token", httpServer.URL), + Keys: fmt.Sprintf("%s/keys", httpServer.URL), + UserInfo: fmt.Sprintf("%s/userinfo", httpServer.URL), + DeviceEndpoint: fmt.Sprintf("%s/device/code", httpServer.URL), + Introspect: fmt.Sprintf("%s/token/introspect", httpServer.URL), + GrantTypes: []string{ + "authorization_code", + "refresh_token", + "urn:ietf:params:oauth:grant-type:device_code", + "urn:ietf:params:oauth:grant-type:token-exchange", + }, + ResponseTypes: []string{ + "code", + }, + Subjects: []string{ + "public", + }, + IDTokenAlgs: []string{ + "RS256", + }, + CodeChallengeAlgs: []string{ + "S256", + "plain", + }, + Scopes: []string{ + "openid", + "email", + "groups", + "profile", + "offline_access", + }, + AuthMethods: []string{ + "client_secret_basic", + "client_secret_post", + }, + Claims: []string{ + "iss", + "sub", + "aud", + "iat", + "exp", + "email", + "email_verified", + "locale", + "name", + "preferred_username", + "at_hash", + }, + }, res) +} + func TestHandleHealthFailure(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/server/introspectionhandler.go b/server/introspectionhandler.go new file mode 100644 index 0000000000..a33f20bd9b --- /dev/null +++ b/server/introspectionhandler.go @@ -0,0 +1,347 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/coreos/go-oidc/v3/oidc" + + "github.com/dexidp/dex/server/internal" +) + +// Introspection contains an access token's session data as specified by +// [IETF RFC 7662](https://tools.ietf.org/html/rfc7662) +type Introspection struct { + // Boolean indicator of whether or not the presented token + // is currently active. The specifics of a token's "active" state + // will vary depending on the implementation of the authorization + // server and the information it keeps about its tokens, but a "true" + // value return for the "active" property will generally indicate + // that a given token has been issued by this authorization server, + // has not been revoked by the resource owner, and is within its + // given time window of validity (e.g., after its issuance time and + // before its expiration time). + Active bool `json:"active"` + + // JSON string containing a space-separated list of + // scopes associated with this token. + Scope string `json:"scope,omitempty"` + + // Client identifier for the OAuth 2.0 client that + // requested this token. + ClientID string `json:"client_id"` + + // Subject of the token, as defined in JWT [RFC7519]. + // Usually a machine-readable identifier of the resource owner who + // authorized this token. + Subject string `json:"sub"` + + // Integer timestamp, measured in the number of seconds + // since January 1 1970 UTC, indicating when this token will expire. + Expiry int64 `json:"exp"` + + // Integer timestamp, measured in the number of seconds + // since January 1 1970 UTC, indicating when this token was + // originally issued. + IssuedAt int64 `json:"iat"` + + // Integer timestamp, measured in the number of seconds + // since January 1 1970 UTC, indicating when this token is not to be + // used before. + NotBefore int64 `json:"nbf"` + + // Human-readable identifier for the resource owner who + // authorized this token. + Username string `json:"username,omitempty"` + + // Service-specific string identifier or list of string + // identifiers representing the intended audience for this token, as + // defined in JWT + Audience audience `json:"aud"` + + // String representing the issuer of this token, as + // defined in JWT + Issuer string `json:"iss"` + + // String identifier for the token, as defined in JWT [RFC7519]. + JwtTokenID string `json:"jti,omitempty"` + + // TokenType is the introspected token's type, typically `bearer`. + TokenType string `json:"token_type"` + + // TokenUse is the introspected token's use, for example `access_token` or `refresh_token`. + TokenUse string `json:"token_use"` + + // Extra is arbitrary data set from the token claims. + Extra IntrospectionExtra `json:"ext,omitempty"` +} + +type IntrospectionExtra struct { + AuthorizingParty string `json:"azp,omitempty"` + + Email string `json:"email,omitempty"` + EmailVerified *bool `json:"email_verified,omitempty"` + + Groups []string `json:"groups,omitempty"` + + Name string `json:"name,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty"` + + FederatedIDClaims *federatedIDClaims `json:"federated_claims,omitempty"` +} + +type TokenTypeEnum int + +const ( + AccessToken TokenTypeEnum = iota + RefreshToken +) + +func (t TokenTypeEnum) String() string { + switch t { + case AccessToken: + return "access_token" + case RefreshToken: + return "refresh_token" + default: + return fmt.Sprintf("TokenTypeEnum(%d)", t) + } +} + +type introspectionError struct { + typ string + code int + desc string +} + +func (e *introspectionError) Error() string { + return fmt.Sprintf("introspection error: status %d, %q %s", e.code, e.typ, e.desc) +} + +func (e *introspectionError) Is(tgt error) bool { + target, ok := tgt.(*introspectionError) + if !ok { + return false + } + + return e.typ == target.typ && + e.code == target.code && + e.desc == target.desc +} + +func newIntrospectInactiveTokenError() *introspectionError { + return &introspectionError{typ: errInactiveToken, desc: "", code: http.StatusUnauthorized} +} + +func newIntrospectInternalServerError() *introspectionError { + return &introspectionError{typ: errServerError, desc: "", code: http.StatusInternalServerError} +} + +func newIntrospectBadRequestError(desc string) *introspectionError { + return &introspectionError{typ: errInvalidRequest, desc: desc, code: http.StatusBadRequest} +} + +func (s *Server) guessTokenType(ctx context.Context, token string) (TokenTypeEnum, error) { + // We skip every checks, we only want to know if it's a valid JWT + verifierConfig := oidc.Config{ + SkipClientIDCheck: true, + SkipExpiryCheck: true, + SkipIssuerCheck: true, + + // We skip signature checks to avoid database calls; + InsecureSkipSignatureCheck: true, + } + + verifier := oidc.NewVerifier(s.issuerURL.String(), nil, &verifierConfig) + if _, err := verifier.Verify(ctx, token); err != nil { + // If it's not an access token, let's assume it's a refresh token; + return RefreshToken, nil + } + + // If it's a valid JWT, it's an access token. + return AccessToken, nil +} + +func (s *Server) getTokenFromRequest(r *http.Request) (string, TokenTypeEnum, error) { + if r.Method != "POST" { + return "", 0, newIntrospectBadRequestError(fmt.Sprintf("HTTP method is \"%s\", expected \"POST\".", r.Method)) + } else if err := r.ParseForm(); err != nil { + return "", 0, newIntrospectBadRequestError("Unable to parse HTTP body, make sure to send a properly formatted form request body.") + } else if r.PostForm == nil || len(r.PostForm) == 0 { + return "", 0, newIntrospectBadRequestError("The POST body can not be empty.") + } else if !r.PostForm.Has("token") { + return "", 0, newIntrospectBadRequestError("The POST body doesn't contain 'token' parameter.") + } + + token := r.PostForm.Get("token") + tokenType, err := s.guessTokenType(r.Context(), token) + if err != nil { + s.logger.Error(err) + return "", 0, newIntrospectInternalServerError() + } + + requestTokenType := r.PostForm.Get("token_type_hint") + if requestTokenType != "" { + if tokenType.String() != requestTokenType { + s.logger.Warnf("Token type hint doesn't match token type: %s != %s", requestTokenType, tokenType) + } + } + + return token, tokenType, nil +} + +func (s *Server) introspectRefreshToken(_ context.Context, token string) (*Introspection, error) { + rToken := new(internal.RefreshToken) + if err := internal.Unmarshal(token, rToken); err != nil { + // For backward compatibility, assume the refresh_token is a raw refresh token ID + // if it fails to decode. + // + // Because refresh_token values that aren't unmarshable were generated by servers + // that don't have a Token value, we'll still reject any attempts to claim a + // refresh_token twice. + rToken = &internal.RefreshToken{RefreshId: token, Token: ""} + } + + rCtx, err := s.getRefreshTokenFromStorage(nil, rToken) + if err != nil { + if errors.Is(err, invalidErr) || errors.Is(err, expiredErr) { + return nil, newIntrospectInactiveTokenError() + } + + s.logger.Errorf("failed to get refresh token: %v", err) + return nil, newIntrospectInternalServerError() + } + + subjectString, sErr := genSubject(rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) + if sErr != nil { + s.logger.Errorf("failed to marshal offline session ID: %v", err) + return nil, newIntrospectInternalServerError() + } + + return &Introspection{ + Active: true, + ClientID: rCtx.storageToken.ClientID, + IssuedAt: rCtx.storageToken.CreatedAt.Unix(), + NotBefore: rCtx.storageToken.CreatedAt.Unix(), + Expiry: rCtx.storageToken.CreatedAt.Add(s.refreshTokenPolicy.absoluteLifetime).Unix(), + Subject: subjectString, + Username: rCtx.storageToken.Claims.PreferredUsername, + Audience: getAudience(rCtx.storageToken.ClientID, rCtx.scopes), + Issuer: s.issuerURL.String(), + + Extra: IntrospectionExtra{ + Email: rCtx.storageToken.Claims.Email, + EmailVerified: &rCtx.storageToken.Claims.EmailVerified, + Groups: rCtx.storageToken.Claims.Groups, + Name: rCtx.storageToken.Claims.Username, + PreferredUsername: rCtx.storageToken.Claims.PreferredUsername, + }, + TokenType: "Bearer", + TokenUse: "refresh_token", + }, nil +} + +func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Introspection, error) { + verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) + idToken, err := verifier.Verify(ctx, token) + if err != nil { + return nil, newIntrospectInactiveTokenError() + } + + var claims IntrospectionExtra + if err := idToken.Claims(&claims); err != nil { + s.logger.Errorf("Error while fetching token claims: %s", err.Error()) + return nil, newIntrospectInternalServerError() + } + + clientID, err := getClientID(idToken.Audience, claims.AuthorizingParty) + if err != nil { + s.logger.Error("Error while fetching client_id from token: %s", err.Error()) + return nil, newIntrospectInternalServerError() + } + + client, err := s.storage.GetClient(clientID) + if err != nil { + s.logger.Error("Error while fetching client from storage: %s", err.Error()) + return nil, newIntrospectInternalServerError() + } + + return &Introspection{ + Active: true, + ClientID: client.ID, + IssuedAt: idToken.IssuedAt.Unix(), + NotBefore: idToken.IssuedAt.Unix(), + Expiry: idToken.Expiry.Unix(), + Subject: idToken.Subject, + Username: claims.PreferredUsername, + Audience: idToken.Audience, + Issuer: s.issuerURL.String(), + + Extra: claims, + TokenType: "Bearer", + TokenUse: "access_token", + }, nil +} + +func (s *Server) handleIntrospect(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var introspect *Introspection + token, tokenType, err := s.getTokenFromRequest(r) + if err == nil { + switch tokenType { + case AccessToken: + introspect, err = s.introspectAccessToken(ctx, token) + case RefreshToken: + introspect, err = s.introspectRefreshToken(ctx, token) + default: + // Token type is neither handled token types. + s.logger.Errorf("Unknown token type: %s", tokenType) + introspectInactiveErr(w) + return + } + } + + if err != nil { + if intErr, ok := err.(*introspectionError); ok { + s.introspectErrHelper(w, intErr.typ, intErr.desc, intErr.code) + } else { + s.logger.Errorf("An unknown error occurred: %s", err.Error()) + s.introspectErrHelper(w, errServerError, "An unknown error occurred", http.StatusInternalServerError) + } + + return + } + + rawJSON, jsonErr := json.Marshal(introspect) + if jsonErr != nil { + s.introspectErrHelper(w, errServerError, jsonErr.Error(), 500) + } + + w.Header().Set("Content-Type", "application/json") + w.Write(rawJSON) +} + +func (s *Server) introspectErrHelper(w http.ResponseWriter, typ string, description string, statusCode int) { + if typ == errInactiveToken { + introspectInactiveErr(w) + return + } + + if err := tokenErr(w, typ, description, statusCode); err != nil { + s.logger.Errorf("introspect error response: %v", err) + } +} + +func introspectInactiveErr(w http.ResponseWriter) { + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(401) + json.NewEncoder(w).Encode(struct { + Active bool `json:"active"` + }{Active: false}) +} diff --git a/server/introspectionhandler_test.go b/server/introspectionhandler_test.go new file mode 100644 index 0000000000..07504c4e60 --- /dev/null +++ b/server/introspectionhandler_test.go @@ -0,0 +1,415 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/server/internal" + "github.com/dexidp/dex/storage" +) + +func toJSON(a interface{}) string { + b, err := json.Marshal(a) + if err != nil { + return "" + } + + return string(b) +} + +func mockTestStorage(t *testing.T, s storage.Storage) { + ctx := context.Background() + c := storage.Client{ + ID: "test", + Secret: "barfoo", + RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, + Name: "dex client", + LogoURL: "https://goo.gl/JIyzIC", + } + + err := s.CreateClient(ctx, c) + require.NoError(t, err) + + c1 := storage.Connector{ + ID: "test", + Type: "mockPassword", + Name: "mockPassword", + Config: []byte(`{ +"username": "test", +"password": "test" +}`), + } + + err = s.CreateConnector(ctx, c1) + require.NoError(t, err) + + err = s.CreateRefresh(ctx, storage.RefreshToken{ + ID: "test", + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "test", + ConnectorID: "test", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + ConnectorData: []byte(`{"some":"data"}`), + }) + require.NoError(t, err) + + err = s.CreateRefresh(ctx, storage.RefreshToken{ + ID: "expired", + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "test", + ConnectorID: "test", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().AddDate(-1, 0, 0).UTC().Round(time.Millisecond), + LastUsed: time.Now().AddDate(-1, 0, 0).UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + ConnectorData: []byte(`{"some":"data"}`), + }) + require.NoError(t, err) + + err = s.CreateOfflineSessions(ctx, storage.OfflineSessions{ + UserID: "1", + ConnID: "test", + Refresh: map[string]*storage.RefreshTokenRef{ + "test": {ID: "test", ClientID: "test"}, + "expired": {ID: "expired", ClientID: "test"}, + }, + ConnectorData: nil, + }) + require.NoError(t, err) +} + +func getIntrospectionValue(issuerURL url.URL, issuedAt time.Time, expiry time.Time, tokenUse string) *Introspection { + trueValue := true + return &Introspection{ + Active: true, + ClientID: "test", + Subject: "CgExEgR0ZXN0", + Expiry: expiry.Unix(), + IssuedAt: issuedAt.Unix(), + NotBefore: issuedAt.Unix(), + Audience: []string{ + "test", + }, + Issuer: issuerURL.String(), + TokenType: "Bearer", + TokenUse: tokenUse, + Extra: IntrospectionExtra{ + Email: "jane.doe@example.com", + EmailVerified: &trueValue, + Groups: []string{ + "a", + "b", + }, + Name: "jane", + }, + } +} + +func TestGetTokenFromRequestSuccess(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer += "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + tests := []struct { + testName string + expectedToken string + expectedTokenType TokenTypeEnum + }{ + // Access Token + { + testName: "Access Token", + expectedToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", + expectedTokenType: AccessToken, + }, + // Refresh Token + { + testName: "Refresh token", + expectedToken: "CgR0ZXN0EgNiYXI", + expectedTokenType: RefreshToken, + }, + // Unknown token + { + testName: "Unknown token", + expectedToken: "AaAaAaA", + expectedTokenType: RefreshToken, + }, + } + + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + data := url.Values{} + data.Set("token", tc.expectedToken) + req := httptest.NewRequest(http.MethodPost, "https://test.tech/token/introspect", bytes.NewBufferString(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + token, tokenType, err := s.getTokenFromRequest(req) + if err != nil { + t.Fatalf("Error returned: %s", err.Error()) + } + + if token != tc.expectedToken { + t.Fatalf("Wrong token returned. Expected %v got %v", tc.expectedToken, token) + } + + if tokenType != tc.expectedTokenType { + t.Fatalf("Wrong token type returned. Expected %v got %v", tc.expectedTokenType, tokenType) + } + }) + } +} + +func TestGetTokenFromRequestFailure(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer += "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + _, _, err := s.getTokenFromRequest(httptest.NewRequest(http.MethodGet, "https://test.tech/token/introspect", nil)) + require.ErrorIs(t, err, &introspectionError{ + typ: errInvalidRequest, + desc: "HTTP method is \"GET\", expected \"POST\".", + code: http.StatusBadRequest, + }) + + _, _, err = s.getTokenFromRequest(httptest.NewRequest(http.MethodPost, "https://test.tech/token/introspect", nil)) + require.ErrorIs(t, err, &introspectionError{ + typ: errInvalidRequest, + desc: "The POST body can not be empty.", + code: http.StatusBadRequest, + }) + + req := httptest.NewRequest(http.MethodPost, "https://test.tech/token/introspect", strings.NewReader("token_type_hint=access_token")) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + _, _, err = s.getTokenFromRequest(req) + require.ErrorIs(t, err, &introspectionError{ + typ: errInvalidRequest, + desc: "The POST body doesn't contain 'token' parameter.", + code: http.StatusBadRequest, + }) +} + +func TestHandleIntrospect(t *testing.T) { + t0 := time.Now() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + now := func() time.Time { return t0 } + + refreshTokenPolicy, err := NewRefreshTokenPolicy(logger, false, "", "24h", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + refreshTokenPolicy.now = now + + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer += "/non-root-path" + c.RefreshTokenPolicy = refreshTokenPolicy + c.Now = now + }) + defer httpServer.Close() + + mockTestStorage(t, s.storage) + + activeAccessToken, expiry, err := s.newIDToken("test", storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, []string{"openid", "email", "profile", "groups"}, "foo", "", "", "test") + require.NoError(t, err) + + activeRefreshToken, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) + require.NoError(t, err) + expiredRefreshToken, err := internal.Marshal(&internal.RefreshToken{RefreshId: "expired", Token: "bar"}) + require.NoError(t, err) + + inactiveResponse := "{\"active\":false}\n" + badRequestResponse := `{"error":"invalid_request","error_description":"The POST body can not be empty."}` + + tests := []struct { + testName string + token string + tokenType string + response string + responseStatusCode int + }{ + // No token + { + testName: "No token", + response: badRequestResponse, + responseStatusCode: 400, + }, + // Access token tests + { + testName: "Access Token: active", + token: activeAccessToken, + response: toJSON(getIntrospectionValue(s.issuerURL, time.Now(), expiry, "access_token")), + responseStatusCode: 200, + }, + { + testName: "Access Token: wrong", + token: "fake-token", + response: inactiveResponse, + responseStatusCode: 401, + }, + // Refresh token tests + { + testName: "Refresh Token: active", + token: activeRefreshToken, + response: toJSON(getIntrospectionValue(s.issuerURL, time.Now(), time.Now().Add(s.refreshTokenPolicy.absoluteLifetime), "refresh_token")), + responseStatusCode: 200, + }, + { + testName: "Refresh Token: expired", + token: expiredRefreshToken, + response: inactiveResponse, + responseStatusCode: 401, + }, + { + testName: "Refresh Token: active => false (wrong)", + token: "fake-token", + response: inactiveResponse, + responseStatusCode: 401, + }, + } + + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + data := url.Values{} + if tc.token != "" { + data.Set("token", tc.token) + } + if tc.tokenType != "" { + data.Set("token_type_hint", tc.tokenType) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Fatalf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "token", "introspect") + + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + if rr.Code != tc.responseStatusCode { + t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.responseStatusCode, rr.Code) + } + + result, _ := io.ReadAll(rr.Body) + if string(result) != tc.response { + t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.response, result) + } + }) + } +} + +func TestIntrospectErrHelper(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer += "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + tests := []struct { + testName string + err *introspectionError + resStatusCode int + resBody string + }{ + { + testName: "Inactive Token", + err: newIntrospectInactiveTokenError(), + resStatusCode: http.StatusUnauthorized, + resBody: "{\"active\":false}\n", + }, + { + testName: "Bad Request", + err: newIntrospectBadRequestError("This is a bad request"), + resStatusCode: http.StatusBadRequest, + resBody: `{"error":"invalid_request","error_description":"This is a bad request"}`, + }, + { + testName: "Internal Server Error", + err: newIntrospectInternalServerError(), + resStatusCode: http.StatusInternalServerError, + resBody: `{"error":"server_error"}`, + }, + } + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + w1 := httptest.NewRecorder() + + s.introspectErrHelper(w1, tc.err.typ, tc.err.desc, tc.err.code) + + res := w1.Result() + require.Equal(t, tc.resStatusCode, res.StatusCode) + require.Equal(t, "application/json", res.Header.Get("Content-Type")) + + data, err := io.ReadAll(res.Body) + defer res.Body.Close() + require.NoError(t, err) + require.Equal(t, tc.resBody, string(data)) + }) + } +} diff --git a/server/oauth2.go b/server/oauth2.go index 2f2fb74f40..3589e493ea 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -105,6 +105,7 @@ const ( errUnsupportedGrantType = "unsupported_grant_type" errInvalidGrant = "invalid_grant" errInvalidClient = "invalid_client" + errInactiveToken = "inactive_token" ) const ( @@ -306,6 +307,49 @@ func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes [ return s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), "", connID) } +func getClientID(aud audience, azp string) (string, error) { + switch len(aud) { + case 0: + return "", fmt.Errorf("no audience is set, could not find ClientID") + case 1: + return aud[0], nil + default: + return azp, nil + } +} + +func getAudience(clientID string, scopes []string) audience { + var aud audience + + for _, scope := range scopes { + if peerID, ok := parseCrossClientScope(scope); ok { + aud = append(aud, peerID) + } + } + + if len(aud) == 0 { + // Client didn't ask for cross client audience. Set the current + // client as the audience. + aud = audience{clientID} + // Client asked for cross client audience: + // if the current client was not requested explicitly + } else if !aud.contains(clientID) { + // by default it becomes one of entries in Audience + aud = append(aud, clientID) + } + + return aud +} + +func genSubject(userID string, connID string) (string, error) { + sub := &internal.IDTokenSubject{ + UserId: userID, + ConnId: connID, + } + + return internal.Marshal(sub) +} + func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { keys, err := s.storage.GetKeys() if err != nil { @@ -325,12 +369,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str issuedAt := s.now() expiry = issuedAt.Add(s.idTokensValidFor) - sub := &internal.IDTokenSubject{ - UserId: claims.UserID, - ConnId: connID, - } - - subjectString, err := internal.Marshal(sub) + subjectString, err := genSubject(claims.UserID, connID) if err != nil { s.logger.Errorf("failed to marshal offline session ID: %v", err) return "", expiry, fmt.Errorf("failed to marshal offline session ID: %v", err) @@ -392,21 +431,11 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str // TODO(ericchiang): propagate this error to the client. return "", expiry, fmt.Errorf("peer (%s) does not trust client", peerID) } - tok.Audience = append(tok.Audience, peerID) } } - if len(tok.Audience) == 0 { - // Client didn't ask for cross client audience. Set the current - // client as the audience. - tok.Audience = audience{clientID} - } else { - // Client asked for cross client audience: - // if the current client was not requested explicitly - if !tok.Audience.contains(clientID) { - // by default it becomes one of entries in Audience - tok.Audience = append(tok.Audience, clientID) - } + tok.Audience = getAudience(clientID, scopes) + if len(tok.Audience) > 1 { // The current client becomes the authorizing party. tok.AuthorizingParty = clientID } diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 5b1ceff5dc..5f5fc3b663 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -11,11 +11,43 @@ import ( "testing" "github.com/go-jose/go-jose/v4" + "github.com/stretchr/testify/require" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/memory" ) +func TestGetClientID(t *testing.T) { + cid, err := getClientID(audience{}, "") + require.Equal(t, "", cid) + require.Equal(t, "no audience is set, could not find ClientID", err.Error()) + + cid, err = getClientID(audience{"a"}, "") + require.Equal(t, "a", cid) + require.NoError(t, err) + + cid, err = getClientID(audience{"a", "b"}, "azp") + require.Equal(t, "azp", cid) + require.NoError(t, err) +} + +func TestGetAudience(t *testing.T) { + aud := getAudience("client-id", []string{}) + require.Equal(t, aud, audience{"client-id"}) + + aud = getAudience("client-id", []string{"ascope"}) + require.Equal(t, aud, audience{"client-id"}) + + aud = getAudience("client-id", []string{"ascope", "audience:server:client_id:aa", "audience:server:client_id:bb"}) + require.Equal(t, aud, audience{"aa", "bb", "client-id"}) +} + +func TestGetSubject(t *testing.T) { + sub, err := genSubject("foo", "bar") + require.Equal(t, "CgNmb28SA2Jhcg", sub) + require.NoError(t, err) +} + func TestParseAuthorizationRequest(t *testing.T) { tests := []struct { name string diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index b3918ab475..cb53802b87 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -40,6 +40,11 @@ func newBadRequestError(desc string) *refreshError { return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} } +var ( + invalidErr = newBadRequestError("Refresh token is invalid or has already been claimed by another client.") + expiredErr = newBadRequestError("Refresh token expired.") +) + func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { s.tokenErrHelper(w, err.msg, err.desc, err.code) } @@ -75,11 +80,9 @@ type refreshContext struct { } // getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info -func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*refreshContext, *refreshError) { +func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.RefreshToken) (*refreshContext, *refreshError) { refreshCtx := refreshContext{requestToken: token} - invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.") - // Get RefreshToken refresh, err := s.storage.GetRefresh(token.RefreshId) if err != nil { @@ -90,7 +93,8 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref return nil, invalidErr } - if refresh.ClientID != clientID { + // Only check ClientID if it was provided; + if clientID != nil && (refresh.ClientID != *clientID) { s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) // According to https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 Dex should respond with an // invalid grant error if token has already been claimed by another client. @@ -109,7 +113,6 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref } } - expiredErr := newBadRequestError("Refresh token expired.") if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { s.logger.Errorf("refresh token with id %s expired", refresh.ID) return nil, expiredErr @@ -334,7 +337,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - rCtx, rerr := s.getRefreshTokenFromStorage(client.ID, token) + rCtx, rerr := s.getRefreshTokenFromStorage(&client.ID, token) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return diff --git a/server/server.go b/server/server.go index 1eaf191543..dddbb137e9 100644 --- a/server/server.go +++ b/server/server.go @@ -391,6 +391,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleWithCORS("/token", s.handleToken) handleWithCORS("/keys", s.handlePublicKeys) handleWithCORS("/userinfo", s.handleUserInfo) + handleWithCORS("/token/introspect", s.handleIntrospect) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}/login", s.handlePasswordLogin)