Skip to content

Commit

Permalink
Updated fedauth tests to run on diff test server (#2062)
Browse files Browse the repository at this point in the history
Signed-off-by: Jeff Wasty <[email protected]>
  • Loading branch information
Jeffery-Wasty committed Jan 29, 2025
1 parent 50510ec commit 35daf0a
Show file tree
Hide file tree
Showing 14 changed files with 326 additions and 112 deletions.
22 changes: 22 additions & 0 deletions src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,28 @@ final InetSocketAddress open(String host, int port, int timeoutMillis, boolean u
return (InetSocketAddress) channelSocket.getRemoteSocketAddress();
}

/**
* Set TCP keep-alive options for idle connection resiliency
*/
private void setSocketOptions(Socket tcpSocket, TDSChannel channel) {
try {
if (SQLServerDriver.socketSetOptionMethod != null && SQLServerDriver.socketKeepIdleOption != null
&& SQLServerDriver.socketKeepIntervalOption != null) {
if (logger.isLoggable(Level.FINER)) {
logger.finer(channel.toString() + ": Setting KeepAlive extended socket options.");
}

SQLServerDriver.socketSetOptionMethod.invoke(tcpSocket, SQLServerDriver.socketKeepIdleOption, 30); // 30 seconds
SQLServerDriver.socketSetOptionMethod.invoke(tcpSocket, SQLServerDriver.socketKeepIntervalOption, 1); // 1 second
}
} catch (IllegalAccessException | InvocationTargetException e) {
if (logger.isLoggable(Level.FINER)) {
logger.finer(channel.toString() + ": KeepAlive extended socket options not supported on this platform. "
+ e.getMessage());
}
}
}

/**
* Disables SSL on this TDS channel.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicLong;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,7 @@ SQLServerPooledConnection getPooledConnectionParent() {
return pooledConnectionParent;
}

SQLServerConnection(String parentInfo) throws SQLServerException {
SQLServerConnection(String parentInfo) {
int connectionID = nextConnectionID(); // sequential connection id
traceID = "ConnectionID:" + connectionID;
loggingClassName += ":" + connectionID;
Expand Down Expand Up @@ -6772,7 +6772,7 @@ public java.sql.Struct createStruct(String typeName, Object[] attributes) throws
return null;
}

String getTrustedServerNameAE() throws SQLServerException {
String getTrustedServerNameAE() {
return trustedServerNameAE.toUpperCase();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ private String generateAzureDWSelect(ResultSet rs, Map<Integer, String> columns)
return sb.toString();
}

private String generateAzureDWEmptyRS(Map<Integer, String> columns) throws SQLException {
private String generateAzureDWEmptyRS(Map<Integer, String> columns) {
StringBuilder sb = new StringBuilder("SELECT TOP 0 ");
for (Entry<Integer, String> p : columns.entrySet()) {
sb.append("NULL AS ").append(p.getValue()).append(",");
Expand Down Expand Up @@ -957,7 +957,7 @@ public java.sql.ResultSet getBestRowIdentifier(String catalog, String schema, St

@Override
public java.sql.ResultSet getCrossReference(String cat1, String schem1, String tab1, String cat2, String schem2,
String tab2) throws SQLException, SQLTimeoutException {
String tab2) throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand Down Expand Up @@ -1021,8 +1021,7 @@ public String getDriverVersion() throws SQLServerException {
}

@Override
public java.sql.ResultSet getExportedKeys(String cat, String schema,
String table) throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getExportedKeys(String cat, String schema, String table) throws SQLException {
return getCrossReference(cat, schema, table, null, null, null);
}

Expand All @@ -1039,12 +1038,11 @@ public String getIdentifierQuoteString() throws SQLServerException {
}

@Override
public java.sql.ResultSet getImportedKeys(String cat, String schema,
String table) throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getImportedKeys(String cat, String schema, String table) throws SQLException {
return getCrossReference(null, null, null, cat, schema, table);
}

private ResultSet executeSPFkeys(String[] procParams) throws SQLException, SQLTimeoutException {
private ResultSet executeSPFkeys(String[] procParams) throws SQLException {
if (!this.connection.isAzureDW()) {
String tempTableName = "@jdbc_temp_fkeys_result";
String sql = "DECLARE " + tempTableName + " table (PKTABLE_QUALIFIER sysname, " + "PKTABLE_OWNER sysname, "
Expand Down Expand Up @@ -1218,7 +1216,7 @@ public int getMaxColumnsInTable() throws SQLServerException {
}

@Override
public int getMaxConnections() throws SQLException, SQLTimeoutException {
public int getMaxConnections() throws SQLException {
checkClosed();
try (SQLServerResultSet rs = getResultSetFromInternalQueries(null,
"select maximum from sys.configurations where name = 'user connections'")) {
Expand Down Expand Up @@ -1438,7 +1436,7 @@ public ResultSet getPseudoColumns(String catalog, String schemaPattern, String t
}

@Override
public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getSchemas() throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
Expand All @@ -1447,8 +1445,7 @@ public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException

}

private java.sql.ResultSet getSchemasInternal(String catalog,
String schemaPattern) throws SQLException, SQLTimeoutException {
private java.sql.ResultSet getSchemasInternal(String catalog, String schemaPattern) throws SQLException {

String s;
// The schemas that return null for catalog name, these are prebuilt
Expand Down Expand Up @@ -1605,14 +1602,13 @@ public java.sql.ResultSet getTablePrivileges(String catalog, String schema,
}

@Override
public java.sql.ResultSet getTableTypes() throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getTableTypes() throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + " ActivityId: " + ActivityCorrelator.getNext().toString());
}
checkClosed();
String s = "SELECT 'VIEW' 'TABLE_TYPE' UNION SELECT 'TABLE' UNION SELECT 'SYSTEM TABLE'";
SQLServerResultSet rs = getResultSetFromInternalQueries(null, s);
return rs;
return getResultSetFromInternalQueries(null, s);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.microsoft.aad.msal4j.IntegratedWindowsAuthenticationParameters;
import com.microsoft.aad.msal4j.InteractiveRequestParameters;
import com.microsoft.aad.msal4j.MsalInteractionRequiredException;
import com.microsoft.aad.msal4j.MsalThrottlingException;
import com.microsoft.aad.msal4j.PublicClientApplication;
import com.microsoft.aad.msal4j.SilentParameters;
import com.microsoft.aad.msal4j.SystemBrowserOptions;
Expand Down Expand Up @@ -70,7 +71,7 @@ static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String use
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} finally {
executorService.shutdown();
Expand Down Expand Up @@ -105,7 +106,7 @@ static SqlFedAuthToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, S
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, aadPrincipalID, authenticationString);
} finally {
executorService.shutdown();
Expand Down Expand Up @@ -148,7 +149,7 @@ static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, "", authenticationString);
} finally {
executorService.shutdown();
Expand All @@ -171,6 +172,19 @@ static SqlFedAuthToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo,
// try to acquire token silently if user account found in cache
try {
Set<IAccount> accountsInCache = pca.getAccounts().join();
if (logger.isLoggable(Level.FINE)) {
StringBuilder acc = new StringBuilder();
if (accountsInCache != null) {
for (IAccount account : accountsInCache) {
if (acc.length() != 0) {
acc.append(", ");
}
acc.append(account.username());
}
}
logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = "
+ (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user);
}
if (null != accountsInCache && !accountsInCache.isEmpty() && null != user && !user.isEmpty()) {
IAccount account = getAccountByUsername(accountsInCache, user);
if (null != account) {
Expand Down Expand Up @@ -214,7 +228,7 @@ static SqlFedAuthToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo,
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} finally {
executorService.shutdown();
Expand All @@ -233,8 +247,7 @@ private static IAccount getAccountByUsername(Set<IAccount> accounts, String user
return null;
}

private static SQLServerException getCorrectedException(ExecutionException e, String user,
String authenticationString) {
private static SQLServerException getCorrectedException(Exception e, String user, String authenticationString) {
Object[] msgArgs = {user, authenticationString};

if (null == e.getCause() || null == e.getCause().getMessage()) {
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,6 @@ protected Object[][] getContents() {
{"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"},
{"R_loginFailedMI", "Login failed for user '<token-identified principal>'"},
{"R_MInotAvailable", "Managed Identity authentication is not available"},
{"R_noLoginModulesConfiguredForJdbcDriver", "javax.security.auth.login.LoginException (No LoginModules configured for SQLJDBCDriver)"}};
{"R_noLoginModulesConfiguredForJdbcDriver", "javax.security.auth.login.LoginException (No LoginModules configured for SQLJDBCDriver)"},
{"R_failedFedauth", "Failed to acquire fedauth token: "}};
}
96 changes: 89 additions & 7 deletions src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
import java.util.List;
import java.util.Locale;
import java.util.ResourceBundle;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.junit.Assert;

import com.microsoft.sqlserver.testframework.AbstractSQLGenerator;
import com.microsoft.sqlserver.testframework.PrepUtil;
Expand Down Expand Up @@ -75,6 +82,73 @@ public final class TestUtils {
static final int ENGINE_EDITION_FOR_SQL_AZURE_DW = 6;
static final int ENGINE_EDITION_FOR_SQL_AZURE_MI = 8;

public static final int TEST_TOKEN_EXPIRY_SECONDS = 120; // token expiry time in secs

static String applicationKey;
static String applicationClientID;

static {
try (InputStream input = new FileInputStream(Constants.CONFIG_PROPERTIES_FILE)) {
Properties configProperties = new Properties();
configProperties.load(input);
applicationKey = configProperties.getProperty("applicationKey");
applicationClientID = configProperties.getProperty("applicationClientID");
} catch (IOException e) {
// No config file found
}
}

public static boolean expireTokenToggle = false;

public static final SQLServerAccessTokenCallback accessTokenCallback = new SQLServerAccessTokenCallback() {
@Override
public SqlAuthenticationToken getAccessToken(String spn, String stsurl) {
String scope = spn + "/.default";
Set<String> scopes = new HashSet<>();
scopes.add(scope);

try {
ExecutorService executorService = Executors.newSingleThreadExecutor();
IClientCredential credential = ClientCredentialFactory.createFromSecret(applicationKey);
ConfidentialClientApplication clientApplication = ConfidentialClientApplication
.builder(applicationClientID, credential).executorService(executorService).authority(stsurl)
.build();
CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());

IAuthenticationResult authenticationResult = future.get();
String accessToken = authenticationResult.accessToken();
long expiresOn = authenticationResult.expiresOnDate().getTime();

if (expireTokenToggle) {
Date now = new Date();
long minutesToExpireWithin = 10 * 60 * 1000; // Expire within 10 minutes
return new SqlAuthenticationToken(accessToken, now.getTime() + minutesToExpireWithin);
} else {
return new SqlAuthenticationToken(accessToken, expiresOn);
}
} catch (Exception e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
return null;
}
};

public static void setAccessTokenExpiry(Object con, String accessToken) {
Field fedAuthTokenField;
try {
fedAuthTokenField = SQLServerConnection.class.getDeclaredField("fedAuthToken");
fedAuthTokenField.setAccessible(true);

Date newExpiry = new Date(
System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(TEST_TOKEN_EXPIRY_SECONDS));
SqlAuthenticationToken newFedAuthToken = new SqlAuthenticationToken(accessToken, newExpiry);
fedAuthTokenField.set(con, newFedAuthToken);
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
Assert.fail("Failed to set token expiry: " + e.getMessage());
}
}

private TestUtils() {}

/**
Expand Down Expand Up @@ -376,7 +450,8 @@ public static void dropTypeIfExists(String typeName, java.sql.Statement stmt) th
* @throws SQLException
*/
public static void dropUserDefinedTypeIfExists(String typeName, Statement stmt) throws SQLException {
stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName) + "') DROP TYPE " + typeName);
stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName)
+ "') DROP TYPE " + typeName);
}

