Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent NullPointerException in LIVE TestMode #19620

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.azure.core.test;

import com.azure.core.test.implementation.TestRunMetrics;
import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
Expand Down Expand Up @@ -39,29 +40,19 @@ public void beforeTestExecution(ExtensionContext extensionContext) {
logPrefixBuilder.append(",");

getStore(extensionContext).put(extensionContext.getRequiredTestMethod(),
new TestInformation(logPrefixBuilder.toString(), System.currentTimeMillis()));
new TestRunMetrics(logPrefixBuilder.toString(), System.currentTimeMillis()));
}

@Override
public void afterTestExecution(ExtensionContext context) {
TestInformation testInformation = getStore(context)
.remove(context.getRequiredTestMethod(), TestInformation.class);
long duration = System.currentTimeMillis() - testInformation.startMillis;
TestRunMetrics testInformation = getStore(context)
.remove(context.getRequiredTestMethod(), TestRunMetrics.class);
long duration = System.currentTimeMillis() - testInformation.getStartMillis();

System.out.printf("%s completed in %d ms.%n", testInformation.logPrefix, duration);
System.out.printf("%s completed in %d ms.%n", testInformation.getLogPrefix(), duration);
}

private static ExtensionContext.Store getStore(ExtensionContext context) {
return context.getStore(ExtensionContext.Namespace.create(AzureTestWatcher.class, context));
}

