Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Performance of LoginInfoEndpoint #3141

Merged
merged 15 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
import static java.util.Collections.emptyMap;
import static java.util.Objects.isNull;
import static java.util.Optional.ofNullable;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OAUTH20;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OIDC10;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.UAA;
import static org.cloudfoundry.identity.uaa.util.UaaUrlUtils.addSubdomainToUrl;
import static org.springframework.util.StringUtils.hasText;
Expand Down Expand Up @@ -629,8 +631,10 @@ private Map<String, SamlIdentityProviderDefinition> getSamlIdentityProviderDefin

protected Map<String, AbstractExternalOAuthIdentityProviderDefinition> getOauthIdentityProviderDefinitions(List<String> allowedIdps) {

List<IdentityProvider> identityProviders =
externalOAuthProviderConfigurator.retrieveAll(true, IdentityZoneHolder.get().getId());
List<IdentityProvider> identityProviders = externalOAuthProviderConfigurator.retrieveActiveByTypes(
IdentityZoneHolder.get().getId(),
OIDC10, OAUTH20
);

return identityProviders.stream()
.filter(p -> allowedIdps == null || allowedIdps.contains(p.getOriginKey()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public interface IdentityProviderProvisioning {

List<IdentityProvider> retrieveActive(String zoneId);

List<IdentityProvider> retrieveActiveByTypes(String zoneId, String... types);

List<IdentityProvider> retrieveAll(boolean activeOnly, String zoneId);

IdentityProvider retrieveByOrigin(String origin, String zoneId);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.cloudfoundry.identity.uaa.provider;

import static java.sql.Types.VARCHAR;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.joining;
import static org.cloudfoundry.identity.uaa.util.UaaStringUtils.isNotEmpty;

import org.cloudfoundry.identity.uaa.audit.event.SystemDeletable;
Expand All @@ -19,9 +21,13 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;

@Component("identityProviderProvisioning")
Expand All @@ -37,6 +43,8 @@ public class JdbcIdentityProviderProvisioning implements IdentityProviderProvisi

public static final String IDENTITY_ACTIVE_PROVIDERS_QUERY = IDENTITY_PROVIDERS_QUERY + " and active=?";

public static final String IDENTITY_ACTIVE_PROVIDERS_OF_TYPE_QUERY_TEMPLATE = IDENTITY_ACTIVE_PROVIDERS_QUERY + " and type in (%s)";

public static final String IDP_WITH_ALIAS_EXISTS_QUERY = "select 1 from identity_provider idp where idp.identity_zone_id = ? and idp.alias_zid <> '' limit 1";

public static final String ID_PROVIDER_UPDATE_FIELDS = "version,lastmodified,name,type,config,active,alias_id,alias_zid,external_key".replace(",", "=?,") + "=?";
Expand Down Expand Up @@ -85,6 +93,26 @@ public List<IdentityProvider> retrieveActive(String zoneId) {
return jdbcTemplate.query(IDENTITY_ACTIVE_PROVIDERS_QUERY, mapper, zoneId, true);
}

@Override
public List<IdentityProvider> retrieveActiveByTypes(final String zoneId, final String... types) {
if (ObjectUtils.isNotEmpty(types)) {
// eliminate duplicates
final Set<String> typesAsSet = new HashSet<>(Arrays.asList(types));

// adjust the number of SQL parameters in the prepared statement
final String sqlPlaceholdersForTypes = typesAsSet.stream().map(type -> "?").collect(joining(","));
final String sql = IDENTITY_ACTIVE_PROVIDERS_OF_TYPE_QUERY_TEMPLATE.formatted(sqlPlaceholdersForTypes);

final ArrayList<Object> arrayList = new ArrayList<>(typesAsSet.size() + 2);
arrayList.add(zoneId);
arrayList.add(true);
arrayList.addAll(typesAsSet);
return jdbcTemplate.query(sql, mapper, arrayList.toArray());
} else {
return emptyList();
}
}

@Override
public List<IdentityProvider> retrieveAll(boolean activeOnly, String zoneId) {
if (activeOnly) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.toSet;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OAUTH20;
import static org.cloudfoundry.identity.uaa.constants.OriginKeys.OIDC10;

Expand Down Expand Up @@ -179,6 +181,27 @@ public List<IdentityProvider> retrieveActive(String zoneId) {
return retrieveAll(true, zoneId);
}

@Override
public List<IdentityProvider> retrieveActiveByTypes(final String zoneId, final String... types) {
if (types == null || types.length == 0) {
return emptyList();
}

// intersect passed types with "oidc1.0" and "oauth2.0"
final Set<String> filteredTypes = Arrays.stream(types)
.filter(type -> OIDC10.equals(type) || OAUTH20.equals(type))
.collect(toSet());
if (filteredTypes.isEmpty()) {
return emptyList();
}

final List<IdentityProvider> idps = providerProvisioning.retrieveActiveByTypes(
zoneId,
filteredTypes.toArray(new String[0])
);
return overlayConfigurationsOfOidcIdps(idps);
}

public IdentityProvider retrieveByIssuer(String issuer, String zoneId) throws IncorrectResultSizeDataAccessException {
IdentityProvider issuedProvider = null;
int originLoopCheckDone = -1;
Expand Down Expand Up @@ -214,16 +237,25 @@ public IdentityProvider retrieveByIssuer(String issuer, String zoneId) throws In
@Override
public List<IdentityProvider> retrieveAll(boolean activeOnly, String zoneId) {
final List<String> types = Arrays.asList(OAUTH20, OIDC10);
List<IdentityProvider> providers = providerProvisioning.retrieveAll(activeOnly, zoneId);
List<IdentityProvider> overlayedProviders = new ArrayList<>();
ofNullable(providers).orElse(emptyList()).stream()
final List<IdentityProvider> providers = Optional.ofNullable(
providerProvisioning.retrieveAll(activeOnly, zoneId)
).orElse(emptyList());
final List<IdentityProvider> oauthAndOidcProviders = providers.stream()
.filter(p -> types.contains(p.getType()))
.toList();
return overlayConfigurationsOfOidcIdps(oauthAndOidcProviders);
}

private List<IdentityProvider> overlayConfigurationsOfOidcIdps(final List<IdentityProvider> providers) {
final List<IdentityProvider> overlayedProviders = new ArrayList<>();
providers.stream()
.forEach(p -> {
if (p.getType().equals(OIDC10)) {
try {
OIDCIdentityProviderDefinition overlayedDefinition = overlay((OIDCIdentityProviderDefinition) p.getConfig());
final OIDCIdentityProviderDefinition overlayedDefinition = overlay(
(OIDCIdentityProviderDefinition) p.getConfig());
p.setConfig(overlayedDefinition);
} catch (Exception e) {
} catch (final Exception e) {
LOGGER.error("Identity provider excluded from login page due to a problem.", e);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import org.apache.http.client.utils.URIBuilder;
import org.cloudfoundry.identity.uaa.constants.OriginKeys;
import org.cloudfoundry.identity.uaa.provider.IdentityProvider;
import org.cloudfoundry.identity.uaa.provider.IdentityProviderProvisioning;
import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
Expand All @@ -19,7 +18,6 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.LinkedList;
import java.util.List;

import static org.springframework.util.StringUtils.hasText;
Expand All @@ -44,27 +42,14 @@ public List<SamlIdentityProviderDefinition> getIdentityProviderDefinitions() {
}

public List<SamlIdentityProviderDefinition> getIdentityProviderDefinitionsForZone(IdentityZone zone) {
List<SamlIdentityProviderDefinition> result = new LinkedList<>();
for (IdentityProvider provider : providerProvisioning.retrieveActive(zone.getId())) {
if (OriginKeys.SAML.equals(provider.getType())) {
result.add((SamlIdentityProviderDefinition) provider.getConfig());
}
}
return result;
return providerProvisioning.retrieveActiveByTypes(zone.getId(), OriginKeys.SAML).stream()
.map(samlIdp -> (SamlIdentityProviderDefinition) samlIdp.getConfig())
.toList();
}

public List<SamlIdentityProviderDefinition> getIdentityProviderDefinitions(List<String> allowedIdps, IdentityZone zone) {
List<SamlIdentityProviderDefinition> idpsInTheZone = getIdentityProviderDefinitionsForZone(zone);
if (allowedIdps != null) {
List<SamlIdentityProviderDefinition> result = new LinkedList<>();
for (SamlIdentityProviderDefinition def : idpsInTheZone) {
if (allowedIdps.contains(def.getIdpEntityAlias())) {
result.add(def);
}
}
return result;
}
return idpsInTheZone;
return getIdentityProviderDefinitionsForZone(zone).stream()
.filter(def -> allowedIdps == null || allowedIdps.contains(def.getIdpEntityAlias())).toList();
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE INDEX IF NOT EXISTS active_and_type_in_zone ON identity_provider (identity_zone_id, active, type);
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE INDEX active_and_type_in_zone ON identity_provider (identity_zone_id, active, type) LOCK = SHARED;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE INDEX CONCURRENTLY IF NOT EXISTS active_and_type_in_zone on identity_provider (identity_zone_id, active, type);
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
Expand Down Expand Up @@ -831,7 +830,7 @@ void allowedIdpsforClientOIDCProvider() throws Exception {
clientAllowedIdps.add(createOIDCIdentityProvider("my-OIDC-idp2"));
clientAllowedIdps.add(createOIDCIdentityProvider("my-OIDC-idp1"));

when(mockIdentityProviderProvisioning.retrieveAll(eq(true), anyString())).thenReturn(clientAllowedIdps);
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20))).thenReturn(clientAllowedIdps);

LoginInfoEndpoint endpoint = getEndpoint(IdentityZoneHolder.get(), clientDetailsService);

Expand All @@ -856,7 +855,7 @@ void oauth_provider_links_shown() throws Exception {
IdentityProvider<AbstractExternalOAuthIdentityProviderDefinition> identityProvider = MultitenancyFixture.identityProvider("oauth-idp-alias", "uaa");
identityProvider.setConfig(definition);

when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(singletonList(identityProvider));
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), any())).thenReturn(singletonList(identityProvider));
endpoint.loginForHtml(extendedModelMap, null, new MockHttpServletRequest(), singletonList(MediaType.TEXT_HTML));

assertThat(extendedModelMap.get("showLoginLinks"), equalTo(true));
Expand All @@ -873,7 +872,8 @@ void passcode_prompt_present_whenThereIsAtleastOneActiveOauthProvider() throws E
IdentityProvider<AbstractExternalOAuthIdentityProviderDefinition> identityProvider = MultitenancyFixture.identityProvider("oauth-idp-alias", "uaa");
identityProvider.setConfig(definition);

when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(singletonList(identityProvider));
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20)))
.thenReturn(singletonList(identityProvider));
endpoint.infoForLoginJson(extendedModelMap, null, new MockHttpServletRequest("GET", "http://someurl"));

Map mapPrompts = (Map) extendedModelMap.get("prompts");
Expand All @@ -892,7 +892,8 @@ void passcode_prompt_present_whenThereIsAtleastOneActiveOauthProvider_stillWorks
IdentityProvider<AbstractExternalOAuthIdentityProviderDefinition> identityProvider = MultitenancyFixture.identityProvider("oauth-idp-alias", "uaa");
identityProvider.setConfig(definition);

when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(singletonList(identityProvider));
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20)))
.thenReturn(singletonList(identityProvider));
endpoint.infoForLoginJson(extendedModelMap, null, new MockHttpServletRequest("GET", "http://someurl"));

Map mapPrompts = (Map) extendedModelMap.get("prompts");
Expand All @@ -911,7 +912,7 @@ void passcode_prompt_present_whenThereIsAtleastOneActiveOauthProvider_stillWorks
IdentityProvider<AbstractExternalOAuthIdentityProviderDefinition> identityProvider = MultitenancyFixture.identityProvider("oauth-idp-alias", "uaa");
identityProvider.setConfig(definition);

when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(singletonList(identityProvider));
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), any())).thenReturn(singletonList(identityProvider));
endpoint.infoForLoginJson(extendedModelMap, null, new MockHttpServletRequest("GET", "http://someurl"));

