diff --git a/dev/diffs/3.4.3.diff b/dev/diffs/3.4.3.diff index e9b197e467..22fe2597dc 100644 --- a/dev/diffs/3.4.3.diff +++ b/dev/diffs/3.4.3.diff @@ -271,10 +271,15 @@ index 56e9520fdab..917932336df 100644 spark.range(100).write.saveAsTable(s"$dbName.$table2Name") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala -index a9f69ab28a1..4056f14c893 100644 +index a9f69ab28a1..5d9d4f2cb83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala -@@ -43,7 +43,7 @@ import org.apache.spark.sql.connector.FakeV2Provider +@@ -39,11 +39,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri + import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation + import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} + import org.apache.spark.sql.catalyst.util.DateTimeUtils ++import org.apache.spark.sql.comet.CometBroadcastExchangeExec + import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, SortExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -283,7 +288,7 @@ index a9f69ab28a1..4056f14c893 100644 import org.apache.spark.sql.expressions.{Aggregator, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -@@ -1981,7 +1981,7 @@ class DataFrameSuite extends QueryTest +@@ -1981,7 +1982,7 @@ class DataFrameSuite extends QueryTest fail("Should not have back to back Aggregates") } atFirstAgg = true @@ -292,7 +297,7 @@ index a9f69ab28a1..4056f14c893 100644 case _ => } } -@@ -2305,7 +2305,7 @@ class DataFrameSuite extends QueryTest +@@ -2305,7 +2306,7 @@ class DataFrameSuite extends QueryTest checkAnswer(join, df) assert( collect(join.queryExecution.executedPlan) { @@ -301,7 +306,7 @@ index a9f69ab28a1..4056f14c893 100644 assert( collect(join.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size === 1) val broadcasted = broadcast(join) -@@ -2313,7 +2313,7 @@ class DataFrameSuite extends QueryTest +@@ -2313,10 +2314,12 @@ class DataFrameSuite extends QueryTest checkAnswer(join2, df) assert( collect(join2.queryExecution.executedPlan) { @@ -309,8 +314,14 @@ index a9f69ab28a1..4056f14c893 100644 + case _: ShuffleExchangeLike => true }.size == 1) assert( collect(join2.queryExecution.executedPlan) { - case e: BroadcastExchangeExec => true }.size === 1) -@@ -2876,7 +2876,7 @@ class DataFrameSuite extends QueryTest +- case e: BroadcastExchangeExec => true }.size === 1) ++ case e: BroadcastExchangeExec => true ++ case _: CometBroadcastExchangeExec => true ++ }.size === 1) + assert( + collect(join2.queryExecution.executedPlan) { case e: ReusedExchangeExec => true }.size == 4) + } +@@ -2876,7 +2879,7 @@ class DataFrameSuite extends QueryTest // Assert that no extra shuffle introduced by cogroup. val exchanges = collect(df3.queryExecution.executedPlan) { @@ -319,7 +330,7 @@ index a9f69ab28a1..4056f14c893 100644 } assert(exchanges.size == 2) } -@@ -3325,7 +3325,8 @@ class DataFrameSuite extends QueryTest +@@ -3325,7 +3328,8 @@ class DataFrameSuite extends QueryTest assert(df2.isLocal) } @@ -658,17 +669,19 @@ index 1792b4c32eb..1616e6f39bd 100644 assert(shuffleMergeJoins.size == 1) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala -index 7f062bfb899..400e939468f 100644 +index 7f062bfb899..7e4ae6b6677 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala -@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier +@@ -30,7 +30,8 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.Filter +-import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} +import org.apache.spark.sql.comet._ - import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} ++import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, FilterExec, InputAdapter, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} + import org.apache.spark.sql.execution.joins._ @@ -740,7 +741,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } } @@ -679,7 +692,20 @@ index 7f062bfb899..400e939468f 100644 withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD.key -> "0", SQLConf.SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD.key -> "1") { -@@ -1115,9 +1117,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -866,10 +868,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + val physical = df.queryExecution.sparkPlan + val physicalJoins = physical.collect { + case j: SortMergeJoinExec => j ++ case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec] + } + val executed = df.queryExecution.executedPlan + val executedJoins = collect(executed) { + case j: SortMergeJoinExec => j ++ case j: CometSortMergeJoinExec => j.originalPlan.asInstanceOf[SortMergeJoinExec] + } + // This only applies to the above tested queries, in which a child SortMergeJoin always + // contains the SortOrder required by its parent SortMergeJoin. Thus, SortExec should never +@@ -1115,9 +1119,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val plan = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) .groupBy($"k1").count() .queryExecution.executedPlan @@ -693,7 +719,7 @@ index 7f062bfb899..400e939468f 100644 }) } -@@ -1134,10 +1138,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1134,10 +1140,11 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) .queryExecution .executedPlan @@ -707,7 +733,7 @@ index 7f062bfb899..400e939468f 100644 }) // Test shuffled hash join -@@ -1147,10 +1152,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1147,10 +1154,13 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan .join(df4.hint("SHUFFLE_MERGE"), $"k1" === $"k4", joinType) .queryExecution .executedPlan @@ -724,7 +750,7 @@ index 7f062bfb899..400e939468f 100644 }) } -@@ -1241,12 +1249,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1241,12 +1251,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan inputDFs.foreach { case (df1, df2, joinExprs) => val smjDF = df1.join(df2.hint("SHUFFLE_MERGE"), joinExprs, "full") assert(collect(smjDF.queryExecution.executedPlan) { @@ -739,7 +765,35 @@ index 7f062bfb899..400e939468f 100644 // Same result between shuffled hash join and sort merge join checkAnswer(shjDF, smjResult) } -@@ -1341,7 +1349,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1282,18 +1292,25 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + } + + // Test shuffled hash join +- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { ++ withSQLConf("spark.comet.enabled" -> "true", ++ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val shjCodegenDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) + assert(shjCodegenDF.queryExecution.executedPlan.collect { + case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true + case WholeStageCodegenExec(ProjectExec(_, _ : ShuffledHashJoinExec)) => true ++ case WholeStageCodegenExec(ColumnarToRowExec(InputAdapter(_: CometHashJoinExec))) => ++ true ++ case WholeStageCodegenExec(ColumnarToRowExec( ++ InputAdapter(CometProjectExec(_, _, _, _, _: CometHashJoinExec, _)))) => true + }.size === 1) + checkAnswer(shjCodegenDF, Seq.empty) + + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val shjNonCodegenDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2", joinType) + assert(shjNonCodegenDF.queryExecution.executedPlan.collect { +- case _: ShuffledHashJoinExec => true }.size === 1) ++ case _: ShuffledHashJoinExec => true ++ case _: CometHashJoinExec => true ++ }.size === 1) + checkAnswer(shjNonCodegenDF, Seq.empty) + } + } +@@ -1341,7 +1358,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val plan = sql(getAggQuery(selectExpr, joinType)).queryExecution.executedPlan assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) // Have shuffle before aggregation @@ -749,7 +803,7 @@ index 7f062bfb899..400e939468f 100644 } def getJoinQuery(selectExpr: String, joinType: String): String = { -@@ -1370,9 +1379,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1370,9 +1388,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) @@ -764,7 +818,7 @@ index 7f062bfb899..400e939468f 100644 } // Test output ordering is not preserved -@@ -1381,9 +1393,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1381,9 +1402,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0" val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1) @@ -779,7 +833,7 @@ index 7f062bfb899..400e939468f 100644 } // Test singe partition -@@ -1393,7 +1408,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan +@@ -1393,7 +1417,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan |FROM range(0, 10, 1, 1) t1 FULL OUTER JOIN range(0, 10, 1, 1) t2 |""".stripMargin) val plan = fullJoinDF.queryExecution.executedPlan @@ -789,6 +843,28 @@ index 7f062bfb899..400e939468f 100644 checkAnswer(fullJoinDF, Row(100)) } } +@@ -1438,6 +1463,9 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + Seq(semiJoinDF, antiJoinDF).foreach { df => + assert(collect(df.queryExecution.executedPlan) { + case j: ShuffledHashJoinExec if j.ignoreDuplicatedKey == ignoreDuplicatedKey => true ++ case j: CometHashJoinExec ++ if j.originalPlan.asInstanceOf[ShuffledHashJoinExec].ignoreDuplicatedKey == ++ ignoreDuplicatedKey => true + }.size == 1) + } + } +@@ -1489,7 +1517,10 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan + + test("SPARK-43113: Full outer join with duplicate stream-side references in condition (SHJ)") { + def check(plan: SparkPlan): Unit = { +- assert(collect(plan) { case _: ShuffledHashJoinExec => true }.size === 1) ++ assert(collect(plan) { ++ case _: ShuffledHashJoinExec => true ++ case _: CometHashJoinExec => true ++ }.size === 1) + } + dupStreamSideColTest("SHUFFLE_HASH", check) + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index b5b34922694..a72403780c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -1292,7 +1368,7 @@ index ac710c32296..baae214c6ee 100644 val df = spark.read.parquet(path).selectExpr(projection: _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala -index 593bd7bb4ba..7ad55e3ab20 100644 +index 593bd7bb4ba..32af28b0238 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -26,9 +26,11 @@ import org.scalatest.time.SpanSugar._ @@ -1316,7 +1392,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 } } -@@ -116,30 +119,38 @@ class AdaptiveQueryExecSuite +@@ -116,30 +119,39 @@ class AdaptiveQueryExecSuite private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { collect(plan) { case j: SortMergeJoinExec => j @@ -1338,6 +1414,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 case j: BaseJoinExec => j + case c: CometHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] + case c: CometSortMergeJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] ++ case c: CometBroadcastHashJoinExec => c.originalPlan.asInstanceOf[BaseJoinExec] } } @@ -1355,7 +1432,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 } } -@@ -176,6 +187,7 @@ class AdaptiveQueryExecSuite +@@ -176,6 +188,7 @@ class AdaptiveQueryExecSuite val parts = rdd.partitions assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) } @@ -1363,7 +1440,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead)) } -@@ -184,7 +196,7 @@ class AdaptiveQueryExecSuite +@@ -184,7 +197,7 @@ class AdaptiveQueryExecSuite val plan = df.queryExecution.executedPlan assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { @@ -1372,7 +1449,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 } assert(shuffle.size == 1) assert(shuffle(0).outputPartitioning.numPartitions == numPartition) -@@ -200,7 +212,8 @@ class AdaptiveQueryExecSuite +@@ -200,7 +213,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) @@ -1382,7 +1459,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 } } -@@ -227,7 +240,8 @@ class AdaptiveQueryExecSuite +@@ -227,7 +241,8 @@ class AdaptiveQueryExecSuite } } @@ -1392,7 +1469,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", -@@ -259,7 +273,8 @@ class AdaptiveQueryExecSuite +@@ -259,7 +274,8 @@ class AdaptiveQueryExecSuite } } @@ -1402,7 +1479,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", -@@ -273,7 +288,8 @@ class AdaptiveQueryExecSuite +@@ -273,7 +289,8 @@ class AdaptiveQueryExecSuite val localReads = collect(adaptivePlan) { case read: AQEShuffleReadExec if read.isLocalRead => read } @@ -1412,7 +1489,29 @@ index 593bd7bb4ba..7ad55e3ab20 100644 val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 -@@ -322,7 +338,7 @@ class AdaptiveQueryExecSuite +@@ -298,7 +315,9 @@ class AdaptiveQueryExecSuite + .groupBy($"a").count() + checkAnswer(testDf, Seq()) + val plan = testDf.queryExecution.executedPlan +- assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) ++ assert(find(plan) { case p => ++ p.isInstanceOf[SortMergeJoinExec] || p.isInstanceOf[CometSortMergeJoinExec] ++ }.isDefined) + val coalescedReads = collect(plan) { + case r: AQEShuffleReadExec => r + } +@@ -312,7 +331,9 @@ class AdaptiveQueryExecSuite + .groupBy($"a").count() + checkAnswer(testDf, Seq()) + val plan = testDf.queryExecution.executedPlan +- assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) ++ assert(find(plan) { case p => ++ p.isInstanceOf[BroadcastHashJoinExec] || p.isInstanceOf[CometBroadcastHashJoinExec] ++ }.isDefined) + val coalescedReads = collect(plan) { + case r: AQEShuffleReadExec => r + } +@@ -322,7 +343,7 @@ class AdaptiveQueryExecSuite } } @@ -1421,7 +1520,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { -@@ -337,7 +353,7 @@ class AdaptiveQueryExecSuite +@@ -337,7 +358,7 @@ class AdaptiveQueryExecSuite } } @@ -1430,7 +1529,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { -@@ -353,7 +369,7 @@ class AdaptiveQueryExecSuite +@@ -353,7 +374,7 @@ class AdaptiveQueryExecSuite } } @@ -1439,7 +1538,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { -@@ -398,7 +414,7 @@ class AdaptiveQueryExecSuite +@@ -398,7 +419,7 @@ class AdaptiveQueryExecSuite } } @@ -1448,7 +1547,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { -@@ -443,7 +459,7 @@ class AdaptiveQueryExecSuite +@@ -443,7 +464,7 @@ class AdaptiveQueryExecSuite } } @@ -1457,7 +1556,16 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { -@@ -508,7 +524,7 @@ class AdaptiveQueryExecSuite +@@ -489,7 +510,7 @@ class AdaptiveQueryExecSuite + } + } + +- test("Exchange reuse") { ++ test("Exchange reuse", IgnoreComet("Comet shuffle changes shuffle metrics")) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { +@@ -508,7 +529,7 @@ class AdaptiveQueryExecSuite } } @@ -1466,7 +1574,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { -@@ -539,7 +555,9 @@ class AdaptiveQueryExecSuite +@@ -539,7 +560,9 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) @@ -1477,7 +1585,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 // Even with local shuffle read, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.nonEmpty) -@@ -560,7 +578,9 @@ class AdaptiveQueryExecSuite +@@ -560,7 +583,9 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) @@ -1488,7 +1596,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 // Even with local shuffle read, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.isEmpty) -@@ -569,7 +589,8 @@ class AdaptiveQueryExecSuite +@@ -569,7 +594,8 @@ class AdaptiveQueryExecSuite } } @@ -1498,7 +1606,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", -@@ -664,7 +685,8 @@ class AdaptiveQueryExecSuite +@@ -664,7 +690,8 @@ class AdaptiveQueryExecSuite val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) // There is still a SMJ, and its two shuffles can't apply local read. @@ -1508,7 +1616,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 } } -@@ -786,7 +808,8 @@ class AdaptiveQueryExecSuite +@@ -786,7 +813,8 @@ class AdaptiveQueryExecSuite } } @@ -1518,7 +1626,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") { findTopLevelSortMergeJoin(plan) -@@ -1004,7 +1027,8 @@ class AdaptiveQueryExecSuite +@@ -1004,7 +1032,8 @@ class AdaptiveQueryExecSuite } } @@ -1528,7 +1636,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { val (_, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT key FROM testData GROUP BY key") -@@ -1599,7 +1623,7 @@ class AdaptiveQueryExecSuite +@@ -1599,7 +1628,7 @@ class AdaptiveQueryExecSuite val (_, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") assert(collect(adaptivePlan) { @@ -1537,7 +1645,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 }.length == 1) } } -@@ -1679,7 +1703,8 @@ class AdaptiveQueryExecSuite +@@ -1679,7 +1708,8 @@ class AdaptiveQueryExecSuite } } @@ -1547,7 +1655,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 def hasRepartitionShuffle(plan: SparkPlan): Boolean = { find(plan) { case s: ShuffleExchangeLike => -@@ -1864,6 +1889,9 @@ class AdaptiveQueryExecSuite +@@ -1864,6 +1894,9 @@ class AdaptiveQueryExecSuite def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = { assert(collect(ds.queryExecution.executedPlan) { case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s @@ -1557,7 +1665,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 }.size == 1) ds.collect() val plan = ds.queryExecution.executedPlan -@@ -1872,6 +1900,9 @@ class AdaptiveQueryExecSuite +@@ -1872,6 +1905,9 @@ class AdaptiveQueryExecSuite }.isEmpty) assert(collect(plan) { case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s @@ -1567,7 +1675,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 }.size == 1) checkAnswer(ds, testData) } -@@ -2028,7 +2059,8 @@ class AdaptiveQueryExecSuite +@@ -2028,7 +2064,8 @@ class AdaptiveQueryExecSuite } } @@ -1577,7 +1685,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withTempView("t1", "t2") { def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = { Seq("100", "100000").foreach { size => -@@ -2114,7 +2146,8 @@ class AdaptiveQueryExecSuite +@@ -2114,7 +2151,8 @@ class AdaptiveQueryExecSuite } } @@ -1587,7 +1695,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withTempView("v") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", -@@ -2213,7 +2246,7 @@ class AdaptiveQueryExecSuite +@@ -2213,7 +2251,7 @@ class AdaptiveQueryExecSuite runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + s"JOIN skewData2 ON key1 = key2 GROUP BY key1") val shuffles1 = collect(adaptive1) { @@ -1596,7 +1704,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 } assert(shuffles1.size == 3) // shuffles1.head is the top-level shuffle under the Aggregate operator -@@ -2226,7 +2259,7 @@ class AdaptiveQueryExecSuite +@@ -2226,7 +2264,7 @@ class AdaptiveQueryExecSuite runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + s"JOIN skewData2 ON key1 = key2") val shuffles2 = collect(adaptive2) { @@ -1605,7 +1713,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 } if (hasRequiredDistribution) { assert(shuffles2.size == 3) -@@ -2260,7 +2293,8 @@ class AdaptiveQueryExecSuite +@@ -2260,7 +2298,8 @@ class AdaptiveQueryExecSuite } } @@ -1615,7 +1723,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 CostEvaluator.instantiate( classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) intercept[IllegalArgumentException] { -@@ -2404,6 +2438,7 @@ class AdaptiveQueryExecSuite +@@ -2404,6 +2443,7 @@ class AdaptiveQueryExecSuite val (_, adaptive) = runAdaptiveAndVerifyResult(query) assert(adaptive.collect { case sort: SortExec => sort @@ -1623,7 +1731,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 }.size == 1) val read = collect(adaptive) { case read: AQEShuffleReadExec => read -@@ -2421,7 +2456,8 @@ class AdaptiveQueryExecSuite +@@ -2421,7 +2461,8 @@ class AdaptiveQueryExecSuite } } @@ -1633,7 +1741,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 withTempView("v") { withSQLConf( SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", -@@ -2533,7 +2569,7 @@ class AdaptiveQueryExecSuite +@@ -2533,7 +2574,7 @@ class AdaptiveQueryExecSuite runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + "JOIN skewData3 ON value2 = value3") val shuffles1 = collect(adaptive1) { @@ -1642,7 +1750,7 @@ index 593bd7bb4ba..7ad55e3ab20 100644 } assert(shuffles1.size == 4) val smj1 = findTopLevelSortMergeJoin(adaptive1) -@@ -2544,7 +2580,7 @@ class AdaptiveQueryExecSuite +@@ -2544,7 +2585,7 @@ class AdaptiveQueryExecSuite runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + "JOIN skewData3 ON value1 = value3") val shuffles2 = collect(adaptive2) { @@ -1672,10 +1780,10 @@ index bd9c79e5b96..ab7584e768e 100644 assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala -index ce43edb79c1..c414b19eda7 100644 +index ce43edb79c1..8436cb727c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala -@@ -17,7 +17,7 @@ +@@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.datasources @@ -1683,8 +1791,27 @@ index ce43edb79c1..c414b19eda7 100644 +import org.apache.spark.sql.{IgnoreComet, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, NullsFirst, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} ++import org.apache.spark.sql.comet.CometSortExec import org.apache.spark.sql.execution.{QueryExecution, SortExec} -@@ -305,7 +305,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write + import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec + import org.apache.spark.sql.internal.SQLConf +@@ -225,6 +226,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write + // assert the outer most sort in the executed plan + assert(plan.collectFirst { + case s: SortExec => s ++ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] + }.exists { + case SortExec(Seq( + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), +@@ -272,6 +274,7 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write + // assert the outer most sort in the executed plan + assert(plan.collectFirst { + case s: SortExec => s ++ case s: CometSortExec => s.originalPlan.asInstanceOf[SortExec] + }.exists { + case SortExec(Seq( + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), +@@ -305,7 +308,8 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write } }