private static final class TestInformation {
private final String logPrefix;
private final long startMillis;

private TestInformation(String logPrefix, long startMillis) {
this.logPrefix = logPrefix;
this.startMillis = startMillis;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import com.azure.core.http.HttpClient;
import com.azure.core.http.HttpClientProvider;
import com.azure.core.test.implementation.ImplUtils;
import com.azure.core.test.implementation.TestIterationContext;
import com.azure.core.test.utils.TestResourceNamer;
import com.azure.core.util.Configuration;
import com.azure.core.util.CoreUtils;
Expand All @@ -22,8 +24,6 @@
import java.util.Arrays;
import java.util.Locale;
import java.util.ServiceLoader;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

Expand All @@ -38,8 +38,6 @@ public abstract class TestBase implements BeforeEachCallback {
public static final String AZURE_TEST_HTTP_CLIENTS_VALUE_NETTY = "NettyAsyncHttpClient";
public static final String AZURE_TEST_SERVICE_VERSIONS_VALUE_ALL = "ALL";

private static final Pattern TEST_ITERATION_PATTERN = Pattern.compile("test-template-invocation:#(\\d+)");

private static TestMode testMode;

private final ClientLogger logger = new ClientLogger(TestBase.class);
Expand Down Expand Up @@ -76,16 +74,14 @@ public void beforeEach(ExtensionContext extensionContext) {
@BeforeEach
public void setupTest(TestInfo testInfo) {
this.testContextManager = new TestContextManager(testInfo.getTestMethod().get(), testMode);
if (testIterationContext != null) {
testContextManager.setTestIteration(testIterationContext.testIteration);
}
testContextManager.setTestIteration(testIterationContext.getTestIteration());
logger.info("Test Mode: {}, Name: {}", testMode, testContextManager.getTestName());

try {
interceptorManager = new InterceptorManager(testContextManager);
} catch (UncheckedIOException e) {
logger.error("Could not create interceptor for {}", testContextManager.getTestName(), e);
Assertions.fail();
Assertions.fail(e);
}
testResourceNamer = new TestResourceNamer(testContextManager, interceptorManager.getRecordedData());

Expand Down Expand Up @@ -193,21 +189,15 @@ public static boolean shouldClientBeTested(HttpClient client) {
.contains(configuredHttpClient.trim().toLowerCase(Locale.ROOT)));
}

private static TestMode initializeTestMode() {
final ClientLogger logger = new ClientLogger(TestBase.class);
final String azureTestMode = Configuration.getGlobalConfiguration().get(AZURE_TEST_MODE);

if (azureTestMode != null) {
try {
return TestMode.valueOf(azureTestMode.toUpperCase(Locale.US));
} catch (IllegalArgumentException e) {
logger.error("Could not parse '{}' into TestEnum. Using 'Playback' mode.", azureTestMode);
return TestMode.PLAYBACK;
}
}

logger.info("Environment variable '{}' has not been set yet. Using 'Playback' mode.", AZURE_TEST_MODE);
return TestMode.PLAYBACK;
/**
* Initializes the {@link TestMode} from the environment configuration {@code AZURE_TEST_MODE}.
* <p>
* If {@code AZURE_TEST_MODE} isn't configured or is invalid then {@link TestMode#PLAYBACK} is returned.
*
* @return The {@link TestMode} being used for testing.
*/
static TestMode initializeTestMode() {
return ImplUtils.getTestMode();
}

/**
Expand All @@ -227,16 +217,4 @@ protected void sleepIfRunningAgainstService(long millis) {
throw logger.logExceptionAsWarning(new IllegalStateException(ex));
}
}

private static final class TestIterationContext implements BeforeEachCallback {
Integer testIteration;

@Override
public void beforeEach(ExtensionContext extensionContext) {
Matcher matcher = TEST_ITERATION_PATTERN.matcher(extensionContext.getUniqueId());
if (matcher.find()) {
testIteration = Integer.valueOf(matcher.group(1));
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.core.test.implementation;

import com.azure.core.test.TestMode;
import com.azure.core.util.Configuration;
import com.azure.core.util.logging.ClientLogger;

import java.util.Locale;

/**
* Implementation utility class.
*/
public final class ImplUtils {
public static final String AZURE_TEST_MODE = "AZURE_TEST_MODE";

/**
* Gets the {@link TestMode} being used to run tests.
*
* @return The {@link TestMode} being used to run tests.
*/
public static TestMode getTestMode() {
final ClientLogger logger = new ClientLogger(ImplUtils.class);
final String azureTestMode = Configuration.getGlobalConfiguration().get(AZURE_TEST_MODE);

if (azureTestMode != null) {
try {
return TestMode.valueOf(azureTestMode.toUpperCase(Locale.US));
} catch (IllegalArgumentException e) {
logger.error("Could not parse '{}' into TestEnum. Using 'Playback' mode.", azureTestMode);
return TestMode.PLAYBACK;
}
}

logger.info("Environment variable '{}' has not been set yet. Using 'Playback' mode.", AZURE_TEST_MODE);
return TestMode.PLAYBACK;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.core.test.implementation;

import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;

import java.util.regex.Matcher;
import java.util.regex.Pattern;

public final class TestIterationContext implements BeforeEachCallback {
private static final Pattern TEST_ITERATION_PATTERN = Pattern.compile("test-template-invocation:#(\\d+)");

private Integer testIteration;

@Override
public void beforeEach(ExtensionContext extensionContext) {
Matcher matcher = TEST_ITERATION_PATTERN.matcher(extensionContext.getUniqueId());
if (matcher.find()) {
testIteration = Integer.valueOf(matcher.group(1));
}
}

/**
* Gets the current test iteration.
*
* @return The current test iteration.
*/
public Integer getTestIteration() {
return testIteration;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.core.test.implementation;

public class TestRunMetrics {
private final String logPrefix;
private final long startMillis;

public TestRunMetrics(String logPrefix, long startMillis) {
this.logPrefix = logPrefix;
this.startMillis = startMillis;
}

public String getLogPrefix() {
return logPrefix;
}

public long getStartMillis() {
return startMillis;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.azure.core.http.HttpPipelineNextPolicy;
import com.azure.core.http.HttpResponse;
import com.azure.core.http.policy.HttpPipelinePolicy;
import com.azure.core.test.TestMode;
import com.azure.core.test.implementation.ImplUtils;
import com.azure.core.test.models.NetworkCallError;
import com.azure.core.test.models.NetworkCallRecord;
import com.azure.core.test.models.RecordedData;
Expand All @@ -29,7 +31,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.zip.GZIPInputStream;

Expand All @@ -50,6 +51,8 @@ public class RecordNetworkCallPolicy implements HttpPipelinePolicy {
private static final String BODY = "Body";
private static final String SIG = "sig";

private static final TestMode TEST_MODE = ImplUtils.getTestMode();

private final ClientLogger logger = new ClientLogger(RecordNetworkCallPolicy.class);
private final RecordedData recordedData;
private final RecordingRedactor redactor;
Expand All @@ -70,13 +73,18 @@ public RecordNetworkCallPolicy(RecordedData recordedData) {
* @param redactors The custom redactor functions to apply to redact sensitive information from recorded data.
*/
public RecordNetworkCallPolicy(RecordedData recordedData, List<Function<String, String>> redactors) {
Objects.requireNonNull(recordedData, "'recordedData' cannot be null.");
this.recordedData = recordedData;
redactor = new RecordingRedactor(redactors);

}

@Override
public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
// Test is running in LIVE mode so it won't be able to record the network call, skip recording code.
if (TEST_MODE == TestMode.LIVE) {
return next.process();
}

final NetworkCallRecord networkCallRecord = new NetworkCallRecord();
Map<String, String> headers = new HashMap<>();

Expand Down Expand Up @@ -105,7 +113,7 @@ public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineN
}).flatMap(httpResponse -> {
final HttpResponse bufferedResponse = httpResponse.buffer();

return extractResponseData(bufferedResponse).map(responseData -> {
return extractResponseData(bufferedResponse, redactor, logger).map(responseData -> {
networkCallRecord.setResponse(responseData);
String body = responseData.get(BODY);

Expand All @@ -122,14 +130,14 @@ public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineN
});
}

private void redactedAccountName(UrlBuilder urlBuilder) {
private static void redactedAccountName(UrlBuilder urlBuilder) {
String[] hostParts = urlBuilder.getHost().split("\\.");
hostParts[0] = "REDACTED";

urlBuilder.setHost(String.join(".", hostParts));
}

private void captureRequestHeaders(HttpHeaders requestHeaders, Map<String, String> captureHeaders,
private static void captureRequestHeaders(HttpHeaders requestHeaders, Map<String, String> captureHeaders,
String... headerNames) {
for (String headerName : headerNames) {
if (requestHeaders.getValue(headerName) != null) {
Expand All @@ -138,7 +146,8 @@ private void captureRequestHeaders(HttpHeaders requestHeaders, Map<String, Strin
}
}

private Mono<Map<String, String>> extractResponseData(final HttpResponse response) {
private static Mono<Map<String, String>> extractResponseData(final HttpResponse response,
final RecordingRedactor redactor, final ClientLogger logger) {
final Map<String, String> responseData = new HashMap<>();
responseData.put(STATUS_CODE, Integer.toString(response.getStatusCode()));

Expand Down