From 125bf5e8e01eb21fa0f3b70bbf2b1d19140aaa4b Mon Sep 17 00:00:00 2001 From: lilgreenbird Date: Wed, 8 Feb 2023 11:33:10 -0800 Subject: [PATCH] fix fedauth tests --- .../microsoft/sqlserver/jdbc/IOBuffer.java | 2 +- .../jdbc/ISQLServerEnclaveProvider.java | 1 - .../sqlserver/jdbc/SQLServerConnection.java | 4 +- .../jdbc/SQLServerDatabaseMetaData.java | 24 +++---- .../sqlserver/jdbc/SQLServerMSAL4JUtils.java | 18 ++--- .../sqlserver/jdbc/TestResource.java | 3 +- .../microsoft/sqlserver/jdbc/TestUtils.java | 40 ++++++++--- .../fedauth/ConnectionSuspensionTest.java | 11 ++- .../jdbc/fedauth/ErrorMessageTest.java | 33 ++++----- .../sqlserver/jdbc/fedauth/FedauthCommon.java | 71 +++++++++++++------ .../sqlserver/jdbc/fedauth/FedauthTest.java | 26 +++---- .../jdbc/fedauth/PooledConnectionTest.java | 30 ++++++-- .../jdbc/resiliency/BasicConnectionTest.java | 52 ++++++++++---- .../sqlserver/testframework/AbstractTest.java | 6 ++ 14 files changed, 205 insertions(+), 116 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java index ef180a60b..19d691268 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java @@ -735,7 +735,7 @@ final InetSocketAddress open(String host, int port, int timeoutMillis, boolean u /** * Set TCP keep-alive options for idle connection resiliency */ - private void setSocketOptions(Socket tcpSocket, TDSChannel channel) throws IOException { + private void setSocketOptions(Socket tcpSocket, TDSChannel channel) { try { if (SQLServerDriver.socketSetOptionMethod != null && SQLServerDriver.socketKeepIdleOption != null && SQLServerDriver.socketKeepIntervalOption != null) { diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerEnclaveProvider.java b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerEnclaveProvider.java index 55a551ea5..899decee6 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerEnclaveProvider.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerEnclaveProvider.java @@ -35,7 +35,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.Hashtable; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 8abfb321c..e826ea0f3 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -1605,7 +1605,7 @@ SQLServerPooledConnection getPooledConnectionParent() { return pooledConnectionParent; } - SQLServerConnection(String parentInfo) throws SQLServerException { + SQLServerConnection(String parentInfo) { int connectionID = nextConnectionID(); // sequential connection id traceID = "ConnectionID:" + connectionID; loggingClassName += ":" + connectionID; @@ -6963,7 +6963,7 @@ public java.sql.Struct createStruct(String typeName, Object[] attributes) throws return null; } - String getTrustedServerNameAE() throws SQLServerException { + String getTrustedServerNameAE() { return trustedServerNameAE.toUpperCase(); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDatabaseMetaData.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDatabaseMetaData.java index 9b9a56bf2..17ad5a0bd 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDatabaseMetaData.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDatabaseMetaData.java @@ -809,7 +809,7 @@ private String generateAzureDWSelect(ResultSet rs, Map columns) return sb.toString(); } - private String generateAzureDWEmptyRS(Map columns) throws SQLException { + private String generateAzureDWEmptyRS(Map columns) { StringBuilder sb = new StringBuilder("SELECT TOP 0 "); for (Entry p : columns.entrySet()) { sb.append("NULL AS ").append(p.getValue()).append(","); @@ -950,7 +950,7 @@ public java.sql.ResultSet getBestRowIdentifier(String catalog, String schema, St @Override public java.sql.ResultSet getCrossReference(String cat1, String schem1, String tab1, String cat2, String schem2, - String tab2) throws SQLException, SQLTimeoutException { + String tab2) throws SQLException { if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString()); } @@ -1014,8 +1014,7 @@ public String getDriverVersion() throws SQLServerException { } @Override - public java.sql.ResultSet getExportedKeys(String cat, String schema, - String table) throws SQLException, SQLTimeoutException { + public java.sql.ResultSet getExportedKeys(String cat, String schema, String table) throws SQLException { return getCrossReference(cat, schema, table, null, null, null); } @@ -1032,12 +1031,11 @@ public String getIdentifierQuoteString() throws SQLServerException { } @Override - public java.sql.ResultSet getImportedKeys(String cat, String schema, - String table) throws SQLException, SQLTimeoutException { + public java.sql.ResultSet getImportedKeys(String cat, String schema, String table) throws SQLException { return getCrossReference(null, null, null, cat, schema, table); } - private ResultSet executeSPFkeys(String[] procParams) throws SQLException, SQLTimeoutException { + private ResultSet executeSPFkeys(String[] procParams) throws SQLException { if (!this.connection.isAzureDW()) { String tempTableName = "@jdbc_temp_fkeys_result"; String sql = "DECLARE " + tempTableName + " table (PKTABLE_QUALIFIER sysname, " + "PKTABLE_OWNER sysname, " @@ -1219,7 +1217,7 @@ public int getMaxColumnsInTable() throws SQLServerException { } @Override - public int getMaxConnections() throws SQLException, SQLTimeoutException { + public int getMaxConnections() throws SQLException { checkClosed(); try (SQLServerResultSet rs = getResultSetFromInternalQueries(null, "select maximum from sys.configurations where name = 'user connections'")) { @@ -1439,7 +1437,7 @@ public ResultSet getPseudoColumns(String catalog, String schemaPattern, String t } @Override - public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException { + public java.sql.ResultSet getSchemas() throws SQLException { if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString()); } @@ -1448,8 +1446,7 @@ public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException } - private java.sql.ResultSet getSchemasInternal(String catalog, - String schemaPattern) throws SQLException, SQLTimeoutException { + private java.sql.ResultSet getSchemasInternal(String catalog, String schemaPattern) throws SQLException { String s; // The schemas that return null for catalog name, these are prebuilt @@ -1606,14 +1603,13 @@ public java.sql.ResultSet getTablePrivileges(String catalog, String schema, } @Override - public java.sql.ResultSet getTableTypes() throws SQLException, SQLTimeoutException { + public java.sql.ResultSet getTableTypes() throws SQLException { if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) { loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString()); } checkClosed(); String s = "SELECT 'VIEW' 'TABLE_TYPE' UNION SELECT 'TABLE' UNION SELECT 'SYSTEM TABLE'"; - SQLServerResultSet rs = getResultSetFromInternalQueries(null, s); - return rs; + return getResultSetFromInternalQueries(null, s); } @Override diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 3568ffcd0..a099031e7 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -28,6 +28,7 @@ import com.microsoft.aad.msal4j.IntegratedWindowsAuthenticationParameters; import com.microsoft.aad.msal4j.InteractiveRequestParameters; import com.microsoft.aad.msal4j.MsalInteractionRequiredException; +import com.microsoft.aad.msal4j.MsalThrottlingException; import com.microsoft.aad.msal4j.PublicClientApplication; import com.microsoft.aad.msal4j.SilentParameters; import com.microsoft.aad.msal4j.SystemBrowserOptions; @@ -74,7 +75,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str Thread.currentThread().interrupt(); throw new SQLServerException(e.getMessage(), e); - } catch (ExecutionException e) { + } catch (MsalThrottlingException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); } finally { executorService.shutdown(); @@ -108,7 +109,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth Thread.currentThread().interrupt(); throw new SQLServerException(e.getMessage(), e); - } catch (ExecutionException e) { + } catch (MsalThrottlingException | ExecutionException e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); } finally { executorService.shutdown(); @@ -150,7 +151,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut Thread.currentThread().interrupt(); throw new SQLServerException(e.getMessage(), e); - } catch (ExecutionException e) { + } catch (MsalThrottlingException | ExecutionException e) { throw getCorrectedException(e, "", authenticationString); } finally { executorService.shutdown(); @@ -177,11 +178,13 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu StringBuilder acc = new StringBuilder(); if (accountsInCache != null) { for (IAccount account : accountsInCache) { - if (acc.length() != 0) acc.append(", "); + if (acc.length() != 0) + acc.append(", "); acc.append(account.username()); } } - logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = " + (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user); + logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = " + + (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user); } if (null != accountsInCache && !accountsInCache.isEmpty() && null != user && !user.isEmpty()) { IAccount account = getAccountByUsername(accountsInCache, user); @@ -228,7 +231,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu Thread.currentThread().interrupt(); throw new SQLServerException(e.getMessage(), e); - } catch (ExecutionException e) { + } catch (MsalThrottlingException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); } finally { executorService.shutdown(); @@ -247,8 +250,7 @@ private static IAccount getAccountByUsername(Set accounts, String user return null; } - private static SQLServerException getCorrectedException(ExecutionException e, String user, - String authenticationString) { + private static SQLServerException getCorrectedException(Exception e, String user, String authenticationString) { Object[] msgArgs = {user, authenticationString}; if (null == e.getCause() || null == e.getCause().getMessage()) { diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java b/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java index d24076bb4..d6702c359 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java @@ -201,5 +201,6 @@ protected Object[][] getContents() { {"R_objectNullOrEmpty", "The {0} is null or empty."}, {"R_cekDecryptionFailed", "Failed to decrypt a column encryption key using key store provider: {0}."}, {"R_connectTimedOut", "connect timed out"}, - {"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"}}; + {"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"}, + {"R_failedFedauth", "Failed to acquire fedauth token: "}}; } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java b/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java index ef061551e..88ad138f8 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java @@ -39,6 +39,9 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.junit.Assert; import com.microsoft.aad.msal4j.ClientCredentialFactory; import com.microsoft.aad.msal4j.ClientCredentialParameters; @@ -91,6 +94,8 @@ public final class TestUtils { static final int ENGINE_EDITION_FOR_SQL_AZURE_DW = 6; static final int ENGINE_EDITION_FOR_SQL_AZURE_MI = 8; + public static final int TEST_TOKEN_EXPIRY_SECONDS = 120; // token expiry time in secs + static String applicationKey; static String applicationClientID; @@ -108,7 +113,8 @@ public final class TestUtils { public static boolean expireTokenToggle = false; public static final SQLServerAccessTokenCallback accessTokenCallback = new SQLServerAccessTokenCallback() { - @Override public SqlAuthenticationToken getAccessToken(String spn, String stsurl) { + @Override + public SqlAuthenticationToken getAccessToken(String spn, String stsurl) { String scope = spn + "/.default"; Set scopes = new HashSet<>(); scopes.add(scope); @@ -116,10 +122,11 @@ public final class TestUtils { try { ExecutorService executorService = Executors.newSingleThreadExecutor(); IClientCredential credential = ClientCredentialFactory.createFromSecret(applicationKey); - ConfidentialClientApplication clientApplication = ConfidentialClientApplication.builder( - applicationClientID, credential).executorService(executorService).authority(stsurl).build(); - CompletableFuture future = clientApplication.acquireToken( - ClientCredentialParameters.builder(scopes).build()); + ConfidentialClientApplication clientApplication = ConfidentialClientApplication + .builder(applicationClientID, credential).executorService(executorService).authority(stsurl) + .build(); + CompletableFuture future = clientApplication + .acquireToken(ClientCredentialParameters.builder(scopes).build()); IAuthenticationResult authenticationResult = future.get(); String accessToken = authenticationResult.accessToken(); @@ -139,6 +146,21 @@ public final class TestUtils { } }; + public static void setAccessTokenExpiry(Object con, String accessToken) { + Field fedAuthTokenField; + try { + fedAuthTokenField = SQLServerConnection.class.getDeclaredField("fedAuthToken"); + fedAuthTokenField.setAccessible(true); + + Date newExpiry = new Date( + System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(TEST_TOKEN_EXPIRY_SECONDS)); + SqlAuthenticationToken newFedAuthToken = new SqlAuthenticationToken(accessToken, newExpiry); + fedAuthTokenField.set(con, newFedAuthToken); + } catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) { + Assert.fail("Failed to set token expiry: " + e.getMessage()); + } + } + private TestUtils() {} /** @@ -440,7 +462,8 @@ public static void dropTypeIfExists(String typeName, java.sql.Statement stmt) th * @throws SQLException */ public static void dropUserDefinedTypeIfExists(String typeName, Statement stmt) throws SQLException { - stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName) + "') DROP TYPE " + typeName); + stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName) + + "') DROP TYPE " + typeName); } /** @@ -1029,12 +1052,13 @@ private static java.security.cert.Certificate getCertificate(String certname) th } } - public static String getConnectionID(SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException { + public static String getConnectionID( + SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException { Class pooledConnection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerPooledConnection"); Class connection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerConnection"); Field physicalConnection = pooledConnection.getDeclaredField("physicalConnection"); - Field traceID = connection.getDeclaredField("traceID"); + Field traceID = connection.getDeclaredField("traceID"); physicalConnection.setAccessible(true); traceID.setAccessible(true); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ConnectionSuspensionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ConnectionSuspensionTest.java index 9e42e19b8..94e70a142 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ConnectionSuspensionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ConnectionSuspensionTest.java @@ -83,10 +83,13 @@ private void testAccessTokenExpiredThenCreateNewStatement(SqlAuthentication auth TestUtils.dropTableIfExists(charTable, stmt); } + TestUtils.setAccessTokenExpiry(connection, accessToken); + secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS; while (secondsPassed < secondsBeforeExpiration) { - Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 2 minutes + Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s secondsPassed = (System.currentTimeMillis() - start) / 1000; + try (Statement stmt1 = connection.createStatement()) { testUserName(connection, azureUserName, authentication); @@ -150,10 +153,12 @@ private void testAccessTokenExpiredThenExecuteUsingSameStatement( TestUtils.dropTableIfExists(charTable, stmt); } + TestUtils.setAccessTokenExpiry(connection, accessToken); + secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS; while (secondsPassed < secondsBeforeExpiration) { - Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 5 minutes - + Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s secondsPassed = (System.currentTimeMillis() - start) / 1000; + testUserName(connection, azureUserName, authentication); } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ErrorMessageTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ErrorMessageTest.java index 7ca0ccbe7..6f4e9e634 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ErrorMessageTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ErrorMessageTest.java @@ -226,11 +226,8 @@ public void testADPasswordUnregisteredUserWithConnectionStringUserName() throws fail(EXPECTED_EXCEPTION_NOT_THROWN); } catch (SQLServerException e) { assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), - (e.getMessage() - .contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName - + " in Active Directory (Authentication=ActiveDirectoryPassword).") - && e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_ADD)) - || e.getMessage().contains(ERR_MSG_REQUEST_THROTTLED)); + e.getMessage().contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName + + " in Active Directory (Authentication=ActiveDirectoryPassword).")); } } @@ -248,11 +245,8 @@ public void testADPasswordUnregisteredUserWithDatasource() throws SQLException { fail(EXPECTED_EXCEPTION_NOT_THROWN); } catch (SQLServerException e) { assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), - (e.getMessage() - .contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName - + " in Active Directory (Authentication=ActiveDirectoryPassword).") - && e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_ADD)) - || e.getMessage().contains(ERR_MSG_REQUEST_THROTTLED)); + e.getMessage().contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName + + " in Active Directory (Authentication=ActiveDirectoryPassword).")); } } @@ -263,11 +257,8 @@ public void testADPasswordUnregisteredUserWithConnectionStringUser() throws SQLE fail(EXPECTED_EXCEPTION_NOT_THROWN); } catch (SQLServerException e) { assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), - (e.getMessage() - .contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName - + " in Active Directory (Authentication=ActiveDirectoryPassword).") - && e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_ADD)) - || e.getMessage().contains(ERR_MSG_REQUEST_THROTTLED)); + e.getMessage().contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + badUserName + + " in Active Directory (Authentication=ActiveDirectoryPassword).")); } } @@ -400,11 +391,11 @@ public void testADPasswordWrongPasswordWithConnectionStringUserName() throws SQL fail(EXPECTED_EXCEPTION_NOT_THROWN); } - assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), (e.getMessage() + assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), e.getMessage() .contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + azureUserName + " in Active Directory (Authentication=ActiveDirectoryPassword).") && e.getCause().getCause().getMessage().toLowerCase().contains("invalid username or password") - || e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_TOO_MANY)) + || e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_TOO_MANY) || e.getMessage().contains(ERR_MSG_REQUEST_THROTTLED)); } } @@ -426,11 +417,11 @@ public void testADPasswordWrongPasswordWithDatasource() throws SQLException { fail(EXPECTED_EXCEPTION_NOT_THROWN); } - assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), (e.getMessage() + assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), e.getMessage() .contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + azureUserName + " in Active Directory (Authentication=ActiveDirectoryPassword).") && e.getCause().getCause().getMessage().toLowerCase().contains("invalid username or password") - || e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_TOO_MANY)) + || e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_TOO_MANY) || e.getMessage().contains(ERR_MSG_REQUEST_THROTTLED)); } } @@ -446,11 +437,11 @@ public void testADPasswordWrongPasswordWithConnectionStringUser() throws SQLExce fail(EXPECTED_EXCEPTION_NOT_THROWN); } - assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), (e.getMessage() + assertTrue(INVALID_EXCEPTION_MSG + ": " + e.getMessage(), e.getMessage() .contains(ERR_MSG_FAILED_AUTHENTICATE + " the user " + azureUserName + " in Active Directory (Authentication=ActiveDirectoryPassword).") && e.getCause().getCause().getMessage().toLowerCase().contains("invalid username or password") - || e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_TOO_MANY)) + || e.getCause().getCause().getMessage().contains(ERR_MSG_SIGNIN_TOO_MANY) || e.getMessage().contains(ERR_MSG_REQUEST_THROTTLED)); } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java index aa26e7af3..41beff621 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthCommon.java @@ -8,8 +8,10 @@ import static org.junit.jupiter.api.Assertions.fail; import com.microsoft.aad.msal4j.IAuthenticationResult; +import com.microsoft.aad.msal4j.MsalThrottlingException; import com.microsoft.aad.msal4j.PublicClientApplication; import com.microsoft.aad.msal4j.UserNamePasswordParameters; + import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -24,7 +26,6 @@ import java.util.logging.LogManager; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import com.microsoft.sqlserver.testframework.Constants; @@ -42,8 +43,8 @@ public class FedauthCommon extends AbstractTest { static String azureUserName = null; static String azurePassword = null; static String azureGroupUserName = null; - static String azureAADPrincipialId = null; - static String azureAADPrincipialSecret = null; + static String azureAADPrincipalId = null; + static String azureAADPrincipalSecret = null; static boolean enableADIntegrated = false; @@ -84,6 +85,7 @@ public class FedauthCommon extends AbstractTest { static final String ERR_MSG_SOCKET_CLOSED = TestResource.getResource("R_socketClosed"); static final String ERR_TCPIP_CONNECTION = TestResource.getResource("R_tcpipConnectionToHost"); static final String ERR_MSG_REQUEST_THROTTLED = "Request was throttled"; + static final String ERR_FAILED_FEDAUTH = TestResource.getResource("R_failedFedauth"); enum SqlAuthentication { NotSpecified, @@ -116,11 +118,6 @@ static SqlAuthentication valueOfString(String value) throws SQLServerException { static String adPasswordConnectionStr; static String adIntegratedConnectionStr; - @BeforeEach - public void setupEachTest() { - getFedauthInfo(); - } - @BeforeAll public static void getConfigs() throws Exception { azureServer = getConfiguredProperty("azureServer"); @@ -128,8 +125,8 @@ public static void getConfigs() throws Exception { azureUserName = getConfiguredProperty("azureUserName"); azurePassword = getConfiguredProperty("azurePassword"); azureGroupUserName = getConfiguredProperty("azureGroupUserName"); - azureAADPrincipialId = getConfiguredProperty("AADSecurePrincipalId"); - azureAADPrincipialSecret = getConfiguredProperty("AADSecurePrincipalSecret"); + azureAADPrincipalId = getConfiguredProperty("AADSecurePrincipalId"); + azureAADPrincipalSecret = getConfiguredProperty("AADSecurePrincipalSecret"); String prop = getConfiguredProperty("enableADIntegrated"); enableADIntegrated = (null != prop && prop.equalsIgnoreCase("true")) ? true : false; @@ -151,6 +148,8 @@ public static void getConfigs() throws Exception { stsurl = getConfiguredProperty("stsurl"); fedauthClientId = getConfiguredProperty("fedauthClientId"); + getFedauthInfo(); + // reset logging to avoid severe logs LogManager.getLogManager().reset(); } @@ -160,22 +159,48 @@ public static void getConfigs() throws Exception { * */ static void getFedauthInfo() { - try { - - final PublicClientApplication clientApplication = PublicClientApplication.builder(fedauthClientId) - .executorService(Executors.newFixedThreadPool(1)).authority(stsurl).build(); - final CompletableFuture future = clientApplication - .acquireToken(UserNamePasswordParameters.builder(Collections.singleton(spn + "/.default"), - azureUserName, azurePassword.toCharArray()).build()); + int retry = THROTTLE_RETRY_COUNT; + long interval = THROTTLE_RETRY_INTERVAL; + while (retry > 0) { + try { + final PublicClientApplication clientApplication = PublicClientApplication.builder(fedauthClientId) + .executorService(Executors.newFixedThreadPool(1)).authority(stsurl).build(); + final CompletableFuture future = clientApplication + .acquireToken(UserNamePasswordParameters.builder(Collections.singleton(spn + "/.default"), + azureUserName, azurePassword.toCharArray()).build()); + + final IAuthenticationResult authenticationResult = future.get(); + + secondsBeforeExpiration = TimeUnit.MILLISECONDS + .toSeconds(authenticationResult.expiresOnDate().getTime() - new Date().getTime()); + accessToken = authenticationResult.accessToken(); + retry = 0; + } catch (MsalThrottlingException te) { + interval = ((MsalThrottlingException) te).retryInMs(); + if (!checkForRetry(te, retry--, interval)) { + fail(ERR_FAILED_FEDAUTH + "no more retries: " + te.getMessage()); + } + } catch (Exception e) { + if (!checkForRetry(e, retry--, interval)) { + fail(ERR_FAILED_FEDAUTH + "no more retries: " + e.getMessage()); + } + } + } + } - final IAuthenticationResult authenticationResult = future.get(); + static boolean checkForRetry(Exception e, int retry, long interval) { + if (retry <= 0) { + return false; + } + try { + System.out.println(e.getMessage() + "Get FedAuth token failed retry #" + retry + " in " + interval + " ms"); + e.printStackTrace(); - secondsBeforeExpiration = TimeUnit.MILLISECONDS - .toSeconds(authenticationResult.expiresOnDate().getTime() - new Date().getTime()); - accessToken = authenticationResult.accessToken(); - } catch (Exception e) { - fail(e.getMessage()); + Thread.sleep(interval); + } catch (InterruptedException ex) { + fail(ERR_FAILED_FEDAUTH + ex.getMessage()); } + return true; } void testUserName(Connection conn, String user, SqlAuthentication authentication) throws SQLException { diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthTest.java index 8ea0a86a1..53c3d546e 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/FedauthTest.java @@ -278,8 +278,8 @@ public void testAADPasswordApplicationName() throws Exception { @Test public void testAADServicePrincipalAuthDeprecated() { String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication=" - + SqlAuthentication.ActiveDirectoryServicePrincipal + ";AADSecurePrincipalId=" + azureAADPrincipialId - + ";AADSecurePrincipalSecret=" + azureAADPrincipialSecret; + + SqlAuthentication.ActiveDirectoryServicePrincipal + ";AADSecurePrincipalId=" + azureAADPrincipalId + + ";AADSecurePrincipalSecret=" + azureAADPrincipalSecret; String urlEncrypted = url + ";encrypt=true;trustServerCertificate=true;"; SQLServerDataSource ds = new SQLServerDataSource(); updateDataSource(url, ds); @@ -300,8 +300,8 @@ public void testAADServicePrincipalAuthDeprecated() { @Test public void testAADServicePrincipalAuth() { String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication=" - + SqlAuthentication.ActiveDirectoryServicePrincipal + ";Username=" + azureAADPrincipialId + ";Password=" - + azureAADPrincipialSecret; + + SqlAuthentication.ActiveDirectoryServicePrincipal + ";Username=" + azureAADPrincipalId + ";Password=" + + azureAADPrincipalSecret; String urlEncrypted = url + ";encrypt=true;trustServerCertificate=true;"; SQLServerDataSource ds = new SQLServerDataSource(); updateDataSource(url, ds); @@ -323,32 +323,32 @@ public void testAADServicePrincipalAuthWrong() { String baseUrl = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication=" + SqlAuthentication.ActiveDirectoryServicePrincipal + ";"; // Wrong AADSecurePrincipalSecret provided. - String url = baseUrl + "AADSecurePrincipalId=" + azureAADPrincipialId + ";AADSecurePrincipalSecret=wrongSecret"; + String url = baseUrl + "AADSecurePrincipalId=" + azureAADPrincipalId + ";AADSecurePrincipalSecret=wrongSecret"; validateException(url, "R_MSALExecution"); // Wrong AADSecurePrincipalId provided. - url = baseUrl + "AADSecurePrincipalId=wrongId;AADSecurePrincipalSecret=" + azureAADPrincipialSecret; + url = baseUrl + "AADSecurePrincipalId=wrongId;AADSecurePrincipalSecret=" + azureAADPrincipalSecret; validateException(url, "R_MSALExecution"); // AADSecurePrincipalSecret/password not provided. - url = baseUrl + "AADSecurePrincipalId=" + azureAADPrincipialId; + url = baseUrl + "AADSecurePrincipalId=" + azureAADPrincipalId; validateException(url, "R_NoUserPasswordForActiveServicePrincipal"); - url = baseUrl + "Username=" + azureAADPrincipialId; + url = baseUrl + "Username=" + azureAADPrincipalId; validateException(url, "R_NoUserPasswordForActiveServicePrincipal"); // AADSecurePrincipalId/username not provided. - url = baseUrl + "AADSecurePrincipalSecret=" + azureAADPrincipialSecret; + url = baseUrl + "AADSecurePrincipalSecret=" + azureAADPrincipalSecret; validateException(url, "R_NoUserPasswordForActiveServicePrincipal"); - url = baseUrl + "password=" + azureAADPrincipialSecret; + url = baseUrl + "password=" + azureAADPrincipalSecret; validateException(url, "R_NoUserPasswordForActiveServicePrincipal"); // Both AADSecurePrincipalId/username and AADSecurePrincipalSecret/password not provided. validateException(baseUrl, "R_NoUserPasswordForActiveServicePrincipal"); // both username/password and AADSecurePrincipalId/AADSecurePrincipalSecret provided - url = baseUrl + "Username=" + azureAADPrincipialId + ";password=" + azureAADPrincipialSecret - + ";AADSecurePrincipalId=" + azureAADPrincipialId + ";AADSecurePrincipalSecret=" - + azureAADPrincipialSecret; + url = baseUrl + "Username=" + azureAADPrincipalId + ";password=" + azureAADPrincipalSecret + + ";AADSecurePrincipalId=" + azureAADPrincipalId + ";AADSecurePrincipalSecret=" + + azureAADPrincipalSecret; validateException(url, "R_BothUserPasswordandDeprecated"); } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/PooledConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/PooledConnectionTest.java index 244fc059d..be893b662 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/PooledConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/fedauth/PooledConnectionTest.java @@ -7,6 +7,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.lang.reflect.Field; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; @@ -31,6 +32,7 @@ import com.microsoft.sqlserver.jdbc.RandomUtil; import com.microsoft.sqlserver.jdbc.SQLServerConnectionPoolDataSource; +import com.microsoft.sqlserver.jdbc.SQLServerPooledConnection; import com.microsoft.sqlserver.jdbc.TestUtils; import com.microsoft.sqlserver.testframework.AbstractSQLGenerator; import com.microsoft.sqlserver.testframework.Constants; @@ -51,23 +53,22 @@ public static void setupTests() throws Exception { @Test public void testPooledConnectionAccessTokenExpiredThenReconnectADPassword() throws SQLException { - // suspend 5 mins - testPooledConnectionAccessTokenExpiredThenReconnect((long) 5 * 60, SqlAuthentication.ActiveDirectoryPassword); + // suspend 60 secs + testPooledConnectionAccessTokenExpiredThenReconnect((long) 60, SqlAuthentication.ActiveDirectoryPassword); // get another token getFedauthInfo(); // suspend until access token expires - testPooledConnectionAccessTokenExpiredThenReconnect(secondsBeforeExpiration, - SqlAuthentication.ActiveDirectoryPassword); + testPooledConnectionAccessTokenExpiredThenReconnect((long) 60, SqlAuthentication.ActiveDirectoryPassword); } @Test public void testPooledConnectionAccessTokenExpiredThenReconnectADIntegrated() throws SQLException { org.junit.Assume.assumeTrue(enableADIntegrated); - // suspend 5 mins - testPooledConnectionAccessTokenExpiredThenReconnect((long) 5 * 60, SqlAuthentication.ActiveDirectoryIntegrated); + // suspend 60 secs + testPooledConnectionAccessTokenExpiredThenReconnect((long) 60, SqlAuthentication.ActiveDirectoryIntegrated); // get another token getFedauthInfo(); @@ -96,6 +97,11 @@ private void testPooledConnectionAccessTokenExpiredThenReconnect(long testingTim try { // create pooled connection PooledConnection pc = ds.getPooledConnection(); + SQLServerPooledConnection spc = (SQLServerPooledConnection) pc; + Field physicalConnectionField = SQLServerPooledConnection.class.getDeclaredField("physicalConnection"); + physicalConnectionField.setAccessible(true); + Object con = physicalConnectionField.get(spc); + TestUtils.setAccessTokenExpiry(con, accessToken); // get first connection from pool try (Connection connection1 = pc.getConnection(); Statement stmt = connection1.createStatement()) { @@ -111,7 +117,7 @@ private void testPooledConnectionAccessTokenExpiredThenReconnect(long testingTim } } Thread.sleep(TimeUnit.SECONDS.toMillis(testingTimeInSeconds)); - Thread.sleep(TimeUnit.SECONDS.toMillis(2)); // give 2 mins more to make sure the access token is expired. + Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // give 90s more to make sure the access token is expired. // get second connection from pool try (Connection connection2 = pc.getConnection(); Statement stmt = connection2.createStatement()) { @@ -162,6 +168,11 @@ private void testPooledConnectionMultiThread(long testingTimeInSeconds, try { // create pooled connection final PooledConnection pc = ds.getPooledConnection(); + SQLServerPooledConnection spc = (SQLServerPooledConnection) pc; + Field physicalConnectionField = SQLServerPooledConnection.class.getDeclaredField("physicalConnection"); + physicalConnectionField.setAccessible(true); + Object con = physicalConnectionField.get(spc); + TestUtils.setAccessTokenExpiry(con, accessToken); // get first connection from pool try (Connection connection1 = pc.getConnection(); Statement stmt = connection1.createStatement()) { @@ -214,6 +225,11 @@ public void testPooledConnectionWithAccessToken() throws SQLException { // create pooled connection final PooledConnection pc = ds.getPooledConnection(); + SQLServerPooledConnection spc = (SQLServerPooledConnection) pc; + Field physicalConnectionField = SQLServerPooledConnection.class.getDeclaredField("physicalConnection"); + physicalConnectionField.setAccessible(true); + Object con = physicalConnectionField.get(spc); + TestUtils.setAccessTokenExpiry(con, accessToken); // get first connection from pool try (Connection connection1 = pc.getConnection()) { diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java index 050eb6a58..47ff214a6 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java @@ -9,6 +9,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import java.lang.reflect.Field; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; @@ -49,15 +50,33 @@ public void testBasicReconnectDefault() throws SQLException { @Test @Tag(Constants.fedAuth) - public void testBasicConnectionAAD() throws SQLException { - String azureServer = getConfiguredProperty("azureServer"); - String azureDatabase = getConfiguredProperty("azureDatabase"); - String azureUserName = getConfiguredProperty("azureUserName"); - String azurePassword = getConfiguredProperty("azurePassword"); - org.junit.Assume.assumeTrue(azureServer != null && !azureServer.isEmpty()); - - basicReconnect("jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";user=" + azureUserName - + ";password=" + azurePassword + ";loginTimeout=90;Authentication=ActiveDirectoryPassword"); + public void testBasicConnectionAAD() throws Exception { + // retry since this could fail due to server throttling + int retry = THROTTLE_RETRY_COUNT; + while (retry > 0) { + try { + String azureServer = getConfiguredProperty("azureServer"); + String azureDatabase = getConfiguredProperty("azureDatabase"); + String azureUserName = getConfiguredProperty("azureUserName"); + String azurePassword = getConfiguredProperty("azurePassword"); + org.junit.Assume.assumeTrue(azureServer != null && !azureServer.isEmpty()); + + basicReconnect("jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";user=" + + azureUserName + ";password=" + azurePassword + + ";loginTimeout=90;Authentication=ActiveDirectoryPassword"); + } catch (Exception e) { + if (e.getMessage().matches(TestUtils.formatErrorMsg("R_crClientAllRecoveryAttemptsFailed"))) { + System.out.println(e.getMessage() + "Recovery failed retry #" + retry + " in " + + THROTTLE_RETRY_INTERVAL + " ms"); + e.printStackTrace(); + + Thread.sleep(THROTTLE_RETRY_INTERVAL); + retry--; + } else { + fail(e.getMessage()); + } + } + } } @Test @@ -288,6 +307,12 @@ public void testDSPooledConnectionAccessTokenCallbackIdleConnectionResiliency() ds.setAccessTokenCallback(TestUtils.accessTokenCallback); SQLServerPooledConnection pc = (SQLServerPooledConnection) ds.getPooledConnection(); + SQLServerPooledConnection spc = (SQLServerPooledConnection) pc; + Field physicalConnectionField = SQLServerPooledConnection.class.getDeclaredField("physicalConnection"); + physicalConnectionField.setAccessible(true); + Object c = physicalConnectionField.get(spc); + String accessToken = ds.getAccessToken(); + TestUtils.setAccessTokenExpiry(c, accessToken); // Idle Connection Resiliency should reconnect after connection kill, second query should run successfully TestUtils.expireTokenToggle = false; @@ -319,8 +344,7 @@ public void testPreparedStatementCacheShouldBeCleared() throws SQLException { // add new statements to fill cache for (int i = 0; i < cacheSize; ++i) { - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con - .prepareStatement(query + i)) { + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement(query + i)) { pstmt.execute(); pstmt.execute(); } @@ -344,12 +368,12 @@ public void testPreparedStatementCacheShouldBeCleared() throws SQLException { public void testUnprocessedResponseCountSuccessfulIdleConnectionRecovery() throws SQLException { try (SQLServerConnection con = (SQLServerConnection) ResiliencyUtils.getConnection(connectionString)) { int queriesToSend = 5; - String query = String.format("/*testUnprocessedResponseCountSuccessfulIdleConnectionRecovery_%s*/SELECT 1; -- ", + String query = String.format( + "/*testUnprocessedResponseCountSuccessfulIdleConnectionRecovery_%s*/SELECT 1; -- ", UUID.randomUUID()); for (int i = 0; i < queriesToSend; ++i) { - try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con - .prepareStatement(query + i)) { + try (SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement(query + i)) { pstmt.executeQuery(); pstmt.executeQuery(); } diff --git a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java index 4ecd403db..a504f9caa 100644 --- a/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java +++ b/src/test/java/com/microsoft/sqlserver/testframework/AbstractTest.java @@ -116,6 +116,12 @@ public abstract class AbstractTest { protected static boolean isWindows = System.getProperty("os.name").startsWith("Windows"); + /** + * Retries due to server throttling + */ + protected static final int THROTTLE_RETRY_COUNT = 3; // max number of throttling retries + protected static final int THROTTLE_RETRY_INTERVAL = 60000; // default throttling retry interval in ms + public static Properties properties = null; /**