Skip to content

Commit

Permalink
[SDK-3231] Added support for multiple checks on a single claim (#573)
Browse files Browse the repository at this point in the history
* Added support for multiple checks on a single claim

* Fix codecov CI failure

* Allow pseudo comparison for codecov

* Remove newly added parameters and check codecov disabled

* Reenabled changes check in codecov

* Trigger Build

* Refactor ExpectedCheckHolder from interfaces to impl package

* Refactored code to improve coverage report
  • Loading branch information
poovamraj authored Apr 13, 2022
1 parent 0f8a9bf commit af04b22
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 137 deletions.
153 changes: 95 additions & 58 deletions lib/src/main/java/com/auth0/jwt/JWTVerifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.auth0.jwt.impl.PublicClaims;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.impl.ExpectedCheckHolder;
import com.auth0.jwt.interfaces.Verification;

import java.time.Clock;
Expand All @@ -25,12 +26,12 @@
*/
public final class JWTVerifier implements com.auth0.jwt.interfaces.JWTVerifier {
private final Algorithm algorithm;
final Map<String, BiPredicate<Claim, DecodedJWT>> expectedChecks;
final List<ExpectedCheckHolder> expectedChecks;
private final JWTParser parser;

JWTVerifier(Algorithm algorithm, Map<String, BiPredicate<Claim, DecodedJWT>> expectedChecks) {
JWTVerifier(Algorithm algorithm, List<ExpectedCheckHolder> expectedChecks) {
this.algorithm = algorithm;
this.expectedChecks = Collections.unmodifiableMap(expectedChecks);
this.expectedChecks = Collections.unmodifiableList(expectedChecks);
this.parser = new JWTParser();
}

Expand All @@ -50,7 +51,7 @@ static Verification init(Algorithm algorithm) throws IllegalArgumentException {
*/
public static class BaseVerification implements Verification {
private final Algorithm algorithm;
private final Map<String, BiPredicate<Claim, DecodedJWT>> expectedChecks;
private final List<ExpectedCheckHolder> expectedChecks;
private long defaultLeeway;
private final Map<String, Long> customLeeways;
private boolean ignoreIssuedAt;
Expand All @@ -62,15 +63,18 @@ public static class BaseVerification implements Verification {
}

this.algorithm = algorithm;
this.expectedChecks = new LinkedHashMap<>();
this.expectedChecks = new ArrayList<>();
this.customLeeways = new HashMap<>();
this.defaultLeeway = 0;
}

@Override
public Verification withIssuer(String... issuer) {
List<String> value = isNullOrEmpty(issuer) ? null : Arrays.asList(issuer);
checkIfNeedToRemove(PublicClaims.ISSUER, value, ((claim, decodedJWT) -> {
addCheck(PublicClaims.ISSUER, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (value == null || !value.contains(claim.asString())) {
throw new IncorrectClaimException(
"The Claim 'iss' value doesn't match the required issuer.", PublicClaims.ISSUER, claim);
Expand All @@ -82,23 +86,40 @@ public Verification withIssuer(String... issuer) {

@Override
public Verification withSubject(String subject) {
checkIfNeedToRemove(PublicClaims.SUBJECT, subject, (claim, decodedJWT) -> subject.equals(claim.asString()));
addCheck(PublicClaims.SUBJECT, (claim, decodedJWT) ->
verifyNull(claim, subject) || subject.equals(claim.asString()));
return this;
}

@Override
public Verification withAudience(String... audience) {
List<String> value = isNullOrEmpty(audience) ? null : Arrays.asList(audience);
checkIfNeedToRemove(PublicClaims.AUDIENCE, value, ((claim, decodedJWT) ->
assertValidAudienceClaim(claim, decodedJWT.getAudience(), value, true)));
addCheck(PublicClaims.AUDIENCE, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (!assertValidAudienceClaim(decodedJWT.getAudience(), value, true)) {
throw new IncorrectClaimException("The Claim 'aud' value doesn't contain the required audience.",
PublicClaims.AUDIENCE, claim);
}
return true;
}));
return this;
}

@Override
public Verification withAnyOfAudience(String... audience) {
List<String> value = isNullOrEmpty(audience) ? null : Arrays.asList(audience);
checkIfNeedToRemove(PublicClaims.AUDIENCE, value, ((claim, decodedJWT) ->
assertValidAudienceClaim(claim, decodedJWT.getAudience(), value, false)));
addCheck(PublicClaims.AUDIENCE, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (!assertValidAudienceClaim(decodedJWT.getAudience(), value, false)) {
throw new IncorrectClaimException("The Claim 'aud' value doesn't contain the required audience.",
PublicClaims.AUDIENCE, claim);
}
return true;
}));
return this;
}

Expand Down Expand Up @@ -138,14 +159,16 @@ public Verification ignoreIssuedAt() {

@Override
public Verification withJWTId(String jwtId) {
checkIfNeedToRemove(PublicClaims.JWT_ID, jwtId, ((claim, decodedJWT) -> jwtId.equals(claim.asString())));
addCheck(PublicClaims.JWT_ID, ((claim, decodedJWT) ->
verifyNull(claim, jwtId) || jwtId.equals(claim.asString())));
return this;
}

@Override
public Verification withClaimPresence(String name) throws IllegalArgumentException {
assertNonNull(name);
withClaim(name, ((claim, decodedJWT) -> assertClaimPresence(name, claim)));
//since addCheck already checks presence, we just return true
withClaim(name, ((claim, decodedJWT) -> true));
return this;
}

Expand All @@ -159,35 +182,40 @@ public Verification withNullClaim(String name) throws IllegalArgumentException {
@Override
public Verification withClaim(String name, Boolean value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asBoolean())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asBoolean())));
return this;
}

@Override
public Verification withClaim(String name, Integer value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asInt())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asInt())));
return this;
}

@Override
public Verification withClaim(String name, Long value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asLong())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asLong())));
return this;
}