Map mapPrompts = (Map) extendedModelMap.get("prompts");
Expand All @@ -935,7 +936,8 @@ void we_return_both_oauth_and_oidc_providers() throws Exception {
IdentityProvider<AbstractExternalOAuthIdentityProviderDefinition> oidcProvider = MultitenancyFixture.identityProvider("oidc-idp-alias", "uaa");
oidcProvider.setConfig(oidcDefinition);

when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), anyString())).thenReturn(Arrays.asList(oauthProvider, oidcProvider));
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20)))
.thenReturn(Arrays.asList(oauthProvider, oidcProvider));
assertEquals(2, endpoint.getOauthIdentityProviderDefinitions(null).size());
}

Expand Down Expand Up @@ -981,7 +983,8 @@ void loginHintEmailDomain() throws Exception {
when(mockOidcConfig.getResponseType()).thenReturn("token");
when(mockOidcConfig.getEmailDomain()).thenReturn(singletonList("example.com"));
when(mockProvider.getConfig()).thenReturn(mockOidcConfig);
when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), any())).thenReturn(singletonList(mockProvider));
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20)))
.thenReturn(singletonList(mockProvider));

LoginInfoEndpoint endpoint = getEndpoint(IdentityZoneHolder.get(), clientDetailsService);

Expand Down Expand Up @@ -1237,7 +1240,7 @@ public void testInvalidLoginHintLoginPageReturnsList() throws Exception {
List<IdentityProvider> clientAllowedIdps = new LinkedList<>();
clientAllowedIdps.add(createOIDCIdentityProvider("my-OIDC-idp1"));
clientAllowedIdps.add(createOIDCIdentityProvider("my-OIDC-idp2"));
when(mockIdentityProviderProvisioning.retrieveAll(eq(true), anyString())).thenReturn(clientAllowedIdps);
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), eq(OriginKeys.OIDC10), eq(OriginKeys.OAUTH20))).thenReturn(clientAllowedIdps);
when(mockIdentityProviderProvisioning.retrieveByOrigin(eq("invalidorigin"), anyString())).thenThrow(new EmptyResultDataAccessException(1));

SavedRequest savedRequest = SessionUtils.getSavedRequestSession(mockHttpServletRequest.getSession());
Expand Down Expand Up @@ -1848,7 +1851,7 @@ private static void mockOidcProvider(IdentityProviderProvisioning mockIdentityPr
when(mockOidcConfig.getResponseType()).thenReturn("token");
when(mockProvider.getConfig()).thenReturn(mockOidcConfig);
when(mockOidcConfig.isShowLinkText()).thenReturn(true);
when(mockIdentityProviderProvisioning.retrieveAll(anyBoolean(), any())).thenReturn(singletonList(mockProvider));
when(mockIdentityProviderProvisioning.retrieveActiveByTypes(anyString(), any())).thenReturn(singletonList(mockProvider));
}

private static void mockLoginHintProvider(ExternalOAuthProviderConfigurator mockIdentityProviderProvisioning)
Expand Down
Loading
Loading