diff --git a/test_runner/src/main/kotlin/ftl/args/AndroidArgs.kt b/test_runner/src/main/kotlin/ftl/args/AndroidArgs.kt index 08b32b66ae..16d3c21003 100644 --- a/test_runner/src/main/kotlin/ftl/args/AndroidArgs.kt +++ b/test_runner/src/main/kotlin/ftl/args/AndroidArgs.kt @@ -65,6 +65,7 @@ class AndroidArgs( private val flank = flankYml.flank override val testShards = cli?.testShards ?: flank.testShards + override val shardTime = cli?.shardTime ?: flank.shardTime override val repeatTests = cli?.repeatTests ?: flank.repeatTests override val smartFlankGcsPath = flank.smartFlankGcsPath override val testTargetsAlwaysRun = cli?.testTargetsAlwaysRun ?: flank.testTargetsAlwaysRun @@ -160,6 +161,7 @@ ${devicesToString(devices)} flank: testShards: $testShards + shardTime: $shardTime repeatTests: $repeatTests smartFlankGcsPath: $smartFlankGcsPath files-to-download: diff --git a/test_runner/src/main/kotlin/ftl/args/ArgsHelper.kt b/test_runner/src/main/kotlin/ftl/args/ArgsHelper.kt index ab4d21371e..663fb87557 100644 --- a/test_runner/src/main/kotlin/ftl/args/ArgsHelper.kt +++ b/test_runner/src/main/kotlin/ftl/args/ArgsHelper.kt @@ -178,9 +178,12 @@ object ArgsHelper { fun calculateShards(filteredTests: List, args: IArgs): List> { val oldTestResult = GcStorage.downloadJunitXml(args) ?: JUnitTestResult(mutableListOf()) - val shardsByTime = Shard.calculateShardsByTime(filteredTests, oldTestResult, args) - return testMethodsAlwaysRun(shardsByTime.stringShards(), args) + val shardCount = Shard.shardCountByTime(filteredTests, oldTestResult, args) + + val shards = Shard.createShardsByShardCount(filteredTests, oldTestResult, args, shardCount) + + return testMethodsAlwaysRun(shards.stringShards(), args) } private fun testMethodsAlwaysRun(shards: StringShards, args: IArgs): StringShards { diff --git a/test_runner/src/main/kotlin/ftl/args/IArgs.kt b/test_runner/src/main/kotlin/ftl/args/IArgs.kt index d1b0ec376a..51912cdb09 100644 --- a/test_runner/src/main/kotlin/ftl/args/IArgs.kt +++ b/test_runner/src/main/kotlin/ftl/args/IArgs.kt @@ -16,6 +16,7 @@ interface IArgs { // FlankYml val testShards: Int + val shardTime: Int val repeatTests: Int val smartFlankGcsPath: String val testTargetsAlwaysRun: List diff --git a/test_runner/src/main/kotlin/ftl/args/IosArgs.kt b/test_runner/src/main/kotlin/ftl/args/IosArgs.kt index 7bc9727f31..edacaa1dc7 100644 --- a/test_runner/src/main/kotlin/ftl/args/IosArgs.kt +++ b/test_runner/src/main/kotlin/ftl/args/IosArgs.kt @@ -50,6 +50,7 @@ class IosArgs( private val flank = flankYml.flank override val testShards = cli?.testShards ?: flank.testShards + override val shardTime = cli?.shardTime ?: flank.shardTime override val repeatTests = cli?.repeatTests ?: flank.repeatTests override val smartFlankGcsPath = flank.smartFlankGcsPath override val testTargetsAlwaysRun = cli?.testTargetsAlwaysRun ?: flank.testTargetsAlwaysRun @@ -122,6 +123,7 @@ ${devicesToString(devices)} flank: testShards: $testShards + shardTime: $shardTime repeatTests: $repeatTests smartFlankGcsPath: $smartFlankGcsPath test-targets-always-run: diff --git a/test_runner/src/main/kotlin/ftl/args/yml/FlankYml.kt b/test_runner/src/main/kotlin/ftl/args/yml/FlankYml.kt index 8f97e54be6..2a39f22659 100644 --- a/test_runner/src/main/kotlin/ftl/args/yml/FlankYml.kt +++ b/test_runner/src/main/kotlin/ftl/args/yml/FlankYml.kt @@ -9,6 +9,7 @@ import ftl.util.Utils.fatalError @JsonIgnoreProperties(ignoreUnknown = true) class FlankYmlParams( val testShards: Int = 1, + val shardTime: Int = -1, val repeatTests: Int = 1, val smartFlankGcsPath: String = "", @@ -19,11 +20,14 @@ class FlankYmlParams( val filesToDownload: List = emptyList() ) { companion object : IYmlKeys { - override val keys = listOf("testShards", "repeatTests", "smartFlankGcsPath", "test-targets-always-run", "files-to-download") + override val keys = listOf( + "testShards", "shardTime", "repeatTests", "smartFlankGcsPath", "test-targets-always-run", "files-to-download" + ) } init { if (testShards <= 0 && testShards != -1) fatalError("testShards must be >= 1 or -1") + if (shardTime <= 0 && shardTime != -1) fatalError("shardTime must be >= 1 or -1") if (repeatTests < 1) fatalError("repeatTests must be >= 1") if (smartFlankGcsPath.isNotEmpty()) { diff --git a/test_runner/src/main/kotlin/ftl/cli/firebase/test/android/AndroidRunCommand.kt b/test_runner/src/main/kotlin/ftl/cli/firebase/test/android/AndroidRunCommand.kt index 256c937df9..2bcb570f54 100644 --- a/test_runner/src/main/kotlin/ftl/cli/firebase/test/android/AndroidRunCommand.kt +++ b/test_runner/src/main/kotlin/ftl/cli/firebase/test/android/AndroidRunCommand.kt @@ -221,6 +221,12 @@ class AndroidRunCommand : Runnable { ) var testShards: Int? = null + @Option( + names = ["--shard-time"], + description = ["The max amount of seconds each shard should run."] + ) + var shardTime: Int? = null + @Option( names = ["--repeat-tests"], description = ["The amount of times to repeat the test executions."] diff --git a/test_runner/src/main/kotlin/ftl/cli/firebase/test/ios/IosRunCommand.kt b/test_runner/src/main/kotlin/ftl/cli/firebase/test/ios/IosRunCommand.kt index 190cb4c03a..280d24eddc 100644 --- a/test_runner/src/main/kotlin/ftl/cli/firebase/test/ios/IosRunCommand.kt +++ b/test_runner/src/main/kotlin/ftl/cli/firebase/test/ios/IosRunCommand.kt @@ -105,6 +105,12 @@ class IosRunCommand : Runnable { ) var testShards: Int? = null + @Option( + names = ["--shard-time"], + description = ["The max amount of seconds each shard should run."] + ) + var shardTime: Int? = null + @Option( names = ["--repeat-tests"], description = ["The amount of times to repeat the test executions."] diff --git a/test_runner/src/main/kotlin/ftl/shard/Shard.kt b/test_runner/src/main/kotlin/ftl/shard/Shard.kt index e0226fa78b..78ed9a9f4b 100644 --- a/test_runner/src/main/kotlin/ftl/shard/Shard.kt +++ b/test_runner/src/main/kotlin/ftl/shard/Shard.kt @@ -52,13 +52,31 @@ object Shard { return "$classname/$testName" } - // take in the XML with timing info then return list of shards - fun calculateShardsByTime( + // take in the XML with timing info then return the shard count based on execution time + fun shardCountByTime( testsToRun: List, oldTestResult: JUnitTestResult, args: IArgs + ): Int { + if (args.shardTime == -1) return -1 + + val junitMap = createJunitMap(oldTestResult, args) + val testsTotalTime = testsToRun.sumByDouble { junitMap[it] ?: 10.0 } + + val shardsByTime = Math.ceil(testsTotalTime / args.shardTime).toInt() + + // We need to respect the testShards + return Math.min(shardsByTime, args.testShards) + } + + // take in the XML with timing info then return list of shards based on the amount of shards to use + fun createShardsByShardCount( + testsToRun: List, + oldTestResult: JUnitTestResult, + args: IArgs, + forcedShardCount: Int = -1 ): List { - val maxShards = args.testShards + val maxShards = if (forcedShardCount == -1) args.testShards else forcedShardCount val junitMap = createJunitMap(oldTestResult, args) var cacheMiss = 0 diff --git a/test_runner/src/test/kotlin/ftl/args/AndroidArgsTest.kt b/test_runner/src/test/kotlin/ftl/args/AndroidArgsTest.kt index db0abbb1d7..cafb841519 100644 --- a/test_runner/src/test/kotlin/ftl/args/AndroidArgsTest.kt +++ b/test_runner/src/test/kotlin/ftl/args/AndroidArgsTest.kt @@ -59,6 +59,7 @@ class AndroidArgsTest { flank: testShards: 7 + shardTime: 60 repeatTests: 8 files-to-download: - /sdcard/screenshots @@ -158,6 +159,7 @@ class AndroidArgsTest { // FlankYml assert(testShards, 7) + assert(shardTime, 60) assert(repeatTests, 8) assert(filesToDownload, listOf("/sdcard/screenshots", "/sdcard/screenshots2")) assert( @@ -211,6 +213,7 @@ AndroidArgs flank: testShards: 7 + shardTime: 60 repeatTests: 8 smartFlankGcsPath:${' '} files-to-download: @@ -669,6 +672,23 @@ AndroidArgs assertThat(AndroidArgs.load(yaml, cli).testShards).isEqualTo(3) } + @Test + fun `cli shardTime`() { + val cli = AndroidRunCommand() + CommandLine(cli).parse("--shard-time=3") + + val yaml = """ + gcloud: + app: $appApk + test: $testApk + + flank: + shardTime: 2 + """ + assertThat(AndroidArgs.load(yaml).shardTime).isEqualTo(2) + assertThat(AndroidArgs.load(yaml, cli).shardTime).isEqualTo(3) + } + @Test fun cli_repeatTests() { val cli = AndroidRunCommand() diff --git a/test_runner/src/test/kotlin/ftl/args/FlankYmlTest.kt b/test_runner/src/test/kotlin/ftl/args/FlankYmlTest.kt index ffae063c55..4afe09dec0 100644 --- a/test_runner/src/test/kotlin/ftl/args/FlankYmlTest.kt +++ b/test_runner/src/test/kotlin/ftl/args/FlankYmlTest.kt @@ -25,9 +25,10 @@ class FlankYmlTest { fun testValidArgs() { FlankYml() FlankYml(FlankYmlParams(testShards = -1)) - val yml = FlankYml(FlankYmlParams(testShards = 1, repeatTests = 1)) + val yml = FlankYml(FlankYmlParams(testShards = 1, repeatTests = 1, shardTime = 58)) assertThat(yml.flank.repeatTests).isEqualTo(1) assertThat(yml.flank.testShards).isEqualTo(1) + assertThat(yml.flank.shardTime).isEqualTo(58) assertThat(yml.flank.testTargetsAlwaysRun).isEqualTo(emptyList()) assertThat(FlankYml.map).isNotEmpty() } @@ -38,6 +39,12 @@ class FlankYmlTest { FlankYml(FlankYmlParams(testShards = -2)) } + @Test + fun testInvalidShardTime() { + exceptionRule.expectMessage("shardTime must be >= 1 or -1") + FlankYml(FlankYmlParams(shardTime = -2)) + } + @Test fun testInvalidrepeatTests() { exceptionRule.expectMessage("repeatTests must be >= 1") diff --git a/test_runner/src/test/kotlin/ftl/args/IosArgsTest.kt b/test_runner/src/test/kotlin/ftl/args/IosArgsTest.kt index 24aff71e1a..469f6e27d5 100644 --- a/test_runner/src/test/kotlin/ftl/args/IosArgsTest.kt +++ b/test_runner/src/test/kotlin/ftl/args/IosArgsTest.kt @@ -53,6 +53,7 @@ class IosArgsTest { flank: testShards: 7 + shardTime: 60 repeatTests: 8 files-to-download: - /sdcard/screenshots @@ -119,6 +120,7 @@ class IosArgsTest { // FlankYml assert(testShards, 7) + assert(shardTime, 60) assert(repeatTests, 8) assert(testTargetsAlwaysRun, listOf("a/testGrantPermissions", "a/testGrantPermissions2")) @@ -160,6 +162,7 @@ IosArgs flank: testShards: 7 + shardTime: 60 repeatTests: 8 smartFlankGcsPath:${' '} test-targets-always-run: @@ -201,6 +204,7 @@ IosArgs // FlankYml assert(testShards, 1) + assert(shardTime, -1) assert(repeatTests, 1) assert(testTargetsAlwaysRun, emptyList()) assert(filesToDownload, emptyList()) @@ -372,6 +376,23 @@ IosArgs assertThat(IosArgs.load(yaml, cli).testShards).isEqualTo(3) } + @Test + fun `cli shardTime`() { + val cli = IosRunCommand() + CommandLine(cli).parse("--shard-time=3") + + val yaml = """ + gcloud: + test: $testPath + xctestrun-file: $xctestrunFile + + flank: + shardTime: 2 + """ + assertThat(IosArgs.load(yaml).shardTime).isEqualTo(2) + assertThat(IosArgs.load(yaml, cli).shardTime).isEqualTo(3) + } + @Test fun cli_repeatTests() { val cli = IosRunCommand() diff --git a/test_runner/src/test/kotlin/ftl/cli/firebase/test/android/AndroidRunCommandTest.kt b/test_runner/src/test/kotlin/ftl/cli/firebase/test/android/AndroidRunCommandTest.kt index 6dcf3942a7..b8f9504477 100644 --- a/test_runner/src/test/kotlin/ftl/cli/firebase/test/android/AndroidRunCommandTest.kt +++ b/test_runner/src/test/kotlin/ftl/cli/firebase/test/android/AndroidRunCommandTest.kt @@ -78,6 +78,7 @@ class AndroidRunCommandTest { assertThat(cmd.project).isNull() assertThat(cmd.resultsHistoryName).isNull() assertThat(cmd.testShards).isNull() + assertThat(cmd.shardTime).isNull() assertThat(cmd.repeatTests).isNull() assertThat(cmd.testTargetsAlwaysRun).isNull() assertThat(cmd.filesToDownload).isNull() @@ -291,4 +292,12 @@ class AndroidRunCommandTest { assertThat(cmd.flakyTestAttempts).isEqualTo(10) } + + @Test + fun `shardTime parse`() { + val cmd = AndroidRunCommand() + CommandLine(cmd).parse("--shard-time=99") + + assertThat(cmd.shardTime).isEqualTo(99) + } } diff --git a/test_runner/src/test/kotlin/ftl/cli/firebase/test/ios/IosRunCommandTest.kt b/test_runner/src/test/kotlin/ftl/cli/firebase/test/ios/IosRunCommandTest.kt index 4bd669cdba..a7a3fc1680 100644 --- a/test_runner/src/test/kotlin/ftl/cli/firebase/test/ios/IosRunCommandTest.kt +++ b/test_runner/src/test/kotlin/ftl/cli/firebase/test/ios/IosRunCommandTest.kt @@ -66,6 +66,7 @@ class IosRunCommandTest { assertThat(cmd.project).isNull() assertThat(cmd.resultsHistoryName).isNull() assertThat(cmd.testShards).isNull() + assertThat(cmd.shardTime).isNull() assertThat(cmd.repeatTests).isNull() assertThat(cmd.testTargetsAlwaysRun).isNull() assertThat(cmd.testTargets).isNull() @@ -225,4 +226,12 @@ class IosRunCommandTest { assertThat(cmd.flakyTestAttempts).isEqualTo(10) } + + @Test + fun `shardTime parse`() { + val cmd = IosRunCommand() + CommandLine(cmd).parse("--shard-time=99") + + assertThat(cmd.shardTime).isEqualTo(99) + } } diff --git a/test_runner/src/test/kotlin/ftl/shard/ShardTest.kt b/test_runner/src/test/kotlin/ftl/shard/ShardTest.kt index 231bfc5b45..45dc25166e 100644 --- a/test_runner/src/test/kotlin/ftl/shard/ShardTest.kt +++ b/test_runner/src/test/kotlin/ftl/shard/ShardTest.kt @@ -9,7 +9,9 @@ import ftl.reports.xml.model.JUnitTestSuite import ftl.test.util.FlankTestRunner import java.util.concurrent.TimeUnit import kotlin.system.measureNanoTime +import org.junit.Rule import org.junit.Test +import org.junit.rules.ExpectedException import org.junit.runner.RunWith import org.mockito.Mockito.`when` import org.mockito.Mockito.mock @@ -17,6 +19,9 @@ import org.mockito.Mockito.mock @RunWith(FlankTestRunner::class) class ShardTest { + @Rule @JvmField + val exceptionRule = ExpectedException.none()!! + private fun sample(): JUnitTestResult { val testCases = mutableListOf( @@ -35,9 +40,10 @@ class ShardTest { return JUnitTestResult(mutableListOf(suite1, suite2)) } - private fun mockArgs(testShards: Int): IArgs { + private fun mockArgs(testShards: Int, shardTime: Int = 0): IArgs { val mockArgs = mock(IosArgs::class.java) `when`(mockArgs.testShards).thenReturn(testShards) + `when`(mockArgs.shardTime).thenReturn(shardTime) return mockArgs } @@ -46,7 +52,7 @@ class ShardTest { val reRunTestsToRun = listOf("a", "b", "c", "d", "e", "f", "g") val suite = sample() - val result = Shard.calculateShardsByTime(reRunTestsToRun, suite, mockArgs(100)) + val result = Shard.createShardsByShardCount(reRunTestsToRun, suite, mockArgs(100)) assertThat(result.size).isEqualTo(7) result.forEach { @@ -58,7 +64,7 @@ class ShardTest { fun sampleTest() { val reRunTestsToRun = listOf("a/a", "b/b", "c/c", "d/d", "e/e", "f/f", "g/g") val suite = sample() - val result = Shard.calculateShardsByTime(reRunTestsToRun, suite, mockArgs(3)) + val result = Shard.createShardsByShardCount(reRunTestsToRun, suite, mockArgs(3)) assertThat(result.size).isEqualTo(3) result.forEach { @@ -81,7 +87,7 @@ class ShardTest { @Test fun firstRun() { val testsToRun = listOf("a", "b", "c") - val result = Shard.calculateShardsByTime(testsToRun, JUnitTestResult(null), mockArgs(2)) + val result = Shard.createShardsByShardCount(testsToRun, JUnitTestResult(null), mockArgs(2)) assertThat(result.size).isEqualTo(2) assertThat(result.sumByDouble { it.time }).isEqualTo(30.0) @@ -94,7 +100,7 @@ class ShardTest { @Test fun mixedNewAndOld() { val testsToRun = listOf("a/a", "b/b", "c/c", "w", "y", "z") - val result = Shard.calculateShardsByTime(testsToRun, sample(), mockArgs(4)) + val result = Shard.createShardsByShardCount(testsToRun, sample(), mockArgs(4)) assertThat(result.size).isEqualTo(4) assertThat(result.sumByDouble { it.time }).isEqualTo(37.0) @@ -111,11 +117,29 @@ class ShardTest { repeat(1_000_000) { index -> testsToRun.add("$index/$index") } val nano = measureNanoTime { - Shard.calculateShardsByTime(testsToRun, JUnitTestResult(null), mockArgs(4)) + Shard.createShardsByShardCount(testsToRun, JUnitTestResult(null), mockArgs(4)) } val ms = TimeUnit.NANOSECONDS.toMillis(nano) println("Shards calculated in $ms ms") assertThat(ms).isLessThan(5000) } + + @Test + fun createShardsByShardTime_workingSample() { + val testsToRun = listOf("a/a", "b/b", "c/c", "d/d", "e/e", "f/f", "g/g") + val suite = sample() + val result = Shard.shardCountByTime(testsToRun, suite, mockArgs(20, 7)) + + assertThat(result).isEqualTo(3) + } + + @Test + fun createShardsByShardTime_countShouldNeverBeHigherThanMaxAvailable() { + val testsToRun = listOf("a/a", "b/b", "c/c", "d/d", "e/e", "f/f", "g/g") + val suite = sample() + val result = Shard.shardCountByTime(testsToRun, suite, mockArgs(2, 7)) + + assertThat(result).isEqualTo(2) + } }