From cef4884ec7871c938ebe1db9d19a9960eb9a9419 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 2 Aug 2024 12:14:13 -0700 Subject: [PATCH] fix: Fallback to Spark for unsupported input besides ordering --- .../apache/comet/serde/QueryPlanSerde.scala | 9 +++-- .../exec/CometColumnarShuffleSuite.scala | 38 ++++++++++++++----- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f8f53ad2c0..3fbe6d7443 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2926,7 +2926,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case HashPartitioning(expressions, _) => val supported = expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - expressions.forall(e => supportedDataType(e.dataType)) + expressions.forall(e => supportedDataType(e.dataType)) && + inputs.forall(attr => supportedDataType(attr.dataType)) if (!supported) { msg = s"unsupported Spark partitioning expressions: $expressions" } @@ -2936,7 +2937,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case RangePartitioning(orderings, _) => val supported = orderings.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - orderings.forall(e => supportedDataType(e.dataType)) + orderings.forall(e => supportedDataType(e.dataType)) && + inputs.forall(attr => supportedDataType(attr.dataType)) if (!supported) { msg = s"unsupported Spark partitioning expressions: $orderings" } @@ -2975,7 +2977,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case HashPartitioning(expressions, _) => val supported = expressions.map(QueryPlanSerde.exprToProto(_, inputs)).forall(_.isDefined) && - expressions.forall(e => supportedDataType(e.dataType)) + expressions.forall(e => supportedDataType(e.dataType)) && + inputs.forall(attr => supportedDataType(attr.dataType)) if (!supported) { msg = s"unsupported Spark partitioning expressions: $expressions" } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 25f2a7537e..9ae882d861 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -19,12 +19,14 @@ package org.apache.comet.exec +import scala.util.Random + import org.scalactic.source.Position import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} -import org.apache.spark.sql.{CometTestBase, DataFrame, Row} +import org.apache.spark.sql.{CometTestBase, DataFrame, RandomDataGenerator, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec @@ -68,17 +70,35 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar test("Unsupported types for SinglePartition should fallback to Spark") { checkSparkAnswer(spark.sql(""" - |SELECT - | AVG(null), - | COUNT(null), - | FIRST(null), - | LAST(null), - | MAX(null), - | MIN(null), - | SUM(null) + |SELECT + | AVG(null), + | COUNT(null), + | FIRST(null), + | LAST(null), + | MAX(null), + | MIN(null), + | SUM(null) """.stripMargin)) } + test("Fallback to Spark for unsupported input besides ordering") { + val dataGenerator = RandomDataGenerator + .forType( + dataType = NullType, + nullable = true, + new Random(System.nanoTime()), + validJulianDatetime = false) + .get + + val schema = new StructType() + .add("index", IntegerType, nullable = false) + .add("col", NullType, nullable = true) + val rdd = + spark.sparkContext.parallelize((1 to 20).map(i => Row(i, dataGenerator()))) + val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) + checkSparkAnswer(df) + } + test("Disable Comet shuffle with AQE coalesce partitions enabled") { Seq(true, false).foreach { coalescePartitionsEnabled => withSQLConf(