/**
Expand Down Expand Up @@ -963,11 +1038,18 @@ private static java.security.cert.Certificate getCertificate(String certname) th
}
}

public static void freeProcCache(Statement stmt) {
try {
stmt.execute("DBCC FREEPROCCACHE");
} catch (Exception e) {
// ignore error - some tests fails due to permission issues from managed identity, this does not seem to affect tests
}
public static String getConnectionID(
SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
Class<?> pooledConnection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerPooledConnection");
Class<?> connection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerConnection");

Field physicalConnection = pooledConnection.getDeclaredField("physicalConnection");
Field traceID = connection.getDeclaredField("traceID");

physicalConnection.setAccessible(true);
traceID.setAccessible(true);

SQLServerConnection conn = (SQLServerConnection) physicalConnection.get(pc);
return (String) traceID.get(conn);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ private void testAccessTokenExpiredThenCreateNewStatement(SqlAuthentication auth
TestUtils.dropTableIfExists(charTable, stmt);
}

TestUtils.setAccessTokenExpiry(connection, accessToken);
secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS;
while (secondsPassed < secondsBeforeExpiration) {
Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 2 minutes
Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s

secondsPassed = (System.currentTimeMillis() - start) / 1000;

try (Statement stmt1 = connection.createStatement()) {
testUserName(connection, azureUserName, authentication);

Expand Down Expand Up @@ -150,10 +153,12 @@ private void testAccessTokenExpiredThenExecuteUsingSameStatement(
TestUtils.dropTableIfExists(charTable, stmt);
}

TestUtils.setAccessTokenExpiry(connection, accessToken);
secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS;
while (secondsPassed < secondsBeforeExpiration) {
Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 5 minutes

Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s
secondsPassed = (System.currentTimeMillis() - start) / 1000;

testUserName(connection, azureUserName, authentication);
}

Expand Down
Loading

0 comments on commit 35daf0a

Please sign in to comment.