diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/login/LoginInfoEndpoint.java b/server/src/main/java/org/cloudfoundry/identity/uaa/login/LoginInfoEndpoint.java index b0c617589ad..23f12936aca 100755 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/login/LoginInfoEndpoint.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/login/LoginInfoEndpoint.java @@ -92,6 +92,8 @@ import static java.util.Collections.emptyMap; import static java.util.Objects.isNull; import static java.util.Optional.ofNullable; +import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OAUTH20; +import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OIDC10; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.UAA; import static org.cloudfoundry.identity.uaa.util.UaaUrlUtils.addSubdomainToUrl; import static org.springframework.util.StringUtils.hasText; @@ -629,8 +631,10 @@ private Map getSamlIdentityProviderDefin protected Map getOauthIdentityProviderDefinitions(List allowedIdps) { - List identityProviders = - externalOAuthProviderConfigurator.retrieveAll(true, IdentityZoneHolder.get().getId()); + List identityProviders = externalOAuthProviderConfigurator.retrieveActiveByTypes( + IdentityZoneHolder.get().getId(), + OIDC10, OAUTH20 + ); return identityProviders.stream() .filter(p -> allowedIdps == null || allowedIdps.contains(p.getOriginKey())) diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/IdentityProviderProvisioning.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/IdentityProviderProvisioning.java index 24833de5143..4b6209af953 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/IdentityProviderProvisioning.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/IdentityProviderProvisioning.java @@ -27,6 +27,8 @@ public interface IdentityProviderProvisioning { List retrieveActive(String zoneId); + List retrieveActiveByTypes(String zoneId, String... types); + List retrieveAll(boolean activeOnly, String zoneId); IdentityProvider retrieveByOrigin(String origin, String zoneId); diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/JdbcIdentityProviderProvisioning.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/JdbcIdentityProviderProvisioning.java index 0e25ecc90fb..114fe6c4b1d 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/JdbcIdentityProviderProvisioning.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/JdbcIdentityProviderProvisioning.java @@ -1,6 +1,8 @@ package org.cloudfoundry.identity.uaa.provider; import static java.sql.Types.VARCHAR; +import static java.util.Collections.emptyList; +import static java.util.stream.Collectors.joining; import static org.cloudfoundry.identity.uaa.util.UaaStringUtils.isNotEmpty; import org.cloudfoundry.identity.uaa.audit.event.SystemDeletable; @@ -19,9 +21,13 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Arrays; import java.util.Date; +import java.util.HashSet; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.UUID; @Component("identityProviderProvisioning") @@ -37,6 +43,8 @@ public class JdbcIdentityProviderProvisioning implements IdentityProviderProvisi public static final String IDENTITY_ACTIVE_PROVIDERS_QUERY = IDENTITY_PROVIDERS_QUERY + " and active=?"; + public static final String IDENTITY_ACTIVE_PROVIDERS_OF_TYPE_QUERY_TEMPLATE = IDENTITY_ACTIVE_PROVIDERS_QUERY + " and type in (%s)"; + public static final String IDP_WITH_ALIAS_EXISTS_QUERY = "select 1 from identity_provider idp where idp.identity_zone_id = ? and idp.alias_zid <> '' limit 1"; public static final String ID_PROVIDER_UPDATE_FIELDS = "version,lastmodified,name,type,config,active,alias_id,alias_zid,external_key".replace(",", "=?,") + "=?"; @@ -85,6 +93,26 @@ public List retrieveActive(String zoneId) { return jdbcTemplate.query(IDENTITY_ACTIVE_PROVIDERS_QUERY, mapper, zoneId, true); } + @Override + public List retrieveActiveByTypes(final String zoneId, final String... types) { + if (ObjectUtils.isNotEmpty(types)) { + // eliminate duplicates + final Set typesAsSet = new HashSet<>(Arrays.asList(types)); + + // adjust the number of SQL parameters in the prepared statement + final String sqlPlaceholdersForTypes = typesAsSet.stream().map(type -> "?").collect(joining(",")); + final String sql = IDENTITY_ACTIVE_PROVIDERS_OF_TYPE_QUERY_TEMPLATE.formatted(sqlPlaceholdersForTypes); + + final ArrayList arrayList = new ArrayList<>(typesAsSet.size() + 2); + arrayList.add(zoneId); + arrayList.add(true); + arrayList.addAll(typesAsSet); + return jdbcTemplate.query(sql, mapper, arrayList.toArray()); + } else { + return emptyList(); + } + } + @Override public List retrieveAll(boolean activeOnly, String zoneId) { if (activeOnly) { diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/oauth/ExternalOAuthProviderConfigurator.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/oauth/ExternalOAuthProviderConfigurator.java index 31ef18f1594..745975f16a1 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/oauth/ExternalOAuthProviderConfigurator.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/oauth/ExternalOAuthProviderConfigurator.java @@ -29,10 +29,12 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Optional.ofNullable; +import static java.util.stream.Collectors.toSet; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OAUTH20; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OIDC10; @@ -179,6 +181,27 @@ public List retrieveActive(String zoneId) { return retrieveAll(true, zoneId); } + @Override + public List retrieveActiveByTypes(final String zoneId, final String... types) { + if (types == null || types.length == 0) { + return emptyList(); + } + + // intersect passed types with "oidc1.0" and "oauth2.0" + final Set filteredTypes = Arrays.stream(types) + .filter(type -> OIDC10.equals(type) || OAUTH20.equals(type)) + .collect(toSet()); + if (filteredTypes.isEmpty()) { + return emptyList(); + } + + final List idps = providerProvisioning.retrieveActiveByTypes( + zoneId, + filteredTypes.toArray(new String[0]) + ); + return overlayConfigurationsOfOidcIdps(idps); + } + public IdentityProvider retrieveByIssuer(String issuer, String zoneId) throws IncorrectResultSizeDataAccessException { IdentityProvider issuedProvider = null; int originLoopCheckDone = -1; @@ -214,16 +237,25 @@ public IdentityProvider retrieveByIssuer(String issuer, String zoneId) throws In @Override public List retrieveAll(boolean activeOnly, String zoneId) { final List types = Arrays.asList(OAUTH20, OIDC10); - List providers = providerProvisioning.retrieveAll(activeOnly, zoneId); - List overlayedProviders = new ArrayList<>(); - ofNullable(providers).orElse(emptyList()).stream() + final List providers = Optional.ofNullable( + providerProvisioning.retrieveAll(activeOnly, zoneId) + ).orElse(emptyList()); + final List oauthAndOidcProviders = providers.stream() .filter(p -> types.contains(p.getType())) + .toList(); + return overlayConfigurationsOfOidcIdps(oauthAndOidcProviders); + } + + private List overlayConfigurationsOfOidcIdps(final List providers) { + final List overlayedProviders = new ArrayList<>(); + providers.stream() .forEach(p -> { if (p.getType().equals(OIDC10)) { try { - OIDCIdentityProviderDefinition overlayedDefinition = overlay((OIDCIdentityProviderDefinition) p.getConfig()); + final OIDCIdentityProviderDefinition overlayedDefinition = overlay( + (OIDCIdentityProviderDefinition) p.getConfig()); p.setConfig(overlayedDefinition); - } catch (Exception e) { + } catch (final Exception e) { LOGGER.error("Identity provider excluded from login page due to a problem.", e); return; } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlIdentityProviderConfigurator.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlIdentityProviderConfigurator.java index 556113ab95b..38addd6e110 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlIdentityProviderConfigurator.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlIdentityProviderConfigurator.java @@ -2,7 +2,6 @@ import org.apache.http.client.utils.URIBuilder; import org.cloudfoundry.identity.uaa.constants.OriginKeys; -import org.cloudfoundry.identity.uaa.provider.IdentityProvider; import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning; import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition; import org.cloudfoundry.identity.uaa.zone.IdentityZone; @@ -19,7 +18,6 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; -import java.util.LinkedList; import java.util.List; import static org.springframework.util.StringUtils.hasText; @@ -44,27 +42,14 @@ public List getIdentityProviderDefinitions() { } public List getIdentityProviderDefinitionsForZone(IdentityZone zone) { - List result = new LinkedList<>(); - for (IdentityProvider provider : providerProvisioning.retrieveActive(zone.getId())) { - if (OriginKeys.SAML.equals(provider.getType())) { - result.add((SamlIdentityProviderDefinition) provider.getConfig()); - } - } - return result; + return providerProvisioning.retrieveActiveByTypes(zone.getId(), OriginKeys.SAML).stream() + .map(samlIdp -> (SamlIdentityProviderDefinition) samlIdp.getConfig()) + .toList(); } public List getIdentityProviderDefinitions(List allowedIdps, IdentityZone zone) { - List idpsInTheZone = getIdentityProviderDefinitionsForZone(zone); - if (allowedIdps != null) { - List result = new LinkedList<>(); - for (SamlIdentityProviderDefinition def : idpsInTheZone) { - if (allowedIdps.contains(def.getIdpEntityAlias())) { - result.add(def); - } - } - return result; - } - return idpsInTheZone; + return getIdentityProviderDefinitionsForZone(zone).stream() + .filter(def -> allowedIdps == null || allowedIdps.contains(def.getIdpEntityAlias())).toList(); } /** diff --git a/server/src/main/resources/org/cloudfoundry/identity/uaa/db/hsqldb/V4_111__IdP_active_and_type_in_zone_Index.sql b/server/src/main/resources/org/cloudfoundry/identity/uaa/db/hsqldb/V4_111__IdP_active_and_type_in_zone_Index.sql new file mode 100644 index 00000000000..611c1da850f --- /dev/null +++ b/server/src/main/resources/org/cloudfoundry/identity/uaa/db/hsqldb/V4_111__IdP_active_and_type_in_zone_Index.sql @@ -0,0 +1 @@ +CREATE INDEX IF NOT EXISTS active_and_type_in_zone ON identity_provider (identity_zone_id, active, type); \ No newline at end of file diff --git a/server/src/main/resources/org/cloudfoundry/identity/uaa/db/mysql/V4_111__IdP_active_and_type_in_zone_Index.sql b/server/src/main/resources/org/cloudfoundry/identity/uaa/db/mysql/V4_111__IdP_active_and_type_in_zone_Index.sql new file mode 100644 index 00000000000..99383753152 --- /dev/null +++ b/server/src/main/resources/org/cloudfoundry/identity/uaa/db/mysql/V4_111__IdP_active_and_type_in_zone_Index.sql @@ -0,0 +1 @@ +CREATE INDEX active_and_type_in_zone ON identity_provider (identity_zone_id, active, type) LOCK = SHARED; \ No newline at end of file diff --git a/server/src/main/resources/org/cloudfoundry/identity/uaa/db/postgresql/V4_111__IdP_active_and_type_in_zone_Index.sql b/server/src/main/resources/org/cloudfoundry/identity/uaa/db/postgresql/V4_111__IdP_active_and_type_in_zone_Index.sql new file mode 100644 index 00000000000..ab798445d05 --- /dev/null +++ b/server/src/main/resources/org/cloudfoundry/identity/uaa/db/postgresql/V4_111__IdP_active_and_type_in_zone_Index.sql @@ -0,0 +1 @@ +CREATE INDEX CONCURRENTLY IF NOT EXISTS active_and_type_in_zone on identity_provider (identity_zone_id, active, type); \ No newline at end of file diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/login/LoginInfoEndpointTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/login/LoginInfoEndpointTests.java index 986cb508593..36d10f798de 100755 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/login/LoginInfoEndpointTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/login/LoginInfoEndpointTests.java @@ -83,7 +83,6 @@ import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; @@ -831,7 +830,7 @@ void allowedIdpsforClientOIDCProvider() throws Exception { clientAllowedIdps.add(createOIDCIdentityProvider("my-OIDC-idp2")); clientAllowedIdps.add(createOIDCIdentityProvider("my-OIDC-idp1")); - when(mockIdentityProviderProvisioning.retrieveAll(eq(true), anyString())).thenReturn(clientAllowedIdps); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20))).thenReturn(clientAllowedIdps); LoginInfoEndpoint endpoint = getEndpoint(IdentityZoneHolder.get(), clientDetailsService); @@ -856,7 +855,7 @@ void oauth_provider_links_shown() throws Exception { IdentityProvider identityProvider = MultitenancyFixture.identityProvider("oauth-idp-alias", "uaa"); identityProvider.setConfig(definition); - when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(singletonList(identityProvider)); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), any())).thenReturn(singletonList(identityProvider)); endpoint.loginForHtml(extendedModelMap, null, new MockHttpServletRequest(), singletonList(MediaType.TEXT_HTML)); assertThat(extendedModelMap.get("showLoginLinks"), equalTo(true)); @@ -873,7 +872,8 @@ void passcode_prompt_present_whenThereIsAtleastOneActiveOauthProvider() throws E IdentityProvider identityProvider = MultitenancyFixture.identityProvider("oauth-idp-alias", "uaa"); identityProvider.setConfig(definition); - when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(singletonList(identityProvider)); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20))) + .thenReturn(singletonList(identityProvider)); endpoint.infoForLoginJson(extendedModelMap, null, new MockHttpServletRequest("GET", "http://someurl")); Map mapPrompts = (Map) extendedModelMap.get("prompts"); @@ -892,7 +892,8 @@ void passcode_prompt_present_whenThereIsAtleastOneActiveOauthProvider_stillWorks IdentityProvider identityProvider = MultitenancyFixture.identityProvider("oauth-idp-alias", "uaa"); identityProvider.setConfig(definition); - when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(singletonList(identityProvider)); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20))) + .thenReturn(singletonList(identityProvider)); endpoint.infoForLoginJson(extendedModelMap, null, new MockHttpServletRequest("GET", "http://someurl")); Map mapPrompts = (Map) extendedModelMap.get("prompts"); @@ -911,7 +912,7 @@ void passcode_prompt_present_whenThereIsAtleastOneActiveOauthProvider_stillWorks IdentityProvider identityProvider = MultitenancyFixture.identityProvider("oauth-idp-alias", "uaa"); identityProvider.setConfig(definition); - when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(singletonList(identityProvider)); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), any())).thenReturn(singletonList(identityProvider)); endpoint.infoForLoginJson(extendedModelMap, null, new MockHttpServletRequest("GET", "http://someurl")); Map mapPrompts = (Map) extendedModelMap.get("prompts"); @@ -935,7 +936,8 @@ void we_return_both_oauth_and_oidc_providers() throws Exception { IdentityProvider oidcProvider = MultitenancyFixture.identityProvider("oidc-idp-alias", "uaa"); oidcProvider.setConfig(oidcDefinition); - when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(Arrays.asList(oauthProvider, oidcProvider)); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20))) + .thenReturn(Arrays.asList(oauthProvider, oidcProvider)); assertEquals(2, endpoint.getOauthIdentityProviderDefinitions(null).size()); } @@ -981,7 +983,8 @@ void loginHintEmailDomain() throws Exception { when(mockOidcConfig.getResponseType()).thenReturn("token"); when(mockOidcConfig.getEmailDomain()).thenReturn(singletonList("example.com")); when(mockProvider.getConfig()).thenReturn(mockOidcConfig); - when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), any())).thenReturn(singletonList(mockProvider)); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20))) + .thenReturn(singletonList(mockProvider)); LoginInfoEndpoint endpoint = getEndpoint(IdentityZoneHolder.get(), clientDetailsService); @@ -1237,7 +1240,7 @@ public void testInvalidLoginHintLoginPageReturnsList() throws Exception { List clientAllowedIdps = new LinkedList<>(); clientAllowedIdps.add(createOIDCIdentityProvider("my-OIDC-idp1")); clientAllowedIdps.add(createOIDCIdentityProvider("my-OIDC-idp2")); - when(mockIdentityProviderProvisioning.retrieveAll(eq(true), anyString())).thenReturn(clientAllowedIdps); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20))).thenReturn(clientAllowedIdps); when(mockIdentityProviderProvisioning.retrieveByOrigin(eq("invalidorigin"), anyString())).thenThrow(new EmptyResultDataAccessException(1)); SavedRequest savedRequest = SessionUtils.getSavedRequestSession(mockHttpServletRequest.getSession()); @@ -1848,7 +1851,7 @@ private static void mockOidcProvider(IdentityProviderProvisioning mockIdentityPr when(mockOidcConfig.getResponseType()).thenReturn("token"); when(mockProvider.getConfig()).thenReturn(mockOidcConfig); when(mockOidcConfig.isShowLinkText()).thenReturn(true); - when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), any())).thenReturn(singletonList(mockProvider)); + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), any())).thenReturn(singletonList(mockProvider)); } private static void mockLoginHintProvider(ExternalOAuthProviderConfigurator mockIdentityProviderProvisioning) diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/JdbcIdentityProviderProvisioningTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/JdbcIdentityProviderProvisioningTests.java index 6f8d84f6814..28590e2562d 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/JdbcIdentityProviderProvisioningTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/JdbcIdentityProviderProvisioningTests.java @@ -1,5 +1,9 @@ package org.cloudfoundry.identity.uaa.provider; +import static java.util.stream.Collectors.toSet; +import static org.cloudfoundry.identity.uaa.constants.OriginKeys.KEYSTONE; +import static org.cloudfoundry.identity.uaa.constants.OriginKeys.LDAP; +import static org.cloudfoundry.identity.uaa.constants.OriginKeys.LOGIN_SERVER; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OAUTH20; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OIDC10; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.SAML; @@ -17,23 +21,31 @@ import static org.mockito.Mockito.when; import java.sql.Timestamp; +import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.UUID; +import java.util.stream.Stream; +import org.apache.commons.collections4.SetUtils; import org.assertj.core.api.Assertions; import org.cloudfoundry.identity.uaa.annotations.WithDatabaseContext; import org.cloudfoundry.identity.uaa.audit.event.EntityDeletedEvent; -import org.cloudfoundry.identity.uaa.constants.OriginKeys; import org.cloudfoundry.identity.uaa.util.JsonUtils; import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.jdbc.core.JdbcTemplate; import org.cloudfoundry.identity.uaa.oauth.common.util.RandomValueStringGenerator; +import org.springframework.util.StringUtils; @WithDatabaseContext class JdbcIdentityProviderProvisioningTests { @@ -49,6 +61,8 @@ class JdbcIdentityProviderProvisioningTests { private String otherZoneId1; private String otherZoneId2; + private static final Set ALL_TYPES = Set.of(LDAP, OIDC10, UAA, OAUTH20, SAML, KEYSTONE, LOGIN_SERVER); + @BeforeEach void createDatasource() { generator = new RandomValueStringGenerator(); @@ -312,7 +326,7 @@ void updateIdentityProviderInDefaultZone() { String idpId = "idpId-" + generator.generate(); IdentityProvider idp = MultitenancyFixture.identityProvider(origin, uaaZoneId); idp.setId(idpId); - idp.setType(OriginKeys.LDAP); + idp.setType(LDAP); idp = jdbcIdentityProviderProvisioning.create(idp, uaaZoneId); LdapIdentityProviderDefinition definition = new LdapIdentityProviderDefinition(); @@ -399,6 +413,60 @@ void retrieveIdentityProviderByOriginInDifferentZone() { assertThrows(EmptyResultDataAccessException.class, () -> jdbcIdentityProviderProvisioning.retrieveByOrigin(idp1.getOriginKey(), otherZoneId2)); } + @ParameterizedTest + @MethodSource + void retrieveActiveByTypes(final String[] types) { + final Set expectedTypes = new HashSet<>(Arrays.asList(types)); // eliminate duplicates + + // create one IdP for every expected type in the correct zone + final List expectedIdpIds = expectedTypes.stream() + .map(type -> createIdp(type, "origin-" + generator.generate(), otherZoneId1)) + .toList(); + + // have another type -> should not be in the result + final Set otherTypes = SetUtils.difference(ALL_TYPES, expectedTypes); + for (final String otherType : otherTypes) { + createIdp(otherType, "origin-" + generator.generate(), otherZoneId1); + } + + // have the correct type, but another zone -> should not be in the result + for (final String type : expectedTypes) { + createIdp(type, "origin-" + generator.generate(), otherZoneId2); + } + + final List result = jdbcIdentityProviderProvisioning.retrieveActiveByTypes(otherZoneId1, + types); + final Set idsInResult = result.stream().map(IdentityProvider::getId).collect(toSet()); + assertEquals(expectedIdpIds.size(), idsInResult.size()); + for (final String id : expectedIdpIds) { + assertTrue(idsInResult.contains(id)); + } + } + + private static Stream retrieveActiveByTypes() { + return Stream.of( + new String[] { }, + new String[] { OAUTH20, OIDC10 }, + new String[] { OAUTH20, OIDC10, SAML }, + new String[] { SAML }, + new String[] { OIDC10 }, + new String[] { LDAP, UAA, OAUTH20, OIDC10 }, + new String[] { LDAP, UAA, OAUTH20, LDAP, LDAP, OIDC10 }, // contains duplicates + (Object) new String[] { LDAP, UAA, OAUTH20, OIDC10, OIDC10, UAA } // contains duplicates + ).map(Arguments::of); + } + + private String createIdp(final String type, final String originKey, final String zoneId) { + final String idpId = "idpId-" + generator.generate(); + final IdentityProvider idp = MultitenancyFixture.identityProvider(originKey, idpId); + idp.setId(idpId); + idp.setType(type); + final IdentityProvider createdIdp = jdbcIdentityProviderProvisioning.create(idp, zoneId); + final String idpIdCreated = createdIdp.getId(); + assertTrue(StringUtils.hasText(idpIdCreated)); + return idpIdCreated; + } + @Test void testIdpWithAliasExistsInZone_TrueCase() { final IdentityProvider idpWithAlias = MultitenancyFixture.identityProvider( diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/oauth/ExternalOAuthProviderConfiguratorTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/oauth/ExternalOAuthProviderConfiguratorTests.java index 1c9fc34b1d0..649ec771d98 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/oauth/ExternalOAuthProviderConfiguratorTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/oauth/ExternalOAuthProviderConfiguratorTests.java @@ -1,5 +1,6 @@ package org.cloudfoundry.identity.uaa.provider.oauth; +import org.apache.commons.lang.RandomStringUtils; import org.cloudfoundry.identity.uaa.extensions.PollutionPreventionExtension; import org.cloudfoundry.identity.uaa.provider.AbstractExternalOAuthIdentityProviderDefinition; import org.cloudfoundry.identity.uaa.provider.IdentityProvider; @@ -16,6 +17,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; @@ -27,12 +30,18 @@ import java.net.MalformedURLException; import java.net.URL; import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; +import static java.util.stream.Collectors.toSet; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.LDAP; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OAUTH20; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OIDC10; +import static org.cloudfoundry.identity.uaa.constants.OriginKeys.SAML; +import static org.cloudfoundry.identity.uaa.constants.OriginKeys.UAA; import static org.cloudfoundry.identity.uaa.provider.ExternalIdentityProviderDefinition.USER_NAME_ATTRIBUTE_NAME; import static org.cloudfoundry.identity.uaa.util.AssertThrowsWithMessage.assertThrowsWithMessageThat; import static org.hamcrest.MatcherAssert.assertThat; @@ -46,12 +55,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; @@ -155,6 +166,110 @@ void retrieveActive() { verify(configurator, times(1)).retrieveAll(eq(true), anyString()); } + @ParameterizedTest + @MethodSource + void retrieveActiveByTypes_ShouldReturnEmptyListWhenNeitherOidcNorOAuthInTypes(final String[] types) { + final String zoneId = RandomStringUtils.randomAlphanumeric(8); + + /* arrange one active IdP per type being present in the zone + * -> however, they should not be returned since the types don't match */ + final String originKeyPrefix = RandomStringUtils.randomAlphanumeric(8) + "-"; + final List idps = new HashSet<>(Arrays.asList(types)).stream() + .map(type -> { + final IdentityProvider idp = new IdentityProvider<>(); + final String originKey = "%s%s".formatted(originKeyPrefix, type); + idp.setOriginKey(originKey); + idp.setId(originKey); + idp.setType(type); + idp.setActive(true); + return idp; + }).toList(); + lenient().when(mockIdentityProviderProvisioning.retrieveActiveByTypes(zoneId, types)).thenReturn(idps); + + assertTrue(configurator.retrieveActiveByTypes(zoneId, types).isEmpty()); + } + + private static Stream retrieveActiveByTypes_ShouldReturnEmptyListWhenNeitherOidcNorOAuthInTypes() { + return Stream.of( + new String[] { SAML }, + new String[] { SAML, LDAP }, + new String[] { }, + (Object) new String[] { UAA, LDAP, LDAP } // contains duplicates + ).map(Arguments::of); + } + + @Test + void retrieveActiveByNullType() { + assertEquals(0, configurator.retrieveActiveByTypes(IdentityZone.getUaaZoneId(), null).size()); + } + + @ParameterizedTest + @MethodSource + void retrieveActiveByTypes(final String[] types) throws OidcMetadataFetchingException { + final String zoneId = RandomStringUtils.randomAlphanumeric(8); + + // eliminate duplicates + final Set typesAsSet = new HashSet<>(Arrays.asList(types)); + final boolean inputContainsOidc = typesAsSet.contains(OIDC10); + final boolean inputContainsOauth = typesAsSet.contains(OAUTH20); + + // arrange one active IdP of every type in "oauth2.0" and "oidc1.0" exists in the zone + final String originKeyPrefix = RandomStringUtils.randomAlphanumeric(8) + "-"; + final List idps = Stream.of(OIDC10, OAUTH20) + .filter(type -> !OIDC10.equals(type) || inputContainsOidc) + .filter(type -> !OAUTH20.equals(type) || inputContainsOauth) + .map(type -> { + final IdentityProvider idp = new IdentityProvider<>(); + final String originKey = "%s%s".formatted(originKeyPrefix, type); + idp.setOriginKey(originKey); + idp.setId(originKey); + idp.setType(type); + if (OIDC10.equals(type)) { + idp.setConfig(new OIDCIdentityProviderDefinition()); + } + idp.setActive(true); + return idp; + }).toList(); + if (inputContainsOidc && inputContainsOauth) { + lenient().when(mockIdentityProviderProvisioning.retrieveActiveByTypes(zoneId, OIDC10, OAUTH20)) + .thenReturn(idps); + lenient().when(mockIdentityProviderProvisioning.retrieveActiveByTypes(zoneId, OAUTH20, OIDC10)) + .thenReturn(idps); + } else if (inputContainsOidc) { + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(zoneId, OIDC10)).thenReturn(idps); + } else if (inputContainsOauth) { + when(mockIdentityProviderProvisioning.retrieveActiveByTypes(zoneId, OAUTH20)).thenReturn(idps); + } + + final List result = configurator.retrieveActiveByTypes(zoneId, types); + + /* the result should contain only IdPs of type "oauth2.0" and "oidc1.0" and only if the corresponding type + * was part of the input types */ + final int expectedSize = (inputContainsOauth ? 1 : 0) + (inputContainsOidc ? 1 : 0); + assertEquals(expectedSize, result.size()); + + final Set typesInResult = result.stream().map(IdentityProvider::getType).collect(toSet()); + assertEquals(expectedSize, typesInResult.size()); + assertEquals(inputContainsOauth, typesInResult.contains(OAUTH20)); + assertEquals(inputContainsOidc, typesInResult.contains(OIDC10)); + + if (inputContainsOidc) { + verify(mockOidcMetadataFetcher, times(1)).fetchMetadataAndUpdateDefinition(any()); + } + } + + private static Stream retrieveActiveByTypes() { + return Stream.of( + new String[] { OIDC10, OAUTH20 }, + new String[] { OIDC10 }, + new String[] { OAUTH20 }, + new String[] { OIDC10, OIDC10, OAUTH20 }, // contains duplicates + new String[] { OIDC10, LDAP, SAML }, // ldap and saml should be ignored + new String[] { OIDC10, OIDC10, LDAP, SAML }, // ldap and saml should be ignored + (Object) new String[] { OIDC10, OIDC10, OAUTH20, LDAP, SAML } // ldap and saml should be ignored + ).map(Arguments::of); + } + @Test void retrieve_by_issuer() throws Exception { when(mockIdentityProviderProvisioning.retrieveAll(eq(true), anyString())).thenReturn(Arrays.asList(oidcProvider, oauthProvider, new IdentityProvider<>().setType(LDAP))); diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlIdentityProviderConfiguratorTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlIdentityProviderConfiguratorTests.java index edf6dad7330..58087037db6 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlIdentityProviderConfiguratorTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlIdentityProviderConfiguratorTests.java @@ -16,7 +16,6 @@ package org.cloudfoundry.identity.uaa.provider.saml; -import org.cloudfoundry.identity.uaa.constants.OriginKeys; import org.cloudfoundry.identity.uaa.provider.IdentityProvider; import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning; import org.cloudfoundry.identity.uaa.provider.JdbcIdentityProviderProvisioning; @@ -39,6 +38,7 @@ import static java.time.Duration.ofSeconds; import static java.util.Arrays.asList; +import static org.cloudfoundry.identity.uaa.constants.OriginKeys.SAML; import static org.cloudfoundry.identity.uaa.util.AssertThrowsWithMessage.assertThrowsWithMessageThat; import static org.hamcrest.Matchers.startsWith; import static org.junit.Assert.assertEquals; @@ -48,6 +48,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -187,7 +188,7 @@ public void testGetEntityID() throws Exception { case "okta-local-2": { ComparableProvider provider = (ComparableProvider) configurator.getExtendedMetadataDelegateFromCache(def).getDelegate(); IdentityProvider idp2 = mock(IdentityProvider.class); - when(idp2.getType()).thenReturn(OriginKeys.SAML); + when(idp2.getType()).thenReturn(SAML); when(idp2.getConfig()).thenReturn(def); when(provisioning.retrieveActive(anyString())).thenReturn(asList(idp2)); configurator.validateSamlIdentityProviderDefinition(def, true); @@ -219,9 +220,9 @@ void testGetEntityIDExists() { for (SamlIdentityProviderDefinition def : bootstrap.getIdentityProviderDefinitions()) { if ("okta-local-2".equalsIgnoreCase(def.getIdpEntityAlias())) { IdentityProvider idp2 = mock(IdentityProvider.class); - when(idp2.getType()).thenReturn(OriginKeys.SAML); + when(idp2.getType()).thenReturn(SAML); when(idp2.getConfig()).thenReturn(def.clone().setIdpEntityAlias("okta-local-1")); - when(provisioning.retrieveActive(anyString())).thenReturn(Arrays.asList(idp2)); + when(provisioning.retrieveActiveByTypes(anyString(), eq(SAML))).thenReturn(Arrays.asList(idp2)); assertThrowsWithMessageThat( MetadataProviderException.class, () -> configurator.validateSamlIdentityProviderDefinition(def, true), @@ -252,18 +253,18 @@ protected List getSamlIdentityProviderDefinition .setIconUrl("sample-icon-url") .setZoneId("other-zone-id"); IdentityProvider idp1 = mock(IdentityProvider.class); - when(idp1.getType()).thenReturn(OriginKeys.SAML); + when(idp1.getType()).thenReturn(SAML); when(idp1.getConfig()).thenReturn(def1); IdentityProvider idp2 = mock(IdentityProvider.class); - when(idp2.getType()).thenReturn(OriginKeys.SAML); + when(idp2.getType()).thenReturn(SAML); when(idp2.getConfig()).thenReturn(def1.clone().setIdpEntityAlias("okta-local-2")); IdentityProvider idp3 = mock(IdentityProvider.class); - when(idp3.getType()).thenReturn(OriginKeys.SAML); + when(idp3.getType()).thenReturn(SAML); when(idp3.getConfig()).thenReturn(def1.clone().setIdpEntityAlias("okta-local-3")); - when(provisioning.retrieveActive(anyString())).thenReturn(Arrays.asList(idp1, idp2)); + when(provisioning.retrieveActiveByTypes(anyString(), eq(SAML))).thenReturn(Arrays.asList(idp1, idp2)); return configurator.getIdentityProviderDefinitions(clientIdpAliases, IdentityZoneHolder.get()); }