diff --git a/plugin/src/main/java/org/gradle/testretry/internal/executer/RetryTestResultProcessor.java b/plugin/src/main/java/org/gradle/testretry/internal/executer/RetryTestResultProcessor.java index 5bdda552..db0912ee 100644 --- a/plugin/src/main/java/org/gradle/testretry/internal/executer/RetryTestResultProcessor.java +++ b/plugin/src/main/java/org/gradle/testretry/internal/executer/RetryTestResultProcessor.java @@ -31,6 +31,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; +import java.util.Optional; import java.util.Set; import static org.gradle.api.tasks.testing.TestResult.ResultType.SKIPPED; @@ -49,6 +50,7 @@ final class RetryTestResultProcessor implements TestResultProcessor { private Method failureMethod; private final Map activeDescriptorsById = new HashMap<>(); + private final Map parentIdByDescriptorId = new HashMap<>(); private final Set testClassesSeenInCurrentRound = new HashSet<>(); private TestNames currentRoundFailedTests = new TestNames(); @@ -80,6 +82,7 @@ public void started(TestDescriptorInternal descriptor, TestStartEvent testStartE delegate.started(descriptor, testStartEvent); } else if (!descriptor.getId().equals(rootTestDescriptorId)) { activeDescriptorsById.put(descriptor.getId(), descriptor); + parentIdByDescriptorId.put(descriptor.getId(), testStartEvent.getParentId()); registerSeenTestClass(descriptor); delegate.started(descriptor, testStartEvent); } @@ -99,7 +102,7 @@ public void completed(Object testId, TestCompleteEvent testCompleteEvent) { boolean failedInPreviousRound = previousRoundFailedTests.remove(className, name); if (failedInPreviousRound && testCompleteEvent.getResultType() == SKIPPED) { - addRetry(className, name); + addRetry(descriptor); } // class-level lifecycle failures do not guarantee that all methods that failed in the previous round will be re-executed (e.g. due to class setup failure) @@ -108,7 +111,7 @@ public void completed(Object testId, TestCompleteEvent testCompleteEvent) { if (isLifecycleFailure(className, name)) { previousRoundFailedTests.remove(className, n -> { if (isLifecycleFailure(className, n)) { - addRetry(className, n); + currentRoundFailedTests.add(className, n); } return true; }); @@ -143,14 +146,44 @@ private void registerSeenTestClass(TestDescriptorInternal descriptor) { } } - private void addRetry(String className, String name) { - if (classRetryMatcher.retryWholeClass(className)) { - currentRoundFailedTests.addClass(className); + private void addRetry(TestDescriptorInternal descriptor) { + Optional classMatchingClassRetryFilter = firstClassMatchingClassRetryFilter(descriptor); + if (classMatchingClassRetryFilter.isPresent()) { + currentRoundFailedTests.addClass(classMatchingClassRetryFilter.get().getClassName()); } else { - currentRoundFailedTests.add(className, name); + currentRoundFailedTests.add(descriptor.getClassName(), descriptor.getName()); } } + private Optional firstClassMatchingClassRetryFilter(TestDescriptorInternal descriptor) { + // top-level descriptor describes a test worker which cannot match the class retry filter + Object parentId = parentIdByDescriptorId.get(descriptor.getId()); + if (parentId == null) { + return Optional.empty(); + } + + // if the parent is not tracked for any reason, then it also cannot match the class retry filter + TestDescriptorInternal parentDescriptor = activeDescriptorsById.get(parentId); + if (parentDescriptor == null) { + return Optional.empty(); + } + + // check if any of the parent classes matches the class retry filter + Optional parentClassToRetryEntirely = firstClassMatchingClassRetryFilter(parentDescriptor); + if (parentClassToRetryEntirely.isPresent()) { + return parentClassToRetryEntirely; + } + + // check if the class on the current level matches the class retry filter + String className = descriptor.getClassName(); + if (className != null && classRetryMatcher.retryWholeClass(className)) { + return Optional.of(descriptor); + } + + // no classes in the descriptor hierarchy should be retried as a whole + return Optional.empty(); + } + private void emitFakePassedEvent(TestDescriptorInternal parent, TestCompleteEvent parentEvent, String name) { Object syntheticTestId = new Object(); TestDescriptorInternal syntheticDescriptor = new TestDescriptorImpl(syntheticTestId, parent, name); @@ -201,7 +234,7 @@ private void failure(Object testId) { String className = descriptor.getClassName(); if (className != null) { if (filter.canRetry(className)) { - addRetry(className, descriptor.getName()); + addRetry(descriptor); } else { hasRetryFilteredFailures = true; } @@ -279,6 +312,7 @@ public void reset(boolean lastRetry) { this.previousRoundFailedTests = currentRoundFailedTests; this.currentRoundFailedTests = new TestNames(); this.activeDescriptorsById.clear(); + this.parentIdByDescriptorId.clear(); } } diff --git a/plugin/src/test/groovy/org/gradle/testretry/testframework/JUnit5FuncTest.groovy b/plugin/src/test/groovy/org/gradle/testretry/testframework/JUnit5FuncTest.groovy index 261377e6..217f3952 100644 --- a/plugin/src/test/groovy/org/gradle/testretry/testframework/JUnit5FuncTest.groovy +++ b/plugin/src/test/groovy/org/gradle/testretry/testframework/JUnit5FuncTest.groovy @@ -634,16 +634,16 @@ class JUnit5FuncTest extends AbstractFrameworkFuncTest { void testOk() { } - @Test - void testFlaky() { - ${flakyAssert("topLevel")} - } - @Nested class NestedTest1 { @Test void testOk() { } + + @Test + void testFlaky() { + ${flakyAssert("topLevel")} + } } @Nested @@ -662,11 +662,12 @@ class JUnit5FuncTest extends AbstractFrameworkFuncTest { with(result.output) { // all methods of TopLevelTest are rerun it.count("${classAndMethodForNested('TopLevelTest', null, 'testOk()', gradleVersion)} PASSED") == 2 - it.count("${classAndMethodForNested('TopLevelTest', null, 'testFlaky()', gradleVersion)} FAILED") == 1 - it.count("${classAndMethodForNested('TopLevelTest', null, 'testFlaky()', gradleVersion)} PASSED") == 1 // all methods of nested classes are retried it.count("${classAndMethodForNested('TopLevelTest', 'NestedTest1', 'testOk()', gradleVersion)} PASSED") == 2 + it.count("${classAndMethodForNested('TopLevelTest', 'NestedTest1', 'testFlaky()', gradleVersion)} FAILED") == 1 + it.count("${classAndMethodForNested('TopLevelTest', 'NestedTest1', 'testFlaky()', gradleVersion)} PASSED") == 1 + it.count("${classAndMethodForNested('TopLevelTest', 'NestedTest2', 'testOk()', gradleVersion)} PASSED") == 2 }