Skip to content

Commit

Permalink
fix fedauth tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lilgreenbird committed Feb 8, 2023
1 parent 4a27a0d commit 125bf5e
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 116 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ final InetSocketAddress open(String host, int port, int timeoutMillis, boolean u
/**
* Set TCP keep-alive options for idle connection resiliency
*/
private void setSocketOptions(Socket tcpSocket, TDSChannel channel) throws IOException {
private void setSocketOptions(Socket tcpSocket, TDSChannel channel) {
try {
if (SQLServerDriver.socketSetOptionMethod != null && SQLServerDriver.socketKeepIdleOption != null
&& SQLServerDriver.socketKeepIntervalOption != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ SQLServerPooledConnection getPooledConnectionParent() {
return pooledConnectionParent;
}

SQLServerConnection(String parentInfo) throws SQLServerException {
SQLServerConnection(String parentInfo) {
int connectionID = nextConnectionID(); // sequential connection id
traceID = "ConnectionID:" + connectionID;
loggingClassName += ":" + connectionID;
Expand Down Expand Up @@ -6963,7 +6963,7 @@ public java.sql.Struct createStruct(String typeName, Object[] attributes) throws
return null;
}

String getTrustedServerNameAE() throws SQLServerException {
String getTrustedServerNameAE() {
return trustedServerNameAE.toUpperCase();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ private String generateAzureDWSelect(ResultSet rs, Map<Integer, String> columns)
return sb.toString();
}

private String generateAzureDWEmptyRS(Map<Integer, String> columns) throws SQLException {
private String generateAzureDWEmptyRS(Map<Integer, String> columns) {
StringBuilder sb = new StringBuilder("SELECT TOP 0 ");
for (Entry<Integer, String> p : columns.entrySet()) {
sb.append("NULL AS ").append(p.getValue()).append(",");
Expand Down Expand Up @@ -950,7 +950,7 @@ public java.sql.ResultSet getBestRowIdentifier(String catalog, String schema, St

@Override
public java.sql.ResultSet getCrossReference(String cat1, String schem1, String tab1, String cat2, String schem2,
String tab2) throws SQLException, SQLTimeoutException {
String tab2) throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
}
Expand Down Expand Up @@ -1014,8 +1014,7 @@ public String getDriverVersion() throws SQLServerException {
}

@Override
public java.sql.ResultSet getExportedKeys(String cat, String schema,
String table) throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getExportedKeys(String cat, String schema, String table) throws SQLException {
return getCrossReference(cat, schema, table, null, null, null);
}

Expand All @@ -1032,12 +1031,11 @@ public String getIdentifierQuoteString() throws SQLServerException {
}

@Override
public java.sql.ResultSet getImportedKeys(String cat, String schema,
String table) throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getImportedKeys(String cat, String schema, String table) throws SQLException {
return getCrossReference(null, null, null, cat, schema, table);
}

private ResultSet executeSPFkeys(String[] procParams) throws SQLException, SQLTimeoutException {
private ResultSet executeSPFkeys(String[] procParams) throws SQLException {
if (!this.connection.isAzureDW()) {
String tempTableName = "@jdbc_temp_fkeys_result";
String sql = "DECLARE " + tempTableName + " table (PKTABLE_QUALIFIER sysname, " + "PKTABLE_OWNER sysname, "
Expand Down Expand Up @@ -1219,7 +1217,7 @@ public int getMaxColumnsInTable() throws SQLServerException {
}

@Override
public int getMaxConnections() throws SQLException, SQLTimeoutException {
public int getMaxConnections() throws SQLException {
checkClosed();
try (SQLServerResultSet rs = getResultSetFromInternalQueries(null,
"select maximum from sys.configurations where name = 'user connections'")) {
Expand Down Expand Up @@ -1439,7 +1437,7 @@ public ResultSet getPseudoColumns(String catalog, String schemaPattern, String t
}

@Override
public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getSchemas() throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
}
Expand All @@ -1448,8 +1446,7 @@ public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException

}

private java.sql.ResultSet getSchemasInternal(String catalog,
String schemaPattern) throws SQLException, SQLTimeoutException {
private java.sql.ResultSet getSchemasInternal(String catalog, String schemaPattern) throws SQLException {

String s;
// The schemas that return null for catalog name, these are prebuilt
Expand Down Expand Up @@ -1606,14 +1603,13 @@ public java.sql.ResultSet getTablePrivileges(String catalog, String schema,
}

@Override
public java.sql.ResultSet getTableTypes() throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getTableTypes() throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
}
checkClosed();
String s = "SELECT 'VIEW' 'TABLE_TYPE' UNION SELECT 'TABLE' UNION SELECT 'SYSTEM TABLE'";
SQLServerResultSet rs = getResultSetFromInternalQueries(null, s);
return rs;
return getResultSetFromInternalQueries(null, s);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.microsoft.aad.msal4j.IntegratedWindowsAuthenticationParameters;
import com.microsoft.aad.msal4j.InteractiveRequestParameters;
import com.microsoft.aad.msal4j.MsalInteractionRequiredException;
import com.microsoft.aad.msal4j.MsalThrottlingException;
import com.microsoft.aad.msal4j.PublicClientApplication;
import com.microsoft.aad.msal4j.SilentParameters;
import com.microsoft.aad.msal4j.SystemBrowserOptions;
Expand Down Expand Up @@ -74,7 +75,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} finally {
executorService.shutdown();
Expand Down Expand Up @@ -108,7 +109,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, aadPrincipalID, authenticationString);
} finally {
executorService.shutdown();
Expand Down Expand Up @@ -150,7 +151,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, "", authenticationString);
} finally {
executorService.shutdown();
Expand All @@ -177,11 +178,13 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu
StringBuilder acc = new StringBuilder();
if (accountsInCache != null) {
for (IAccount account : accountsInCache) {
if (acc.length() != 0) acc.append(", ");
if (acc.length() != 0)
acc.append(", ");
acc.append(account.username());
}
}
logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = " + (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user);
logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = "
+ (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user);
}
if (null != accountsInCache && !accountsInCache.isEmpty() && null != user && !user.isEmpty()) {
IAccount account = getAccountByUsername(accountsInCache, user);
Expand Down Expand Up @@ -228,7 +231,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} finally {
executorService.shutdown();
Expand All @@ -247,8 +250,7 @@ private static IAccount getAccountByUsername(Set<IAccount> accounts, String user
return null;
}

private static SQLServerException getCorrectedException(ExecutionException e, String user,
String authenticationString) {
private static SQLServerException getCorrectedException(Exception e, String user, String authenticationString) {
Object[] msgArgs = {user, authenticationString};

if (null == e.getCause() || null == e.getCause().getMessage()) {
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,6 @@ protected Object[][] getContents() {
{"R_objectNullOrEmpty", "The {0} is null or empty."},
{"R_cekDecryptionFailed", "Failed to decrypt a column encryption key using key store provider: {0}."},
{"R_connectTimedOut", "connect timed out"},
{"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"}};
{"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"},
{"R_failedFedauth", "Failed to acquire fedauth token: "}};
}
40 changes: 32 additions & 8 deletions src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.junit.Assert;

import com.microsoft.aad.msal4j.ClientCredentialFactory;
import com.microsoft.aad.msal4j.ClientCredentialParameters;
Expand Down Expand Up @@ -91,6 +94,8 @@ public final class TestUtils {
static final int ENGINE_EDITION_FOR_SQL_AZURE_DW = 6;
static final int ENGINE_EDITION_FOR_SQL_AZURE_MI = 8;

public static final int TEST_TOKEN_EXPIRY_SECONDS = 120; // token expiry time in secs

static String applicationKey;
static String applicationClientID;

Expand All @@ -108,18 +113,20 @@ public final class TestUtils {
public static boolean expireTokenToggle = false;

public static final SQLServerAccessTokenCallback accessTokenCallback = new SQLServerAccessTokenCallback() {
@Override public SqlAuthenticationToken getAccessToken(String spn, String stsurl) {
@Override
public SqlAuthenticationToken getAccessToken(String spn, String stsurl) {
String scope = spn + "/.default";
Set<String> scopes = new HashSet<>();
scopes.add(scope);

try {
ExecutorService executorService = Executors.newSingleThreadExecutor();
IClientCredential credential = ClientCredentialFactory.createFromSecret(applicationKey);
ConfidentialClientApplication clientApplication = ConfidentialClientApplication.builder(
applicationClientID, credential).executorService(executorService).authority(stsurl).build();
CompletableFuture<IAuthenticationResult> future = clientApplication.acquireToken(
ClientCredentialParameters.builder(scopes).build());
ConfidentialClientApplication clientApplication = ConfidentialClientApplication
.builder(applicationClientID, credential).executorService(executorService).authority(stsurl)
.build();
CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());

IAuthenticationResult authenticationResult = future.get();
String accessToken = authenticationResult.accessToken();
Expand All @@ -139,6 +146,21 @@ public final class TestUtils {
}
};

public static void setAccessTokenExpiry(Object con, String accessToken) {
Field fedAuthTokenField;
try {
fedAuthTokenField = SQLServerConnection.class.getDeclaredField("fedAuthToken");
fedAuthTokenField.setAccessible(true);

Date newExpiry = new Date(
System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(TEST_TOKEN_EXPIRY_SECONDS));
SqlAuthenticationToken newFedAuthToken = new SqlAuthenticationToken(accessToken, newExpiry);
fedAuthTokenField.set(con, newFedAuthToken);
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
Assert.fail("Failed to set token expiry: " + e.getMessage());
}
}

private TestUtils() {}

/**
Expand Down Expand Up @@ -440,7 +462,8 @@ public static void dropTypeIfExists(String typeName, java.sql.Statement stmt) th
* @throws SQLException
*/
public static void dropUserDefinedTypeIfExists(String typeName, Statement stmt) throws SQLException {
stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName) + "') DROP TYPE " + typeName);
stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName)
+ "') DROP TYPE " + typeName);
}

/**
Expand Down Expand Up @@ -1029,12 +1052,13 @@ private static java.security.cert.Certificate getCertificate(String certname) th
}
}

public static String getConnectionID(SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
public static String getConnectionID(
SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
Class<?> pooledConnection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerPooledConnection");
Class<?> connection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerConnection");

Field physicalConnection = pooledConnection.getDeclaredField("physicalConnection");
Field traceID = connection.getDeclaredField("traceID");
Field traceID = connection.getDeclaredField("traceID");

physicalConnection.setAccessible(true);
traceID.setAccessible(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ private void testAccessTokenExpiredThenCreateNewStatement(SqlAuthentication auth
TestUtils.dropTableIfExists(charTable, stmt);
}

TestUtils.setAccessTokenExpiry(connection, accessToken);
secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS;
while (secondsPassed < secondsBeforeExpiration) {
Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 2 minutes
Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s

secondsPassed = (System.currentTimeMillis() - start) / 1000;

try (Statement stmt1 = connection.createStatement()) {
testUserName(connection, azureUserName, authentication);

Expand Down Expand Up @@ -150,10 +153,12 @@ private void testAccessTokenExpiredThenExecuteUsingSameStatement(
TestUtils.dropTableIfExists(charTable, stmt);
}

TestUtils.setAccessTokenExpiry(connection, accessToken);
secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS;
while (secondsPassed < secondsBeforeExpiration) {
Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 5 minutes

Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s
secondsPassed = (System.currentTimeMillis() - start) / 1000;

testUserName(connection, azureUserName, authentication);
}

Expand Down
Loading

0 comments on commit 125bf5e

Please sign in to comment.