@Override
public Verification withClaim(String name, Double value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asDouble())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asDouble())));
return this;
}

@Override
public Verification withClaim(String name, String value) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, value, ((claim, decodedJWT) -> value.equals(claim.asString())));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asString())));
return this;
}

Expand All @@ -201,37 +229,42 @@ public Verification withClaim(String name, Instant value) throws IllegalArgument
assertNonNull(name);
// Since date-time claims are serialized as epoch seconds,
// we need to compare them with only seconds-granularity
checkIfNeedToRemove(name, value,
((claim, decodedJWT) -> value.truncatedTo(ChronoUnit.SECONDS).equals(claim.asInstant())));
addCheck(name,
((claim, decodedJWT) -> verifyNull(claim, value)
|| value.truncatedTo(ChronoUnit.SECONDS).equals(claim.asInstant())));
return this;
}

@Override
public Verification withClaim(String name, BiPredicate<Claim, DecodedJWT> predicate)
throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, predicate, predicate);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, predicate)
|| predicate.test(claim, decodedJWT)));
return this;
}

@Override
public Verification withArrayClaim(String name, String... items) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, items, ((claim, decodedJWT) -> assertValidCollectionClaim(claim, items)));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}

@Override
public Verification withArrayClaim(String name, Integer... items) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, items, ((claim, decodedJWT) -> assertValidCollectionClaim(claim, items)));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}

@Override
public Verification withArrayClaim(String name, Long... items) throws IllegalArgumentException {
assertNonNull(name);
checkIfNeedToRemove(name, items, ((claim, decodedJWT) -> assertValidCollectionClaim(claim, items)));
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}

Expand Down Expand Up @@ -268,13 +301,13 @@ private void addMandatoryClaimChecks() {
long notBeforeLeeway = getLeewayFor(PublicClaims.NOT_BEFORE);
long issuedAtLeeway = getLeewayFor(PublicClaims.ISSUED_AT);

expectedChecks.put(PublicClaims.EXPIRES_AT, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.EXPIRES_AT, claim, expiresAtLeeway, true));
expectedChecks.put(PublicClaims.NOT_BEFORE, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.NOT_BEFORE, claim, notBeforeLeeway, false));
expectedChecks.add(constructExpectedCheck(PublicClaims.EXPIRES_AT, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.EXPIRES_AT, claim, expiresAtLeeway, true)));
expectedChecks.add(constructExpectedCheck(PublicClaims.NOT_BEFORE, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.NOT_BEFORE, claim, notBeforeLeeway, false)));
if (!ignoreIssuedAt) {
expectedChecks.put(PublicClaims.ISSUED_AT, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.ISSUED_AT, claim, issuedAtLeeway, false));
expectedChecks.add(constructExpectedCheck(PublicClaims.ISSUED_AT, (claim, decodedJWT) ->
assertValidInstantClaim(PublicClaims.ISSUED_AT, claim, issuedAtLeeway, false)));
}
}

