Skip to content

Commit

Permalink
First attempt at JWT-based client-auth
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaceccanti committed Oct 17, 2021
1 parent 767e86e commit fc7148d
Show file tree
Hide file tree
Showing 11 changed files with 399 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,42 @@
*/
package it.infn.mw.iam.config.security;

import static java.util.Collections.singletonList;
import static org.springframework.http.HttpMethod.OPTIONS;

import java.time.Clock;

import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
import org.mitre.oauth2.service.impl.DefaultClientUserDetailsService;
import org.mitre.oauth2.web.CorsFilter;
import org.mitre.openid.connect.assertion.JWTBearerClientAssertionTokenEndpointFilter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.Order;
import org.springframework.security.authentication.ProviderManager;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.oauth2.provider.client.ClientCredentialsTokenEndpointFilter;
import org.springframework.security.oauth2.provider.error.OAuth2AccessDeniedHandler;
import org.springframework.security.oauth2.provider.error.OAuth2AuthenticationEntryPoint;
import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
import org.springframework.security.web.context.SecurityContextPersistenceFilter;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;

import it.infn.mw.iam.config.IamProperties;
import it.infn.mw.iam.core.oauth.assertion.IAMJWTBearerAuthenticationProvider;

@Configuration
@Order(-1)
public class IamTokenEndointSecurityConfig extends WebSecurityConfigurerAdapter {

public static final String TOKEN_ENDPOINT = "/token";

@Autowired
private CorsFilter corsFilter;

Expand All @@ -48,7 +59,16 @@ public class IamTokenEndointSecurityConfig extends WebSecurityConfigurerAdapter

@Autowired
@Qualifier("clientUserDetailsService")
private UserDetailsService userDetailsService;
private DefaultClientUserDetailsService userDetailsService;

@Autowired
private Clock clock;

@Autowired
private ClientKeyCacheService validators;

@Autowired
private IamProperties iamProperties;

@Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception {
Expand All @@ -63,6 +83,20 @@ public ClientCredentialsTokenEndpointFilter ccFilter() throws Exception {
return filter;
}

@Bean
public JWTBearerClientAssertionTokenEndpointFilter jwtBearerFilter() throws Exception {

JWTBearerClientAssertionTokenEndpointFilter filter =
new JWTBearerClientAssertionTokenEndpointFilter(new AntPathRequestMatcher(TOKEN_ENDPOINT));

IAMJWTBearerAuthenticationProvider authProvider = new IAMJWTBearerAuthenticationProvider(clock,
iamProperties, userDetailsService.getClientDetailsService(), validators);

filter.setAuthenticationManager(new ProviderManager(singletonList(authProvider)));

return filter;
}

@Override
protected void configure(HttpSecurity http) throws Exception {

Expand All @@ -78,7 +112,8 @@ protected void configure(HttpSecurity http) throws Exception {
.antMatchers(OPTIONS, TOKEN_ENDPOINT).permitAll()
.antMatchers(TOKEN_ENDPOINT).authenticated()
.and()
.addFilterBefore(ccFilter(), AbstractPreAuthenticatedProcessingFilter.class)
.addFilterBefore(jwtBearerFilter(), AbstractPreAuthenticatedProcessingFilter.class)
.addFilterBefore(ccFilter(), BasicAuthenticationFilter.class)
.addFilterBefore(corsFilter, SecurityContextPersistenceFilter.class)
.exceptionHandling()
.authenticationEntryPoint(authenticationEntryPoint)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package it.infn.mw.iam.core.oauth.assertion;

import static java.lang.String.format;
import static java.util.Objects.isNull;

import java.text.ParseException;
import java.time.Clock;
import java.time.Instant;
import java.util.Date;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

import org.mitre.jwt.signer.service.JWTSigningAndValidationService;
import org.mitre.jwt.signer.service.impl.ClientKeyCacheService;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.ClientDetailsEntity.AuthMethod;
import org.mitre.oauth2.service.ClientDetailsEntityService;
import org.mitre.openid.connect.assertion.JWTBearerAssertionAuthenticationToken;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.AuthenticationServiceException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.userdetails.UsernameNotFoundException;

import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

import it.infn.mw.iam.config.IamProperties;

public class IAMJWTBearerAuthenticationProvider implements AuthenticationProvider {

public static final Logger LOG =
LoggerFactory.getLogger(IAMJWTBearerAuthenticationProvider.class);

private static final GrantedAuthority ROLE_CLIENT = new SimpleGrantedAuthority("ROLE_CLIENT");

private final int CLOCK_SKEW_IN_SECONDS = 300;

private final Clock clock;
private final ClientDetailsEntityService clientService;
private final ClientKeyCacheService validators;

private final String TOKEN_ENDPOINT;

public IAMJWTBearerAuthenticationProvider(Clock clock, IamProperties iamProperties,
ClientDetailsEntityService clientService, ClientKeyCacheService validators) {

this.clock = clock;
this.clientService = clientService;
this.validators = validators;

if (iamProperties.getIssuer().endsWith("/")) {
TOKEN_ENDPOINT = iamProperties.getIssuer() + "token";
} else {
TOKEN_ENDPOINT = iamProperties.getIssuer() + "/token";
}

}

private void clientAuthMethodChecks(ClientDetailsEntity client, SignedJWT jws) {

if (client.getTokenEndpointAuthMethod() == null
|| client.getTokenEndpointAuthMethod().equals(AuthMethod.NONE)
|| client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_BASIC)
|| client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_POST)) {

throw new AuthenticationServiceException("Unsupported authentication method.");
}

JWSAlgorithm alg = jws.getHeader().getAlgorithm();

if (client.getTokenEndpointAuthSigningAlg() != null
&& !client.getTokenEndpointAuthSigningAlg().equals(alg)) {
invalidBearerAssertion("Invalid signature algorithm: " + alg.getName());
}

if (client.getTokenEndpointAuthMethod().equals(AuthMethod.PRIVATE_KEY)) {
if (!JWSAlgorithm.Family.RSA.contains(alg) && !JWSAlgorithm.Family.EC.contains(alg)) {
invalidBearerAssertion("Invalid signature algorithm: " + alg.getName());
}
} else if (client.getTokenEndpointAuthMethod().equals(AuthMethod.SECRET_JWT)) {
if (!JWSAlgorithm.Family.HMAC_SHA.contains(alg)) {
invalidBearerAssertion("Invalid signature algorithm: " + alg.getName());
}
}
}

private void signatureChecks(ClientDetailsEntity client, SignedJWT jws) {
JWSAlgorithm alg = jws.getHeader().getAlgorithm();

JWTSigningAndValidationService validator =
Optional.ofNullable(validators.getValidator(client, alg))
.orElseThrow(() -> new AuthenticationServiceException(
format("Unable to resolve validator for client '%s' and algorithm '%s'",
client.getClientId(), alg.getName())));

if (!validator.validateSignature(jws)) {
invalidBearerAssertion("invalid signature");
}
}

private void invalidBearerAssertion(String msg) {
throw new AuthenticationServiceException(
String.format("invalid jwt bearer assertion: %s", msg));
}

private void assertionChecks(ClientDetailsEntity client, SignedJWT jws) throws ParseException {

JWTClaimsSet jwtClaims = jws.getJWTClaimsSet();

if (isNull(jwtClaims.getIssuer())) {
invalidBearerAssertion("issuer is null");
} else if (!jwtClaims.getIssuer().equals(client.getClientId())) {
invalidBearerAssertion("issuer does not match client id");
}

if (isNull(jwtClaims.getExpirationTime())) {
invalidBearerAssertion("expiration time not set");
}

Instant nowSkewed = clock.instant().minusSeconds(CLOCK_SKEW_IN_SECONDS);

if (Date.from(nowSkewed).after(jwtClaims.getExpirationTime())) {
invalidBearerAssertion("expired assertion token");
}

if (!isNull(jwtClaims.getNotBeforeTime())) {

nowSkewed = clock.instant().plusSeconds(CLOCK_SKEW_IN_SECONDS);
if (Date.from(nowSkewed).before(jwtClaims.getNotBeforeTime())) {
invalidBearerAssertion("assertion is not yet valid");
}
}

if (!isNull(jwtClaims.getIssueTime())) {
nowSkewed = clock.instant().plusSeconds(CLOCK_SKEW_IN_SECONDS);
if (Date.from(nowSkewed).before(jwtClaims.getIssueTime())) {
invalidBearerAssertion("assertion was issued in the future");
}
}

if (isNull(jwtClaims.getAudience())) {
invalidBearerAssertion("assertion audience is null");
} else {
if (!jwtClaims.getAudience().contains(TOKEN_ENDPOINT)) {
invalidBearerAssertion("invalid audience");
}
}
}

@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {

JWTBearerAssertionAuthenticationToken jwtAuth =
(JWTBearerAssertionAuthenticationToken) authentication;

ClientDetailsEntity client = clientService.loadClientByClientId(jwtAuth.getName());

if (isNull(client)) {
throw new UsernameNotFoundException("Unknown client: " + jwtAuth.getName());
}

try {


final JWT jwt = jwtAuth.getJwt();

if (!(jwt instanceof SignedJWT)) {
invalidBearerAssertion("Unsupported JWT type: " + jwt.getClass().getName());
}

SignedJWT jws = (SignedJWT) jwt;

clientAuthMethodChecks(client, jws);

signatureChecks(client, jws);

assertionChecks(client, jws);

Set<GrantedAuthority> authorities = new HashSet<>(client.getAuthorities());
authorities.add(ROLE_CLIENT);

return new JWTBearerAssertionAuthenticationToken(jwt, authorities);

} catch (ParseException e) {
throw new AuthenticationServiceException("JWT parse error:" + e.getMessage(), e);
}
}


@Override
public boolean supports(Class<?> authentication) {
return JWTBearerAssertionAuthenticationToken.class.isAssignableFrom(authentication);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,8 @@ protected JWTClaimsSet.Builder baseJWTSetup(OAuth2AccessTokenEntity token,
.subject(subject)
.jwtID(UUID.randomUUID().toString());

if (!authentication.isClientOnly()) {
builder.claim(CLIENT_ID_CLAIM_NAME, token.getClient().getClientId());
}

builder.claim(CLIENT_ID_CLAIM_NAME, token.getClient().getClientId());

String audience = null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ public IamJWTProfileUserinfoHelper(IamProperties props, UserInfoService userInfo

@Override
public UserInfo resolveUserInfo(OAuth2Authentication authentication) {
final String username = authentication.getName();

UserInfo ui = getUserInfoService().getByUsernameAndClientId(username,
authentication.getOAuth2Request().getClientId());
UserInfo ui = lookupUserinfo(authentication);

if (isNull(ui)) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public JsonObject toJson() {
JsonObject json = super.toJson();

json.remove("groups");
json.remove("organisation_name");

return json;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ logging.level.org.opensaml.saml2.metadata.provider=INFO
#logging.level.org.eclipse.persistence=DEBUG

# Test logging
# logging.level.org.springframework.test.web.servlet.result=DEBUG
logging.level.org.springframework.test.web.servlet.result=DEBUG

#logging.level.org.apache.jasper.servlet.TldScanner=DEBUG

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package it.infn.mw.iam.test.oauth.assertions;

import static java.util.Collections.singletonList;

import java.time.Instant;
import java.util.Date;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.MACSigner;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;

import it.infn.mw.iam.test.oauth.EndpointsTestUtils;

public class JWTBearerClientAuthenticationTestSupport extends EndpointsTestUtils {

public static final String CLIENT_ID_SECRET_JWT = "jwt-auth-client_secret_jwt";
public static final String CLIENT_ID_SECRET_JWT_SECRET = "c8e9eed0-e6e4-4a66-b16e-6f37096356a7";
public static final String TOKEN_ENDPOINT_AUDIENCE = "http://localhost:8080/token";
public static final String TOKEN_ENDPOINT = "/token";
public static final String JWT_BEARER_ASSERTION_TYPE =
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer";

public SignedJWT createClientAuthToken(String clientId, Instant expirationTime)
throws JOSEException {

JWSSigner signer = new MACSigner(CLIENT_ID_SECRET_JWT_SECRET);
JWTClaimsSet claimsSet = new JWTClaimsSet.Builder().subject(clientId)
.issuer(clientId)
.expirationTime(Date.from(expirationTime))
.audience(singletonList(TOKEN_ENDPOINT_AUDIENCE))
.build();

SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.HS256), claimsSet);

signedJWT.sign(signer);

return signedJWT;
}

}
Loading

0 comments on commit fc7148d

Please sign in to comment.