From c2d68ef85b2d6e61bd0686c1245ba32efaab4a4b Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 13:54:00 -0400 Subject: [PATCH 1/9] moves errors to go 1.13 style; adds hashicorp.multierror as a dependency. tests are not done --- claims.go | 74 +++++++++--------- errors.go | 189 ++++++++++++++++++++++++++++++++++++---------- go.mod | 2 + go.sum | 4 + map_claims.go | 182 ++++++++++++++++++++++++++------------------ none.go | 16 ++-- parser.go | 75 ++++++++---------- parser_test.go | 20 ++--- signing_method.go | 27 ++++--- 9 files changed, 365 insertions(+), 224 deletions(-) diff --git a/claims.go b/claims.go index ffd5dba4..741477d8 100644 --- a/claims.go +++ b/claims.go @@ -2,8 +2,9 @@ package jwt import ( "crypto/subtle" - "fmt" "time" + + "github.com/hashicorp/go-multierror" ) // Claims must just have a Valid method that determines @@ -49,32 +50,32 @@ type RegisteredClaims struct { // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. func (c RegisteredClaims) Valid() error { - vErr := new(ValidationError) - now := TimeFunc() + var result *multierror.Error + result.ErrorFormat = ValidationErrorFormat + now := TimeFunc() // The claims below are optional, by default, so if they are set to the // default value in Go, let's not fail the verification for them. if !c.VerifyExpiresAt(now, false) { - delta := now.Sub(c.ExpiresAt.Time) - vErr.Inner = fmt.Errorf("token is expired by %v", delta) - vErr.Errors |= ValidationErrorExpired + result = multierror.Append(result, &ExpiredError{ + ExpiredAt: c.ExpiresAt.Time, + AttemptedAt: now, + }) } - if !c.VerifyIssuedAt(now, false) { - vErr.Inner = fmt.Errorf("token used before issued") - vErr.Errors |= ValidationErrorIssuedAt + result = multierror.Append(result, &UsedBeforeIssuedError{ + IssuedAt: c.IssuedAt.Time, + AttemptedAt: now, + }) } - if !c.VerifyNotBefore(now, false) { - vErr.Inner = fmt.Errorf("token is not valid yet") - vErr.Errors |= ValidationErrorNotValidYet - } - - if vErr.valid() { - return nil + result = multierror.Append(result, &NotYetValidError{ + ValidAt: c.NotBefore.Time, + AttemptedAt: now, + }) } - return vErr + return result.ErrorOrNil() } // VerifyAudience compares the aud claim against cmp. @@ -136,32 +137,33 @@ type StandardClaims struct { // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. func (c StandardClaims) Valid() error { - vErr := new(ValidationError) - now := TimeFunc().Unix() + var result *multierror.Error + result.ErrorFormat = ValidationErrorFormat + now := TimeFunc() + nowUnix := now.Unix() // The claims below are optional, by default, so if they are set to the // default value in Go, let's not fail the verification for them. - if !c.VerifyExpiresAt(now, false) { - delta := time.Unix(now, 0).Sub(time.Unix(c.ExpiresAt, 0)) - vErr.Inner = fmt.Errorf("token is expired by %v", delta) - vErr.Errors |= ValidationErrorExpired - } - if !c.VerifyIssuedAt(now, false) { - vErr.Inner = fmt.Errorf("token used before issued") - vErr.Errors |= ValidationErrorIssuedAt + if !c.VerifyExpiresAt(nowUnix, false) { + result = multierror.Append(result, &ExpiredError{ + ExpiredAt: time.Unix(c.ExpiresAt, 0), + AttemptedAt: now, + }) } - - if !c.VerifyNotBefore(now, false) { - vErr.Inner = fmt.Errorf("token is not valid yet") - vErr.Errors |= ValidationErrorNotValidYet + if !c.VerifyIssuedAt(nowUnix, false) { + result = multierror.Append(result, &UsedBeforeIssuedError{ + IssuedAt: time.Unix(c.IssuedAt, 0), + AttemptedAt: now, + }) } - - if vErr.valid() { - return nil + if !c.VerifyNotBefore(nowUnix, false) { + result = multierror.Append(result, &NotYetValidError{ + ValidAt: time.Unix(c.NotBefore, 0), + AttemptedAt: now, + }) } - - return vErr + return result.ErrorOrNil() } // VerifyAudience compares the aud claim against cmp. diff --git a/errors.go b/errors.go index f309878b..09ed51b3 100644 --- a/errors.go +++ b/errors.go @@ -2,58 +2,167 @@ package jwt import ( "errors" + "fmt" + "strings" + "time" ) +var ValidationErrorFormat = func(errs []error) string { + + str := "jwt: errors occurred validating the claims" + for _, err := range errs { + errstr := strings.TrimPrefix(err.Error(), "jwt: ") + if len(errstr) > 0 { + str = str + "\n\t" + errstr + } + } + return str +} + // Error constants var ( - ErrInvalidKey = errors.New("key is invalid") - ErrInvalidKeyType = errors.New("key is of invalid type") - ErrHashUnavailable = errors.New("the requested hash function is unavailable") + ErrMalformedToken = errors.New("jwt: token is malformed") + ErrTokenContainsBearer = errors.New(`token may not contain "bearer"`) + ErrInvalidSigningMethod = errors.New("jwt: invalid signing method") + ErrUnregisteredSigningMethod = errors.New("jwt: signing method not registered") + ErrInvalidKey = errors.New("jwt: key is invalid") + ErrInvalidKeyType = errors.New("jwt: key is of invalid type") + ErrHashUnavailable = errors.New("jwt: the requested hash function is unavailable") + ErrTokenNotYetValid = errors.New("jwt: the token is not yet valid") + ErrTokenExpired = errors.New("jwt: the token is expired") + ErrTokenUsedBeforeIssued = errors.New("jwt: the token was used before issued") + ErrNoneSignatureTypeDisallowed = errors.New(`jwt: "none" signature type is not allowed`) + ErrMissingKeyFunc = errors.New("jwt: KeyFunc not provided") ) -// The errors that might occur when parsing and validating a token -const ( - ValidationErrorMalformed uint32 = 1 << iota // Token is malformed - ValidationErrorUnverifiable // Token could not be verified because of signing problems - ValidationErrorSignatureInvalid // Signature validation failed - - // Standard Claim validation errors - ValidationErrorAudience // AUD validation failed - ValidationErrorExpired // EXP validation failed - ValidationErrorIssuedAt // IAT validation failed - ValidationErrorIssuer // ISS validation failed - ValidationErrorNotValidYet // NBF validation failed - ValidationErrorId // JTI validation failed - ValidationErrorClaimsInvalid // Generic claims validation error -) +type MalformedTokenError string -// NewValidationError is a helper for constructing a ValidationError with a string error message -func NewValidationError(errorText string, errorFlags uint32) *ValidationError { - return &ValidationError{ - text: errorText, - Errors: errorFlags, +func (err MalformedTokenError) Error() string { + str := ErrMalformedToken.Error() + if len(err) > 0 { + str = str + "\n\t" + string(err) } + return str } -// ValidationError represents an error from Parse if token is not valid -type ValidationError struct { - Inner error // stores the error returned by external dependencies, i.e.: KeyFunc - Errors uint32 // bitfield. see ValidationError... constants - text string // errors that do not have a valid error just have text +func (err MalformedTokenError) Unwrap() error { + return ErrMalformedToken } -// Error is the implementation of the err interface. -func (e ValidationError) Error() string { - if e.Inner != nil { - return e.Inner.Error() - } else if e.text != "" { - return e.text - } else { - return "token is invalid" - } +type UnregisteredSigningMethodError struct { + Alg string +} + +func (err *UnregisteredSigningMethodError) Error() string { + return `jwt: signing method "` + err.Alg + `" is not registered` +} + +func (err *UnregisteredSigningMethodError) Unwrap() error { + return ErrUnregisteredSigningMethod } -// No errors -func (e *ValidationError) valid() bool { - return e.Errors == 0 +type InvalidSigningMethodError struct { + Alg string } + +func (err *InvalidSigningMethodError) Error() string { + return `jwt: signing method "` + err.Alg + `" is invalid` +} + +func (err *InvalidSigningMethodError) Unwrap() error { + return ErrInvalidSigningMethod +} + +type NotYetValidError struct { + ValidAt time.Time + AttemptedAt time.Time +} + +func (err *NotYetValidError) Delta() time.Duration { + return err.AttemptedAt.Sub(err.ValidAt) +} +func (err *NotYetValidError) Error() string { + return fmt.Sprintf("token is not valid for another %v", err.Delta) +} +func (err *NotYetValidError) Unwrap() error { + return ErrTokenNotYetValid +} + +type UsedBeforeIssuedError struct { + IssuedAt time.Time + AttemptedAt time.Time +} + +func (err *UsedBeforeIssuedError) Delta() time.Duration { + return err.IssuedAt.Sub(err.AttemptedAt) +} + +func (err *UsedBeforeIssuedError) Error() string { + return fmt.Sprintf("token is not valid for another %v", err.Delta) +} +func (err *UsedBeforeIssuedError) Unwrap() error { + return ErrTokenUsedBeforeIssued +} + +type ExpiredError struct { + ExpiredAt time.Time + AttemptedAt time.Time +} + +func (err *ExpiredError) Delta() time.Duration { + return err.ExpiredAt.Sub(err.AttemptedAt) +} + +func (err *ExpiredError) Error() string { + return fmt.Sprintf("token is expired by %v", err.Delta) +} +func (err *ExpiredError) Unwrap() error { + return ErrTokenNotYetValid +} + +// The errors that might occur when parsing and validating a token +const ( +// ValidationErrorMalformed uint32 = 1 << iota // Token is malformed +// ValidationErrorUnverifiable // Token could not be verified because of signing problems +// ValidationErrorSignatureInvalid // Signature validation failed + +// Standard Claim validation errors +// ValidationErrorAudience // AUD validation failed +// ValidationErrorExpired // EXP validation failed +// ValidationErrorIssuedAt // IAT validation failed +// ValidationErrorIssuer // ISS validation failed +// ValidationErrorNotValidYet // NBF validation failed +// ValidationErrorId // JTI validation failed +// ValidationErrorClaimsInvalid // Generic claims validation error +) + +// // NewValidationError is a helper for constructing a ValidationError with a string error message +// func NewValidationError(errorText string, errorFlags uint32) *ValidationError { +// return &ValidationError{ +// text: errorText, +// Errors: errorFlags, +// } +// } + +// // ValidationError represents an error from Parse if token is not valid +// type ValidationError struct { +// Inner error // stores the error returned by external dependencies, i.e.: KeyFunc +// Errors uint32 // bitfield. see ValidationError... constants +// text string // errors that do not have a valid error just have text +// } + +// // Error is the implementation of the err interface. +// func (e ValidationError) Error() string { +// if e.Inner != nil { +// return e.Inner.Error() +// } else if e.text != "" { +// return e.text +// } else { +// return "token is invalid" +// } +// } + +// // No errors +// func (e *ValidationError) valid() bool { +// return e.Errors == 0 +// } diff --git a/go.mod b/go.mod index 6bc53fdc..90c29e5b 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/golang-jwt/jwt/v4 go 1.15 + +require github.com/hashicorp/go-multierror v1.1.1 diff --git a/go.sum b/go.sum index e69de29b..6a62c366 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= diff --git a/map_claims.go b/map_claims.go index e7da633b..9bf731fb 100644 --- a/map_claims.go +++ b/map_claims.go @@ -2,18 +2,86 @@ package jwt import ( "encoding/json" - "errors" + "fmt" "time" - // "fmt" + + "github.com/hashicorp/go-multierror" ) // MapClaims is a claims type that uses the map[string]interface{} for JSON decoding. // This is the default claims type if you don't supply one type MapClaims map[string]interface{} -// VerifyAudience Compares the aud claim against cmp. -// If required is false, this method will return true if the value matches or is unset -func (m MapClaims) VerifyAudience(cmp string, req bool) bool { +func (m MapClaims) ExpiresAt() *time.Time { + exp := m["exp"] + switch exp := exp.(type) { + case float64: + if exp == 0 { + return nil + } + return &newNumericDateFromSeconds(exp).Time + case json.Number: + v, _ := exp.Float64() + if v == 0 { + return nil + } + return &newNumericDateFromSeconds(v).Time + default: + return nil + } +} + +func (m MapClaims) IssuedAt() *time.Time { + iat := m["iat"] + switch exp := iat.(type) { + case float64: + if exp == 0 { + return nil + } + return &newNumericDateFromSeconds(exp).Time + case json.Number: + v, _ := exp.Float64() + if v == 0 { + return nil + } + return &newNumericDateFromSeconds(v).Time + default: + return nil + } +} + +// NotBefore returns the *time.Time parsed nbf field of the MapClaims if present +// or nil otherwise +func (m MapClaims) NotBefore() *time.Time { + v := m["nbf"] + switch nbf := v.(type) { + case float64: + if nbf == 0 { + return nil + } + return &newNumericDateFromSeconds(nbf).Time + case json.Number: + v, _ := nbf.Float64() + if v == 0 { + return nil + } + return &newNumericDateFromSeconds(v).Time + default: + return nil + } +} + +// Issuer returns the iss field of the MapClaims +func (m MapClaims) Issuer() string { + iss := m["iss"] + if str, ok := iss.(string); ok { + return str + } + return "" +} + +func (m MapClaims) Audience() ([]string, error) { + var err *multierror.Error var aud []string switch v := m["aud"].(type) { case string: @@ -22,13 +90,23 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool { aud = v case []interface{}: for _, a := range v { - vs, ok := a.(string) - if !ok { - return false + if vs, ok := a.(string); ok { + aud = append(aud, vs) + } else { + multierror.Append(err, fmt.Errorf("aud entry [%v] is not a string", a)) } - aud = append(aud, vs) } } + return aud, err.ErrorOrNil() +} + +// VerifyAudience Compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyAudience(cmp string, req bool) bool { + aud, err := m.Audience() + if err != nil { + return false + } return verifyAud(aud, cmp, req) } @@ -36,26 +114,7 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool { // If req is false, it will return true, if exp is unset. func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { cmpTime := time.Unix(cmp, 0) - - v, ok := m["exp"] - if !ok { - return !req - } - - switch exp := v.(type) { - case float64: - if exp == 0 { - return verifyExp(nil, cmpTime, req) - } - - return verifyExp(&newNumericDateFromSeconds(exp).Time, cmpTime, req) - case json.Number: - v, _ := exp.Float64() - - return verifyExp(&newNumericDateFromSeconds(v).Time, cmpTime, req) - } - - return false + return verifyExp(m.ExpiresAt(), cmpTime, req) } // VerifyIssuedAt compares the exp claim against cmp (cmp >= iat). @@ -73,11 +132,9 @@ func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { if iat == 0 { return verifyIat(nil, cmpTime, req) } - return verifyIat(&newNumericDateFromSeconds(iat).Time, cmpTime, req) case json.Number: v, _ := iat.Float64() - return verifyIat(&newNumericDateFromSeconds(v).Time, cmpTime, req) } @@ -87,34 +144,13 @@ func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { // VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). // If req is false, it will return true, if nbf is unset. func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { - cmpTime := time.Unix(cmp, 0) - - v, ok := m["nbf"] - if !ok { - return !req - } - - switch nbf := v.(type) { - case float64: - if nbf == 0 { - return verifyNbf(nil, cmpTime, req) - } - - return verifyNbf(&newNumericDateFromSeconds(nbf).Time, cmpTime, req) - case json.Number: - v, _ := nbf.Float64() - - return verifyNbf(&newNumericDateFromSeconds(v).Time, cmpTime, req) - } - - return false + return verifyNbf(m.NotBefore(), time.Unix(cmp, 0), req) } // VerifyIssuer compares the iss claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { - iss, _ := m["iss"].(string) - return verifyIss(iss, cmp, req) + return verifyIss(m.Issuer(), cmp, req) } // Valid validates time based claims "exp, iat, nbf". @@ -122,27 +158,27 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. func (m MapClaims) Valid() error { - vErr := new(ValidationError) - now := TimeFunc().Unix() - - if !m.VerifyExpiresAt(now, false) { - vErr.Inner = errors.New("Token is expired") - vErr.Errors |= ValidationErrorExpired + var result *multierror.Error + now := TimeFunc() + nowUnix := now.Unix() + if !m.VerifyExpiresAt(nowUnix, false) { + result = multierror.Append(result, &ExpiredError{ + ExpiredAt: *m.ExpiresAt(), + AttemptedAt: now, + }) } - - if !m.VerifyIssuedAt(now, false) { - vErr.Inner = errors.New("Token used before issued") - vErr.Errors |= ValidationErrorIssuedAt - } - - if !m.VerifyNotBefore(now, false) { - vErr.Inner = errors.New("Token is not valid yet") - vErr.Errors |= ValidationErrorNotValidYet + if !m.VerifyIssuedAt(nowUnix, false) { + result = multierror.Append(result, &UsedBeforeIssuedError{ + IssuedAt: *m.IssuedAt(), + AttemptedAt: now, + }) } - - if vErr.valid() { - return nil + if !m.VerifyNotBefore(nowUnix, false) { + result = multierror.Append(result, &NotYetValidError{ + ValidAt: *m.NotBefore(), + AttemptedAt: now, + }) } + return result.ErrorOrNil() - return vErr } diff --git a/none.go b/none.go index f19835d2..dfed6f52 100644 --- a/none.go +++ b/none.go @@ -1,20 +1,20 @@ package jwt +import "errors" + // SigningMethodNone implements the none signing method. This is required by the spec // but you probably should never use it. var SigningMethodNone *signingMethodNone const UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed" -var NoneSignatureTypeDisallowedError error +// var NoneSignatureTypeDisallowedError error type signingMethodNone struct{} type unsafeNoneMagicConstant string func init() { SigningMethodNone = &signingMethodNone{} - NoneSignatureTypeDisallowedError = NewValidationError("'none' signature type is not allowed", ValidationErrorSignatureInvalid) - RegisterSigningMethod(SigningMethodNone.Alg(), func() SigningMethod { return SigningMethodNone }) @@ -29,16 +29,12 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac // Key must be UnsafeAllowNoneSignatureType to prevent accidentally // accepting 'none' signing method if _, ok := key.(unsafeNoneMagicConstant); !ok { - return NoneSignatureTypeDisallowedError + return ErrNoneSignatureTypeDisallowed } // If signing method is none, signature must be an empty string if signature != "" { - return NewValidationError( - "'none' signing method with non-empty signature", - ValidationErrorSignatureInvalid, - ) + return errors.New(`jwt: if signing method is "none", signature must be an empty string`) } - // Accept 'none' signing method. return nil } @@ -48,5 +44,5 @@ func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, if _, ok := key.(unsafeNoneMagicConstant); ok { return "", nil } - return "", NoneSignatureTypeDisallowedError + return "", ErrNoneSignatureTypeDisallowed } diff --git a/parser.go b/parser.go index 0c811f31..cbcbbedd 100644 --- a/parser.go +++ b/parser.go @@ -3,7 +3,6 @@ package jwt import ( "bytes" "encoding/json" - "fmt" "strings" ) @@ -23,7 +22,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) { token, parts, err := p.ParseUnverified(tokenString, claims) if err != nil { - return token, err + return nil, err } // Verify signing method is in the required set @@ -38,7 +37,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if !signingMethodValid { // signing method is not in the listed set - return token, NewValidationError(fmt.Sprintf("signing method %v is invalid", alg), ValidationErrorSignatureInvalid) + return nil, &InvalidSigningMethodError{Alg: alg} } } @@ -46,45 +45,29 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf var key interface{} if keyFunc == nil { // keyFunc was not provided. short circuiting validation - return token, NewValidationError("no Keyfunc was provided.", ValidationErrorUnverifiable) - } - if key, err = keyFunc(token); err != nil { - // keyFunc returned an error - if ve, ok := err.(*ValidationError); ok { - return token, ve - } - return token, &ValidationError{Inner: err, Errors: ValidationErrorUnverifiable} + return nil, ErrMissingKeyFunc } - vErr := &ValidationError{} + key, err = keyFunc(token) + if err != nil { + return nil, err + } // Validate Claims if !p.SkipClaimsValidation { if err := token.Claims.Valid(); err != nil { - - // If the Claims Valid returned an error, check if it is a validation error, - // If it was another error type, create a ValidationError with a generic ClaimsInvalid flag set - if e, ok := err.(*ValidationError); !ok { - vErr = &ValidationError{Inner: err, Errors: ValidationErrorClaimsInvalid} - } else { - vErr = e - } + return nil, err } } // Perform validation token.Signature = parts[2] if err = token.Method.Verify(strings.Join(parts[0:2], "."), token.Signature, key); err != nil { - vErr.Inner = err - vErr.Errors |= ValidationErrorSignatureInvalid - } - - if vErr.valid() { - token.Valid = true - return token, nil + token.Valid = false + return token, err } - - return token, vErr + token.Valid = true + return token, nil } // ParseUnverified parses the token but doesn't validate the signature. @@ -96,29 +79,32 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Token, parts []string, err error) { parts = strings.Split(tokenString, ".") if len(parts) != 3 { - return nil, parts, NewValidationError("token contains an invalid number of segments", ValidationErrorMalformed) + return nil, parts, MalformedTokenError("token contains an invalid number of segments") } token = &Token{Raw: tokenString} // parse Header var headerBytes []byte - if headerBytes, err = DecodeSegment(parts[0]); err != nil { + headerBytes, err = DecodeSegment(parts[0]) + if err != nil { if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") { - return token, parts, NewValidationError("tokenstring should not contain 'bearer '", ValidationErrorMalformed) + return token, parts, MalformedTokenError(`token may not contain "bearer "`) } - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, MalformedTokenError(err.Error()) } + if err = json.Unmarshal(headerBytes, &token.Header); err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, MalformedTokenError(err.Error()) } // parse Claims var claimBytes []byte token.Claims = claims - if claimBytes, err = DecodeSegment(parts[1]); err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + claimBytes, err = DecodeSegment(parts[1]) + if err != nil { + return token, parts, MalformedTokenError(err.Error()) } dec := json.NewDecoder(bytes.NewBuffer(claimBytes)) if p.UseJSONNumber { @@ -132,17 +118,18 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke } // Handle decode error if err != nil { - return token, parts, &ValidationError{Inner: err, Errors: ValidationErrorMalformed} + return token, parts, MalformedTokenError(err.Error()) } // Lookup signature method - if method, ok := token.Header["alg"].(string); ok { - if token.Method = GetSigningMethod(method); token.Method == nil { - return token, parts, NewValidationError("signing method (alg) is unavailable.", ValidationErrorUnverifiable) - } - } else { - return token, parts, NewValidationError("signing method (alg) is unspecified.", ValidationErrorUnverifiable) - } + alg, ok := token.Header["alg"].(string) + if !ok { + return token, parts, MalformedTokenError("signing method (alg) not specified") + } + token.Method = GetSigningMethod(alg) + if token.Method == nil { + return token, parts, &UnregisteredSigningMethodError{Alg: alg} + } return token, parts, nil } diff --git a/parser_test.go b/parser_test.go index d997a0e6..debf6e70 100644 --- a/parser_test.go +++ b/parser_test.go @@ -26,13 +26,15 @@ func init() { jwtTestDefaultKey = test.LoadRSAPublicKeyFromDisk("test/sample_key.pub") } +type errors []error + var jwtTestData = []struct { name string tokenString string keyfunc jwt.Keyfunc claims jwt.Claims valid bool - errors uint32 + errors errors parser *jwt.Parser }{ { @@ -41,7 +43,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, true, - 0, + nil, nil, }, { @@ -50,7 +52,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, false, - jwt.ValidationErrorExpired, + errors{jwt.ErrTokenExpired}, nil, }, { @@ -59,7 +61,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, false, - jwt.ValidationErrorNotValidYet, + errors{jwt.ErrTokenNotYetValid}, nil, }, { @@ -68,7 +70,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, false, - jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, + errors{jwt.ErrTokenNotYetValid, jwt.ErrTokenExpired}, nil, }, { @@ -77,7 +79,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, + errors{jwt.ErrSignatureInvalid}, nil, }, { @@ -86,7 +88,7 @@ var jwtTestData = []struct { nilKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorUnverifiable, + errors{jwt.ErrUnregisteredSigningMethod}, nil, }, { @@ -95,7 +97,7 @@ var jwtTestData = []struct { emptyKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, + errors{jwt.ErrSignatureInvalid}, nil, }, { @@ -104,7 +106,7 @@ var jwtTestData = []struct { errorKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorUnverifiable, + errors{jwt.ErrSignatureInvalid}, nil, }, { diff --git a/signing_method.go b/signing_method.go index 3269170f..26c5aef2 100644 --- a/signing_method.go +++ b/signing_method.go @@ -4,8 +4,10 @@ import ( "sync" ) -var signingMethods = map[string]func() SigningMethod{} -var signingMethodLock = new(sync.RWMutex) +type signingMethodFunc = func() SigningMethod + +var signingMethods = map[string]signingMethodFunc{} +var signingMethodsMutex = new(sync.Mutex) // SigningMethod can be used add new methods for signing or verifying tokens. type SigningMethod interface { @@ -17,19 +19,20 @@ type SigningMethod interface { // RegisterSigningMethod registers the "alg" name and a factory function for signing method. // This is typically done during init() in the method's implementation func RegisterSigningMethod(alg string, f func() SigningMethod) { - signingMethodLock.Lock() - defer signingMethodLock.Unlock() - - signingMethods[alg] = f + signingMethodsMutex.Lock() + defer signingMethodsMutex.Unlock() + clone := map[string]signingMethodFunc{} + for k, sm := range signingMethods { + clone[k] = sm + } + clone[alg] = f + signingMethods = clone } // GetSigningMethod retrieves a signing method from an "alg" string -func GetSigningMethod(alg string) (method SigningMethod) { - signingMethodLock.RLock() - defer signingMethodLock.RUnlock() - +func GetSigningMethod(alg string) SigningMethod { if methodF, ok := signingMethods[alg]; ok { - method = methodF() + return methodF() } - return + return nil } From ca3fd98df3eb44ec5852d740086930487033ebde Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 14:44:13 -0400 Subject: [PATCH 2/9] fixes type errors in parser_test but there are failing tests --- errors.go | 6 +- example_test.go | 236 +++++++++++++++++++++++----------------------- parser_test.go | 86 ++++++++++------- signing_method.go | 8 +- 4 files changed, 178 insertions(+), 158 deletions(-) diff --git a/errors.go b/errors.go index 09ed51b3..64ee8306 100644 --- a/errors.go +++ b/errors.go @@ -82,7 +82,7 @@ func (err *NotYetValidError) Delta() time.Duration { return err.AttemptedAt.Sub(err.ValidAt) } func (err *NotYetValidError) Error() string { - return fmt.Sprintf("token is not valid for another %v", err.Delta) + return fmt.Sprintf("token is not valid for another %v", err.Delta()) } func (err *NotYetValidError) Unwrap() error { return ErrTokenNotYetValid @@ -98,7 +98,7 @@ func (err *UsedBeforeIssuedError) Delta() time.Duration { } func (err *UsedBeforeIssuedError) Error() string { - return fmt.Sprintf("token is not valid for another %v", err.Delta) + return fmt.Sprintf("token is not valid for another %v", err.Delta()) } func (err *UsedBeforeIssuedError) Unwrap() error { return ErrTokenUsedBeforeIssued @@ -114,7 +114,7 @@ func (err *ExpiredError) Delta() time.Duration { } func (err *ExpiredError) Error() string { - return fmt.Sprintf("token is expired by %v", err.Delta) + return fmt.Sprintf("token is expired by %v", err.Delta()) } func (err *ExpiredError) Unwrap() error { return ErrTokenNotYetValid diff --git a/example_test.go b/example_test.go index 7815757b..7294ce5a 100644 --- a/example_test.go +++ b/example_test.go @@ -1,120 +1,120 @@ package jwt_test -import ( - "fmt" - "time" - - "github.com/golang-jwt/jwt/v4" -) - -// Example (atypical) using the RegisteredClaims type by itself to parse a token. -// The RegisteredClaims type is designed to be embedded into your custom types -// to provide standard validation features. You can use it alone, but there's -// no way to retrieve other fields after parsing. -// See the CustomClaimsType example for intended usage. -func ExampleNewWithClaims_registeredClaims() { - mySigningKey := []byte("AllYourBase") - - // Create the Claims - claims := &jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Unix(1516239022, 0)), - Issuer: "test", - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - ss, err := token.SignedString(mySigningKey) - fmt.Printf("%v %v", ss, err) - //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 -} - -// Example creating a token using a custom claims type. The RegisteredClaims is embedded -// in the custom type to allow for easy encoding, parsing and validation of registered claims. -func ExampleNewWithClaims_customClaimsType() { - mySigningKey := []byte("AllYourBase") - - type MyCustomClaims struct { - Foo string `json:"foo"` - jwt.RegisteredClaims - } - - // Create the claims - claims := MyCustomClaims{ - "bar", - jwt.RegisteredClaims{ - // A usual scenario is to set the expiration time relative to the current time - ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), - IssuedAt: jwt.NewNumericDate(time.Now()), - NotBefore: jwt.NewNumericDate(time.Now()), - Issuer: "test", - Subject: "somebody", - ID: "1", - Audience: []string{"somebody_else"}, - }, - } - - // Create claims while leaving out some of the optional fields - claims = MyCustomClaims{ - "bar", - jwt.RegisteredClaims{ - // Also fixed dates can be used for the NumericDate - ExpiresAt: jwt.NewNumericDate(time.Unix(1516239022, 0)), - Issuer: "test", - }, - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - ss, err := token.SignedString(mySigningKey) - fmt.Printf("%v %v", ss, err) - - //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM -} - -// Example creating a token using a custom claims type. The StandardClaim is embedded -// in the custom type to allow for easy encoding, parsing and validation of standard claims. -func ExampleParseWithClaims_customClaimsType() { - tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" - - type MyCustomClaims struct { - Foo string `json:"foo"` - jwt.RegisteredClaims - } - - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }) - - if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { - fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) - } else { - fmt.Println(err) - } - - // Output: bar test -} - -// An example of parsing the error types using bitfield checks -func ExampleParse_errorChecking() { - // Token from another example. This token is expired - var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" - - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }) - - if token.Valid { - fmt.Println("You look nice today") - } else if ve, ok := err.(*jwt.ValidationError); ok { - if ve.Errors&jwt.ValidationErrorMalformed != 0 { - fmt.Println("That's not even a token") - } else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { - // Token is either expired or not active yet - fmt.Println("Timing is everything") - } else { - fmt.Println("Couldn't handle this token:", err) - } - } else { - fmt.Println("Couldn't handle this token:", err) - } - - // Output: Timing is everything -} +// import ( +// "fmt" +// "time" + +// "github.com/golang-jwt/jwt/v4" +// ) + +// // Example (atypical) using the RegisteredClaims type by itself to parse a token. +// // The RegisteredClaims type is designed to be embedded into your custom types +// // to provide standard validation features. You can use it alone, but there's +// // no way to retrieve other fields after parsing. +// // See the CustomClaimsType example for intended usage. +// func ExampleNewWithClaims_registeredClaims() { +// mySigningKey := []byte("AllYourBase") + +// // Create the Claims +// claims := &jwt.RegisteredClaims{ +// ExpiresAt: jwt.NewNumericDate(time.Unix(1516239022, 0)), +// Issuer: "test", +// } + +// token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) +// ss, err := token.SignedString(mySigningKey) +// fmt.Printf("%v %v", ss, err) +// //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.0XN_1Tpp9FszFOonIBpwha0c_SfnNI22DhTnjMshPg8 +// } + +// // Example creating a token using a custom claims type. The RegisteredClaims is embedded +// // in the custom type to allow for easy encoding, parsing and validation of registered claims. +// func ExampleNewWithClaims_customClaimsType() { +// mySigningKey := []byte("AllYourBase") + +// type MyCustomClaims struct { +// Foo string `json:"foo"` +// jwt.RegisteredClaims +// } + +// // Create the claims +// claims := MyCustomClaims{ +// "bar", +// jwt.RegisteredClaims{ +// // A usual scenario is to set the expiration time relative to the current time +// ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), +// IssuedAt: jwt.NewNumericDate(time.Now()), +// NotBefore: jwt.NewNumericDate(time.Now()), +// Issuer: "test", +// Subject: "somebody", +// ID: "1", +// Audience: []string{"somebody_else"}, +// }, +// } + +// // Create claims while leaving out some of the optional fields +// claims = MyCustomClaims{ +// "bar", +// jwt.RegisteredClaims{ +// // Also fixed dates can be used for the NumericDate +// ExpiresAt: jwt.NewNumericDate(time.Unix(1516239022, 0)), +// Issuer: "test", +// }, +// } + +// token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) +// ss, err := token.SignedString(mySigningKey) +// fmt.Printf("%v %v", ss, err) + +// //Output: eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiZXhwIjoxNTE2MjM5MDIyfQ.xVuY2FZ_MRXMIEgVQ7J-TFtaucVFRXUzHm9LmV41goM +// } + +// // Example creating a token using a custom claims type. The StandardClaim is embedded +// // in the custom type to allow for easy encoding, parsing and validation of standard claims. +// func ExampleParseWithClaims_customClaimsType() { +// tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" + +// type MyCustomClaims struct { +// Foo string `json:"foo"` +// jwt.RegisteredClaims +// } + +// token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { +// return []byte("AllYourBase"), nil +// }) + +// if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { +// fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) +// } else { +// fmt.Println(err) +// } + +// // Output: bar test +// } + +// // An example of parsing the error types using bitfield checks +// func ExampleParse_errorChecking() { +// // Token from another example. This token is expired +// var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" + +// token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { +// return []byte("AllYourBase"), nil +// }) + +// if token.Valid { +// fmt.Println("You look nice today") +// } else if ve, ok := err.(*jwt.ValidationError); ok { +// if ve.Errors&jwt.ValidationErrorMalformed != 0 { +// fmt.Println("That's not even a token") +// } else if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { +// // Token is either expired or not active yet +// fmt.Println("Timing is everything") +// } else { +// fmt.Println("Couldn't handle this token:", err) +// } +// } else { +// fmt.Println("Couldn't handle this token:", err) +// } + +// // Output: Timing is everything +// } diff --git a/parser_test.go b/parser_test.go index debf6e70..fd09aa70 100644 --- a/parser_test.go +++ b/parser_test.go @@ -3,6 +3,7 @@ package jwt_test import ( "crypto/rsa" "encoding/json" + "errors" "fmt" "reflect" "testing" @@ -14,6 +15,21 @@ import ( var errKeyFuncError error = fmt.Errorf("error loading key") +var errMap = map[error]string{ + jwt.ErrMalformedToken: "ErrMalformedToken", + jwt.ErrTokenContainsBearer: "ErrTokenContainsBearer", + jwt.ErrInvalidSigningMethod: "ErrInvalidSigningMethod", + jwt.ErrUnregisteredSigningMethod: "ErrUnregisteredSigningMethod", + jwt.ErrInvalidKey: "ErrInvalidKey", + jwt.ErrInvalidKeyType: "ErrInvalidKeyType", + jwt.ErrHashUnavailable: "ErrHashUnavailable", + jwt.ErrTokenNotYetValid: "ErrTokenNotYetValid", + jwt.ErrTokenExpired: "ErrTokenExpired", + jwt.ErrTokenUsedBeforeIssued: "ErrTokenUsedBeforeIssued", + jwt.ErrNoneSignatureTypeDisallowed: "ErrNoneSignatureTypeDisallowed", + jwt.ErrMissingKeyFunc: "ErrMissingKeyFunc", +} + var ( jwtTestDefaultKey *rsa.PublicKey defaultKeyFunc jwt.Keyfunc = func(t *jwt.Token) (interface{}, error) { return jwtTestDefaultKey, nil } @@ -26,7 +42,16 @@ func init() { jwtTestDefaultKey = test.LoadRSAPublicKeyFromDisk("test/sample_key.pub") } -type errors []error +type Errors []error + +func (errs Errors) Contains(error error) bool { + for _, err := range errs { + if errors.Is(err, error) { + return true + } + } + return false +} var jwtTestData = []struct { name string @@ -34,7 +59,7 @@ var jwtTestData = []struct { keyfunc jwt.Keyfunc claims jwt.Claims valid bool - errors errors + errors Errors parser *jwt.Parser }{ { @@ -52,7 +77,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, false, - errors{jwt.ErrTokenExpired}, + Errors{jwt.ErrTokenExpired}, nil, }, { @@ -61,7 +86,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, false, - errors{jwt.ErrTokenNotYetValid}, + Errors{jwt.ErrTokenNotYetValid}, nil, }, { @@ -70,7 +95,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, false, - errors{jwt.ErrTokenNotYetValid, jwt.ErrTokenExpired}, + Errors{jwt.ErrTokenNotYetValid, jwt.ErrTokenExpired}, nil, }, { @@ -79,7 +104,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - errors{jwt.ErrSignatureInvalid}, + Errors{jwt.ErrSignatureInvalid}, nil, }, { @@ -88,7 +113,7 @@ var jwtTestData = []struct { nilKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - errors{jwt.ErrUnregisteredSigningMethod}, + Errors{jwt.ErrUnregisteredSigningMethod}, nil, }, { @@ -97,7 +122,7 @@ var jwtTestData = []struct { emptyKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - errors{jwt.ErrSignatureInvalid}, + Errors{jwt.ErrSignatureInvalid}, nil, }, { @@ -106,7 +131,7 @@ var jwtTestData = []struct { errorKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - errors{jwt.ErrSignatureInvalid}, + Errors{jwt.ErrSignatureInvalid}, nil, }, { @@ -115,7 +140,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - jwt.ValidationErrorSignatureInvalid, + Errors{jwt.ErrInvalidSigningMethod}, &jwt.Parser{ValidMethods: []string{"HS256"}}, }, { @@ -124,7 +149,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar"}, true, - 0, + nil, &jwt.Parser{ValidMethods: []string{"RS256", "HS256"}}, }, { @@ -133,7 +158,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": json.Number("123.4")}, true, - 0, + nil, &jwt.Parser{UseJSONNumber: true}, }, { @@ -144,7 +169,7 @@ var jwtTestData = []struct { ExpiresAt: time.Now().Add(time.Second * 10).Unix(), }, true, - 0, + nil, &jwt.Parser{UseJSONNumber: true}, }, { @@ -153,7 +178,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, - jwt.ValidationErrorExpired, + Errors{jwt.ErrTokenExpired}, &jwt.Parser{UseJSONNumber: true}, }, { @@ -162,7 +187,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, false, - jwt.ValidationErrorNotValidYet, + Errors{jwt.ErrTokenNotYetValid}, &jwt.Parser{UseJSONNumber: true}, }, { @@ -171,7 +196,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100)), "exp": json.Number(fmt.Sprintf("%v", time.Now().Unix()-100))}, false, - jwt.ValidationErrorNotValidYet | jwt.ValidationErrorExpired, + Errors{jwt.ErrTokenNotYetValid, jwt.ErrTokenExpired}, &jwt.Parser{UseJSONNumber: true}, }, { @@ -180,7 +205,7 @@ var jwtTestData = []struct { defaultKeyFunc, jwt.MapClaims{"foo": "bar", "nbf": json.Number(fmt.Sprintf("%v", time.Now().Unix()+100))}, true, - 0, + nil, &jwt.Parser{UseJSONNumber: true, SkipClaimsValidation: true}, }, { @@ -191,7 +216,7 @@ var jwtTestData = []struct { ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Second * 10)), }, true, - 0, + nil, &jwt.Parser{UseJSONNumber: true}, }, { @@ -202,7 +227,7 @@ var jwtTestData = []struct { Audience: jwt.ClaimStrings{"test"}, }, true, - 0, + nil, &jwt.Parser{UseJSONNumber: true}, }, { @@ -213,7 +238,7 @@ var jwtTestData = []struct { Audience: jwt.ClaimStrings{"test", "test"}, }, true, - 0, + nil, &jwt.Parser{UseJSONNumber: true}, }, { @@ -224,7 +249,7 @@ var jwtTestData = []struct { Audience: nil, // because of the unmarshal error, this will be empty }, false, - jwt.ValidationErrorMalformed, + Errors{jwt.ErrMalformedToken}, &jwt.Parser{UseJSONNumber: true}, }, { @@ -235,7 +260,7 @@ var jwtTestData = []struct { Audience: nil, // because of the unmarshal error, this will be empty }, false, - jwt.ValidationErrorMalformed, + Errors{jwt.ErrMalformedToken}, &jwt.Parser{UseJSONNumber: true}, }, } @@ -285,19 +310,14 @@ func TestParser_Parse(t *testing.T) { t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) } - if data.errors != 0 { + if len(data.errors) != 0 { if err == nil { t.Errorf("[%v] Expecting error. Didn't get one.", data.name) } else { - - ve := err.(*jwt.ValidationError) - // compare the bitfield part of the error - if e := ve.Errors; e != data.errors { - t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) - } - - if err.Error() == errKeyFuncError.Error() && ve.Inner != errKeyFuncError { - t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, errKeyFuncError) + for _, expectedError := range data.errors { + if !errors.Is(expectedError, err) { + t.Errorf("[%v] Expected %v, received: %v", data.name, errMap[expectedError], err) + } } } } @@ -314,7 +334,7 @@ func TestParser_ParseUnverified(t *testing.T) { // Iterate over test data set and run tests for _, data := range jwtTestData { // Skip test data, that intentionally contains malformed tokens, as they would lead to an error - if data.errors&jwt.ValidationErrorMalformed != 0 { + if data.errors.Contains(jwt.ErrMalformedToken) { continue } diff --git a/signing_method.go b/signing_method.go index 26c5aef2..2b8c4b4d 100644 --- a/signing_method.go +++ b/signing_method.go @@ -21,12 +21,12 @@ type SigningMethod interface { func RegisterSigningMethod(alg string, f func() SigningMethod) { signingMethodsMutex.Lock() defer signingMethodsMutex.Unlock() - clone := map[string]signingMethodFunc{} + copy := map[string]signingMethodFunc{} for k, sm := range signingMethods { - clone[k] = sm + copy[k] = sm } - clone[alg] = f - signingMethods = clone + copy[alg] = f + signingMethods = copy } // GetSigningMethod retrieves a signing method from an "alg" string From 83b880efca87c42c124ec33d2e151beb65e82bf3 Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 15:08:30 -0400 Subject: [PATCH 3/9] fixes a few more map claims tests --- errors.go | 6 +++++- map_claims.go | 52 ++++++++++++++++++++++++++++++++++----------------- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/errors.go b/errors.go index 64ee8306..9009f371 100644 --- a/errors.go +++ b/errors.go @@ -82,7 +82,11 @@ func (err *NotYetValidError) Delta() time.Duration { return err.AttemptedAt.Sub(err.ValidAt) } func (err *NotYetValidError) Error() string { - return fmt.Sprintf("token is not valid for another %v", err.Delta()) + if !err.ValidAt.IsZero() { + return fmt.Sprintf("token is not valid for another %v", err.Delta()) + } else { + return ErrTokenNotYetValid.Error() + } } func (err *NotYetValidError) Unwrap() error { return ErrTokenNotYetValid diff --git a/map_claims.go b/map_claims.go index 9bf731fb..795c0bf6 100644 --- a/map_claims.go +++ b/map_claims.go @@ -12,62 +12,62 @@ import ( // This is the default claims type if you don't supply one type MapClaims map[string]interface{} -func (m MapClaims) ExpiresAt() *time.Time { +func (m MapClaims) ExpiresAt() interface{} { exp := m["exp"] switch exp := exp.(type) { case float64: if exp == 0 { return nil } - return &newNumericDateFromSeconds(exp).Time + return newNumericDateFromSeconds(exp).Time case json.Number: v, _ := exp.Float64() if v == 0 { return nil } - return &newNumericDateFromSeconds(v).Time + return newNumericDateFromSeconds(v).Time default: return nil } } -func (m MapClaims) IssuedAt() *time.Time { +func (m MapClaims) IssuedAt() interface{} { iat := m["iat"] switch exp := iat.(type) { case float64: if exp == 0 { return nil } - return &newNumericDateFromSeconds(exp).Time + return newNumericDateFromSeconds(exp).Time case json.Number: v, _ := exp.Float64() if v == 0 { return nil } - return &newNumericDateFromSeconds(v).Time + return newNumericDateFromSeconds(v).Time default: - return nil + return iat } } // NotBefore returns the *time.Time parsed nbf field of the MapClaims if present // or nil otherwise -func (m MapClaims) NotBefore() *time.Time { +func (m MapClaims) NotBefore() interface{} { v := m["nbf"] switch nbf := v.(type) { case float64: if nbf == 0 { return nil } - return &newNumericDateFromSeconds(nbf).Time + return newNumericDateFromSeconds(nbf).Time case json.Number: v, _ := nbf.Float64() if v == 0 { return nil } - return &newNumericDateFromSeconds(v).Time + return newNumericDateFromSeconds(v).Time default: - return nil + return m["nbf"] } } @@ -93,7 +93,7 @@ func (m MapClaims) Audience() ([]string, error) { if vs, ok := a.(string); ok { aud = append(aud, vs) } else { - multierror.Append(err, fmt.Errorf("aud entry [%v] is not a string", a)) + err = multierror.Append(err, fmt.Errorf("aud entry [%v] is not a string", a)) } } } @@ -114,7 +114,14 @@ func (m MapClaims) VerifyAudience(cmp string, req bool) bool { // If req is false, it will return true, if exp is unset. func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { cmpTime := time.Unix(cmp, 0) - return verifyExp(m.ExpiresAt(), cmpTime, req) + exp := m.ExpiresAt() + if exp == nil { + return verifyExp(nil, cmpTime, req) + } + if t, ok := exp.(time.Time); ok { + return verifyExp(&t, cmpTime, req) + } + return false } // VerifyIssuedAt compares the exp claim against cmp (cmp >= iat). @@ -144,7 +151,14 @@ func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { // VerifyNotBefore compares the nbf claim against cmp (cmp >= nbf). // If req is false, it will return true, if nbf is unset. func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { - return verifyNbf(m.NotBefore(), time.Unix(cmp, 0), req) + nbf := m.NotBefore() + if nbf == nil { + return verifyNbf(nil, time.Unix(cmp, 0), req) + } + if t, ok := nbf.(time.Time); ok { + return verifyNbf(&t, time.Unix(cmp, 0), req) + } + return false } // VerifyIssuer compares the iss claim against cmp. @@ -161,21 +175,25 @@ func (m MapClaims) Valid() error { var result *multierror.Error now := TimeFunc() nowUnix := now.Unix() + exp, _ := m.ExpiresAt().(time.Time) if !m.VerifyExpiresAt(nowUnix, false) { result = multierror.Append(result, &ExpiredError{ - ExpiredAt: *m.ExpiresAt(), + ExpiredAt: exp, AttemptedAt: now, }) } if !m.VerifyIssuedAt(nowUnix, false) { + iat, _ := m.IssuedAt().(time.Time) result = multierror.Append(result, &UsedBeforeIssuedError{ - IssuedAt: *m.IssuedAt(), + IssuedAt: iat, AttemptedAt: now, }) } if !m.VerifyNotBefore(nowUnix, false) { + nbf, _ := m.NotBefore().(time.Time) + result = multierror.Append(result, &NotYetValidError{ - ValidAt: *m.NotBefore(), + ValidAt: nbf, AttemptedAt: now, }) } From 08e560599852d124b6892620854c61b8a438cdd5 Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 15:14:27 -0400 Subject: [PATCH 4/9] fixes another map claims test --- map_claims.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/map_claims.go b/map_claims.go index 795c0bf6..698a4a49 100644 --- a/map_claims.go +++ b/map_claims.go @@ -67,17 +67,17 @@ func (m MapClaims) NotBefore() interface{} { } return newNumericDateFromSeconds(v).Time default: - return m["nbf"] + return v } } // Issuer returns the iss field of the MapClaims -func (m MapClaims) Issuer() string { +func (m MapClaims) Issuer() interface{} { iss := m["iss"] if str, ok := iss.(string); ok { return str } - return "" + return iss } func (m MapClaims) Audience() ([]string, error) { @@ -164,7 +164,13 @@ func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { // VerifyIssuer compares the iss claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { - return verifyIss(m.Issuer(), cmp, req) + iss := m.Issuer() + if str, ok := iss.(string); ok { + return verifyIss(str, cmp, req) + } else if iss == nil { + return verifyIss("", cmp, req) + } + return false } // Valid validates time based claims "exp, iat, nbf". From c10e445dd20a0d972aed0f420e1a1dd3731719e0 Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 15:48:28 -0400 Subject: [PATCH 5/9] fixes a few more tests; still a few outstanding issues --- claims.go | 4 ++-- map_claims.go | 5 +++-- parser.go | 10 +++++----- parser_test.go | 4 +++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/claims.go b/claims.go index 741477d8..2b53e358 100644 --- a/claims.go +++ b/claims.go @@ -50,7 +50,7 @@ type RegisteredClaims struct { // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. func (c RegisteredClaims) Valid() error { - var result *multierror.Error + result := &multierror.Error{} result.ErrorFormat = ValidationErrorFormat now := TimeFunc() @@ -137,7 +137,7 @@ type StandardClaims struct { // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. func (c StandardClaims) Valid() error { - var result *multierror.Error + result := &multierror.Error{} result.ErrorFormat = ValidationErrorFormat now := TimeFunc() diff --git a/map_claims.go b/map_claims.go index 698a4a49..f83558ea 100644 --- a/map_claims.go +++ b/map_claims.go @@ -27,7 +27,7 @@ func (m MapClaims) ExpiresAt() interface{} { } return newNumericDateFromSeconds(v).Time default: - return nil + return exp } } @@ -178,7 +178,8 @@ func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { // As well, if any of the above claims are not in the token, it will still // be considered a valid claim. func (m MapClaims) Valid() error { - var result *multierror.Error + result := &multierror.Error{} + result.ErrorFormat = ValidationErrorFormat now := TimeFunc() nowUnix := now.Unix() exp, _ := m.ExpiresAt().(time.Time) diff --git a/parser.go b/parser.go index cbcbbedd..80c81973 100644 --- a/parser.go +++ b/parser.go @@ -22,7 +22,7 @@ func (p *Parser) Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyfunc) (*Token, error) { token, parts, err := p.ParseUnverified(tokenString, claims) if err != nil { - return nil, err + return token, err } // Verify signing method is in the required set @@ -37,7 +37,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf } if !signingMethodValid { // signing method is not in the listed set - return nil, &InvalidSigningMethodError{Alg: alg} + return token, &InvalidSigningMethodError{Alg: alg} } } @@ -45,18 +45,18 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf var key interface{} if keyFunc == nil { // keyFunc was not provided. short circuiting validation - return nil, ErrMissingKeyFunc + return token, ErrMissingKeyFunc } key, err = keyFunc(token) if err != nil { - return nil, err + return token, err } // Validate Claims if !p.SkipClaimsValidation { if err := token.Claims.Valid(); err != nil { - return nil, err + return token, err } } diff --git a/parser_test.go b/parser_test.go index fd09aa70..52430f2a 100644 --- a/parser_test.go +++ b/parser_test.go @@ -292,7 +292,9 @@ func TestParser_Parse(t *testing.T) { case *jwt.RegisteredClaims: token, err = parser.ParseWithClaims(data.tokenString, &jwt.RegisteredClaims{}, data.keyfunc) } - + if token == nil { + fmt.Println("token is nil") + } // Verify result matches expectation if !reflect.DeepEqual(data.claims, token.Claims) { t.Errorf("[%v] Claims mismatch. Expecting: %v Got: %v", data.name, data.claims, token.Claims) From b3b0922f8be74f25253697c9248d98d8d2a9ea83 Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 15:57:26 -0400 Subject: [PATCH 6/9] fixes a few more tests --- errors.go | 4 ++-- parser_test.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/errors.go b/errors.go index 9009f371..cdd0b5ae 100644 --- a/errors.go +++ b/errors.go @@ -114,14 +114,14 @@ type ExpiredError struct { } func (err *ExpiredError) Delta() time.Duration { - return err.ExpiredAt.Sub(err.AttemptedAt) + return err.AttemptedAt.Sub(err.ExpiredAt) } func (err *ExpiredError) Error() string { return fmt.Sprintf("token is expired by %v", err.Delta()) } func (err *ExpiredError) Unwrap() error { - return ErrTokenNotYetValid + return ErrTokenExpired } // The errors that might occur when parsing and validating a token diff --git a/parser_test.go b/parser_test.go index 52430f2a..fc01a2c9 100644 --- a/parser_test.go +++ b/parser_test.go @@ -293,7 +293,7 @@ func TestParser_Parse(t *testing.T) { token, err = parser.ParseWithClaims(data.tokenString, &jwt.RegisteredClaims{}, data.keyfunc) } if token == nil { - fmt.Println("token is nil") + panic("token is nil") } // Verify result matches expectation if !reflect.DeepEqual(data.claims, token.Claims) { @@ -317,8 +317,8 @@ func TestParser_Parse(t *testing.T) { t.Errorf("[%v] Expecting error. Didn't get one.", data.name) } else { for _, expectedError := range data.errors { - if !errors.Is(expectedError, err) { - t.Errorf("[%v] Expected %v, received: %v", data.name, errMap[expectedError], err) + if !errors.Is(err, expectedError) { + t.Errorf(`[%v] Expected "%v", received: %v`, data.name, errMap[expectedError], err) } } } From a1604f2d3c08b8fc1d9f20cdd86cd7be924c94d4 Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 16:05:24 -0400 Subject: [PATCH 7/9] 4 tests are still failing --- errors.go | 2 +- parser.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/errors.go b/errors.go index cdd0b5ae..f2cb9c18 100644 --- a/errors.go +++ b/errors.go @@ -22,7 +22,7 @@ var ValidationErrorFormat = func(errs []error) string { // Error constants var ( ErrMalformedToken = errors.New("jwt: token is malformed") - ErrTokenContainsBearer = errors.New(`token may not contain "bearer"`) + ErrTokenContainsBearer = errors.New(`jwt: token may not contain "bearer "`) ErrInvalidSigningMethod = errors.New("jwt: invalid signing method") ErrUnregisteredSigningMethod = errors.New("jwt: signing method not registered") ErrInvalidKey = errors.New("jwt: key is invalid") diff --git a/parser.go b/parser.go index 80c81973..aba6ba13 100644 --- a/parser.go +++ b/parser.go @@ -124,7 +124,7 @@ func (p *Parser) ParseUnverified(tokenString string, claims Claims) (token *Toke // Lookup signature method alg, ok := token.Header["alg"].(string) - if !ok { + if !ok || len(alg) == 0 { return token, parts, MalformedTokenError("signing method (alg) not specified") } token.Method = GetSigningMethod(alg) From 27b7b8bc09cd1698c37ba6c6f684c1832187ae5e Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 16:50:38 -0400 Subject: [PATCH 8/9] fixes all but 1 test --- ecdsa.go | 12 +++++++----- ed25519.go | 8 ++------ errors.go | 26 +++++++++++++++++++++++++- hmac.go | 12 ++++++------ parser_test.go | 7 ++++--- rsa.go | 6 +++++- 6 files changed, 49 insertions(+), 22 deletions(-) diff --git a/ecdsa.go b/ecdsa.go index eac023fc..1f0b8144 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -4,13 +4,11 @@ import ( "crypto" "crypto/ecdsa" "crypto/rand" - "errors" "math/big" ) var ( - // Sadly this is missing from crypto/ecdsa compared to crypto/rsa - ErrECDSAVerification = errors.New("crypto/ecdsa: verification error") +// Sadly this is missing from crypto/ecdsa compared to crypto/rsa ) // SigningMethodECDSA implements the ECDSA family of signing methods. @@ -74,7 +72,9 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa } if len(sig) != 2*m.KeySize { - return ErrECDSAVerification + return &SignatureVerificationError{ + Algorithm: m.Name, + } } r := big.NewInt(0).SetBytes(sig[:m.KeySize]) @@ -92,7 +92,9 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa return nil } - return ErrECDSAVerification + return &SignatureVerificationError{ + Algorithm: m.Name, + } } // Sign implements token signing for the SigningMethod. diff --git a/ed25519.go b/ed25519.go index 07d3aacd..8366b7b6 100644 --- a/ed25519.go +++ b/ed25519.go @@ -1,16 +1,12 @@ package jwt import ( - "errors" - "crypto" "crypto/ed25519" "crypto/rand" ) -var ( - ErrEd25519Verification = errors.New("ed25519: verification error") -) +var () // SigningMethodEd25519 implements the EdDSA family. // Expects ed25519.PrivateKey for signing and ed25519.PublicKey for verification @@ -55,7 +51,7 @@ func (m *SigningMethodEd25519) Verify(signingString, signature string, key inter // Verify the signature if !ed25519.Verify(ed25519Key, []byte(signingString), sig) { - return ErrEd25519Verification + return &SignatureVerificationError{Algorithm: "EdDSA"} } return nil diff --git a/errors.go b/errors.go index f2cb9c18..fb23c02f 100644 --- a/errors.go +++ b/errors.go @@ -26,15 +26,39 @@ var ( ErrInvalidSigningMethod = errors.New("jwt: invalid signing method") ErrUnregisteredSigningMethod = errors.New("jwt: signing method not registered") ErrInvalidKey = errors.New("jwt: key is invalid") - ErrInvalidKeyType = errors.New("jwt: key is of invalid type") + ErrInvalidKeyType = errors.New("jwt: invalid key type") ErrHashUnavailable = errors.New("jwt: the requested hash function is unavailable") ErrTokenNotYetValid = errors.New("jwt: the token is not yet valid") ErrTokenExpired = errors.New("jwt: the token is expired") ErrTokenUsedBeforeIssued = errors.New("jwt: the token was used before issued") ErrNoneSignatureTypeDisallowed = errors.New(`jwt: "none" signature type is not allowed`) ErrMissingKeyFunc = errors.New("jwt: KeyFunc not provided") + ErrSignatureInvalid = errors.New("jwt: signature is invalid") ) +type SignatureVerificationError struct { + Algorithm string + err error +} + +func (err *SignatureVerificationError) Error() string { + return ErrSignatureInvalid.Error() + " [" + err.Algorithm + "]" +} + +func (err *SignatureVerificationError) Is(cmp error) bool { + if _, ok := cmp.(*SignatureVerificationError); ok { + return true + } + if errors.Is(err.err, cmp) { + return true + } + return errors.Is(err.Unwrap(), cmp) +} + +func (err *SignatureVerificationError) Unwrap() error { + return ErrSignatureInvalid +} + type MalformedTokenError string func (err MalformedTokenError) Error() string { diff --git a/hmac.go b/hmac.go index 011f68a2..3d578bad 100644 --- a/hmac.go +++ b/hmac.go @@ -3,7 +3,6 @@ package jwt import ( "crypto" "crypto/hmac" - "errors" ) // SigningMethodHMAC implements the HMAC-SHA family of signing methods. @@ -15,10 +14,9 @@ type SigningMethodHMAC struct { // Specific instances for HS256 and company var ( - SigningMethodHS256 *SigningMethodHMAC - SigningMethodHS384 *SigningMethodHMAC - SigningMethodHS512 *SigningMethodHMAC - ErrSignatureInvalid = errors.New("signature is invalid") + SigningMethodHS256 *SigningMethodHMAC + SigningMethodHS384 *SigningMethodHMAC + SigningMethodHS512 *SigningMethodHMAC ) func init() { @@ -70,7 +68,9 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac hasher := hmac.New(m.Hash.New, keyBytes) hasher.Write([]byte(signingString)) if !hmac.Equal(sig, hasher.Sum(nil)) { - return ErrSignatureInvalid + return &SignatureVerificationError{ + Algorithm: "HMAC", + } } // No validation errors. Signature is good. diff --git a/parser_test.go b/parser_test.go index fc01a2c9..de72d431 100644 --- a/parser_test.go +++ b/parser_test.go @@ -28,6 +28,7 @@ var errMap = map[error]string{ jwt.ErrTokenUsedBeforeIssued: "ErrTokenUsedBeforeIssued", jwt.ErrNoneSignatureTypeDisallowed: "ErrNoneSignatureTypeDisallowed", jwt.ErrMissingKeyFunc: "ErrMissingKeyFunc", + jwt.ErrSignatureInvalid: "ErrSignatureInvalid", } var ( @@ -113,7 +114,7 @@ var jwtTestData = []struct { nilKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - Errors{jwt.ErrUnregisteredSigningMethod}, + Errors{jwt.ErrMissingKeyFunc}, nil, }, { @@ -122,7 +123,7 @@ var jwtTestData = []struct { emptyKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - Errors{jwt.ErrSignatureInvalid}, + Errors{jwt.ErrInvalidKeyType}, nil, }, { @@ -131,7 +132,7 @@ var jwtTestData = []struct { errorKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - Errors{jwt.ErrSignatureInvalid}, + Errors{jwt.ErrInvalidKeyType}, nil, }, { diff --git a/rsa.go b/rsa.go index b910b19c..e02a4cbc 100644 --- a/rsa.go +++ b/rsa.go @@ -70,7 +70,11 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface hasher.Write([]byte(signingString)) // Verify the signature - return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + err = rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + if err != nil { + return &SignatureVerificationError{err: err, Algorithm: "RSA"} + } + return nil } // Sign implements token signing for the SigningMethod From 502d2136ede4f2c09a9d9e70d86e4b438b1c597c Mon Sep 17 00:00:00 2001 From: chanced Date: Thu, 2 Sep 2021 17:00:39 -0400 Subject: [PATCH 9/9] fixes last test; example tests are still broken --- errors.go | 23 +++++++++++++++++++++++ parser.go | 2 +- parser_test.go | 3 ++- 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/errors.go b/errors.go index fb23c02f..ccfbcece 100644 --- a/errors.go +++ b/errors.go @@ -34,8 +34,31 @@ var ( ErrNoneSignatureTypeDisallowed = errors.New(`jwt: "none" signature type is not allowed`) ErrMissingKeyFunc = errors.New("jwt: KeyFunc not provided") ErrSignatureInvalid = errors.New("jwt: signature is invalid") + ErrKeyFuncError = errors.New("jwt: KeyFunc returned an error") ) +type KeyFuncError struct { + Err error +} + +func (err *KeyFuncError) Error() string { + return ErrKeyFuncError.Error() + "\n\t" + err.Err.Error() +} + +func (err *KeyFuncError) Unwrap() error { + return err.Err +} + +func (err *KeyFuncError) Is(target error) bool { + if _, ok := target.(*KeyFuncError); ok { + return true + } + if errors.Is(err.Err, target) { + return true + } + return errors.Is(target, ErrKeyFuncError) +} + type SignatureVerificationError struct { Algorithm string err error diff --git a/parser.go b/parser.go index aba6ba13..48c76b4f 100644 --- a/parser.go +++ b/parser.go @@ -50,7 +50,7 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf key, err = keyFunc(token) if err != nil { - return token, err + return token, &KeyFuncError{Err: err} } // Validate Claims diff --git a/parser_test.go b/parser_test.go index de72d431..06172cc5 100644 --- a/parser_test.go +++ b/parser_test.go @@ -29,6 +29,7 @@ var errMap = map[error]string{ jwt.ErrNoneSignatureTypeDisallowed: "ErrNoneSignatureTypeDisallowed", jwt.ErrMissingKeyFunc: "ErrMissingKeyFunc", jwt.ErrSignatureInvalid: "ErrSignatureInvalid", + jwt.ErrKeyFuncError: "ErrKeyFuncError", } var ( @@ -132,7 +133,7 @@ var jwtTestData = []struct { errorKeyFunc, jwt.MapClaims{"foo": "bar"}, false, - Errors{jwt.ErrInvalidKeyType}, + Errors{jwt.ErrKeyFuncError, errKeyFuncError}, nil, }, {