Expand All @@ -294,8 +327,7 @@ private boolean assertValidCollectionClaim(Claim claim, Object[] expectedClaimVa
}
}
} else {
claimArr = claim.isNull() || claim.isMissing()
? Collections.emptyList() : Arrays.asList(claim.as(Object[].class));
claimArr = Arrays.asList(claim.as(Object[].class));
}
List<Object> valueArr = Arrays.asList(expectedClaimValue);
return claimArr.containsAll(valueArr);
Expand Down Expand Up @@ -329,24 +361,12 @@ private boolean assertInstantIsPast(Instant claimVal, long leeway, Instant now)
}

private boolean assertValidAudienceClaim(
Claim claim,
List<String> audience,
List<String> values,
boolean shouldContainAll
) {
if (audience == null || (shouldContainAll && !audience.containsAll(values))
|| (!shouldContainAll && Collections.disjoint(audience, values))) {
throw new IncorrectClaimException(
"The Claim 'aud' value doesn't contain the required audience.", PublicClaims.AUDIENCE, claim);
}
return true;
}

private boolean assertClaimPresence(String name, Claim claim) {
if (claim.isMissing()) {
throw new MissingClaimException(name);
}
return true;
return !(audience == null || (shouldContainAll && !audience.containsAll(values))
|| (!shouldContainAll && Collections.disjoint(audience, values)));
}

private void assertPositive(long leeway) {
Expand All @@ -361,13 +381,31 @@ private void assertNonNull(String name) {
}
}

private void checkIfNeedToRemove(String name, Object value, BiPredicate<Claim, DecodedJWT> predicate) {
if (value == null) {
expectedChecks.remove(name);
return;
}
expectedChecks.put(name, (claim, decodedJWT) -> assertClaimPresence(name, claim)
&& predicate.test(claim, decodedJWT));
private void addCheck(String name, BiPredicate<Claim, DecodedJWT> predicate) {
expectedChecks.add(constructExpectedCheck(name, (claim, decodedJWT) -> {
if (claim.isMissing()) {
throw new MissingClaimException(name);
}
return predicate.test(claim, decodedJWT);
}));
}

private ExpectedCheckHolder constructExpectedCheck(String claimName, BiPredicate<Claim, DecodedJWT> check) {
return new ExpectedCheckHolder() {
@Override
public String getClaimName() {
return claimName;
}

@Override
public boolean verify(Claim claim, DecodedJWT decodedJWT) {
return check.test(claim, decodedJWT);
}
};
}

private boolean verifyNull(Claim claim, Object value) {
return value == null && claim.isNull();
}

private boolean isNullOrEmpty(String[] args) {
Expand Down Expand Up @@ -431,15 +469,14 @@ private void verifyAlgorithm(DecodedJWT jwt, Algorithm expectedAlgorithm) throws
}
}

private void verifyClaims(DecodedJWT jwt, Map<String, BiPredicate<Claim, DecodedJWT>> claims)
private void verifyClaims(DecodedJWT jwt, List<ExpectedCheckHolder> expectedChecks)
throws TokenExpiredException, InvalidClaimException {
for (Map.Entry<String, BiPredicate<Claim, DecodedJWT>> entry : claims.entrySet()) {
for (ExpectedCheckHolder expectedCheck : expectedChecks) {
boolean isValid;
String claimName = entry.getKey();
BiPredicate<Claim, DecodedJWT> expectedCheck = entry.getValue();
String claimName = expectedCheck.getClaimName();
Claim claim = jwt.getClaim(claimName);

isValid = expectedCheck.test(claim, jwt);
isValid = expectedCheck.verify(claim, jwt);

if (!isValid) {
throw new IncorrectClaimException(
Expand Down
25 changes: 25 additions & 0 deletions lib/src/main/java/com/auth0/jwt/impl/ExpectedCheckHolder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.auth0.jwt.impl;

import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;

/**
* This holds the checks that are run to verify a JWT.
*/
public interface ExpectedCheckHolder {
/**
* The claim name that will be checked.
*
* @return the claim name
*/
String getClaimName();

/**
* The verification that will be run.
*
* @param claim the claim for which verification is done
* @param decodedJWT the JWT on which verification is done
* @return whether the verification passed or not
*/
boolean verify(Claim claim, DecodedJWT decodedJWT);
}
Loading

0 comments on commit af04b22

Please sign in to comment.