diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index c1d63299e..e7daa4279 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -1072,7 +1072,10 @@ class CometSparkSessionExtensions case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { val eliminatedPlan = plan transformUp { + case ColumnarToRowExec(child) => CometColumnarToRowExec(child) case ColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child + case CometColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => + sparkToColumnar.child case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child // Spark adds `RowToColumnar` under Comet columnar shuffle. But it's redundant as the // shuffle takes row-based input. @@ -1089,6 +1092,8 @@ class CometSparkSessionExtensions eliminatedPlan match { case ColumnarToRowExec(child: CometCollectLimitExec) => child + case CometColumnarToRowExec(child: CometCollectLimitExec) => + child case other => other } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala new file mode 100644 index 000000000..b50f0cd07 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} +import org.apache.spark.util.Utils + +/** + * Copied from Spark `ColumnarToRowExec`. Comet needs the fix for SPARK-50235 but cannot wait for + * the fix to be released in Spark versions. We copy the implementation here to apply the fix. + */ +case class CometColumnarToRowExec(child: SparkPlan) + extends ColumnarToRowTransition + with CodegenSupport { + // supportsColumnar requires to be only called on driver side, see also SPARK-37779. + assert(Utils.isInRunningSparkTask || child.supportsColumnar) + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + // `ColumnarToRowExec` processes the input RDD directly, which is kind of a leaf node in the + // codegen stage and needs to do the limit check. + protected override def canCheckLimitNotReached: Boolean = true + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches")) + + override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + val numInputBatches = longMetric("numInputBatches") + // This avoids calling `output` in the RDD closure, so that we don't need to include the entire + // plan (this) in the closure. + val localOutput = this.output + child.executeColumnar().mapPartitionsInternal { batches => + val toUnsafe = UnsafeProjection.create(localOutput, localOutput) + batches.flatMap { batch => + numInputBatches += 1 + numOutputRows += batch.numRows() + batch.rowIterator().asScala.map(toUnsafe) + } + } + } + + /** + * Generate [[ColumnVector]] expressions for our parent to consume as rows. This is called once + * per [[ColumnVector]] in the batch. + */ + private def genCodeColumnVector( + ctx: CodegenContext, + columnVar: String, + ordinal: String, + dataType: DataType, + nullable: Boolean): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) + val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal) + val isNullVar = if (nullable) { + JavaCode.isNullVariable(ctx.freshName("isNull")) + } else { + FalseLiteral + } + val valueVar = ctx.freshName("value") + val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" + val code = code"${ctx.registerComment(str)}" + (if (nullable) { + code""" + boolean $isNullVar = $columnVar.isNullAt($ordinal); + $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); + """ + } else { + code"$javaType $valueVar = $value;" + }) + ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) + } + + /** + * Produce code to process the input iterator as [[ColumnarBatch]]es. This produces an + * [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] for each row in each batch. + */ + override protected def doProduce(ctx: CodegenContext): String = { + // PhysicalRDD always just has one input + val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];") + + // metrics + val numOutputRows = metricTerm(ctx, "numOutputRows") + val numInputBatches = metricTerm(ctx, "numInputBatches") + + val columnarBatchClz = classOf[ColumnarBatch].getName + val batch = ctx.addMutableState(columnarBatchClz, "batch") + + val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0 + val columnVectorClzs = + child.vectorTypes.getOrElse(Seq.fill(output.indices.size)(classOf[ColumnVector].getName)) + val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map { + case (columnVectorClz, i) => + val name = ctx.addMutableState(columnVectorClz, s"colInstance$i") + (name, s"$name = ($columnVectorClz) $batch.column($i);") + }.unzip + + val nextBatch = ctx.freshName("nextBatch") + val nextBatchFuncName = ctx.addNewFunction( + nextBatch, + s""" + |private void $nextBatch() throws java.io.IOException { + | if ($input.hasNext()) { + | $batch = ($columnarBatchClz)$input.next(); + | $numInputBatches.add(1); + | $numOutputRows.add($batch.numRows()); + | $idx = 0; + | ${columnAssigns.mkString("", "\n", "\n")} + | } + |}""".stripMargin) + + ctx.currentVars = null + val rowidx = ctx.freshName("rowIdx") + val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => + genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) + } + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val numRows = ctx.freshName("numRows") + val shouldStop = if (parent.needStopCheck) { + s"if (shouldStop()) { $idx = $rowidx + 1; return; }" + } else { + "// shouldStop check is eliminated" + } + + val writableColumnVectorClz = classOf[WritableColumnVector].getName + + s""" + |if ($batch == null) { + | $nextBatchFuncName(); + |} + |while ($limitNotReachedCond $batch != null) { + | int $numRows = $batch.numRows(); + | int $localEnd = $numRows - $idx; + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | int $rowidx = $idx + $localIdx; + | ${consume(ctx, columnsBatchInput).trim} + | $shouldStop + | } + | $idx = $numRows; + | + | // Comet fix for SPARK-50235 + | for (int i = 0; i < ${colVars.length}; i++) { + | if (!($batch.column(i) instanceof $writableColumnVectorClz)) { + | $batch.column(i).close(); + | } + | } + | + | $batch = null; + | $nextBatchFuncName(); + |} + |// Comet fix for SPARK-50235: clean up resources + |if ($batch != null) { + | $batch.close(); + |} + """.stripMargin + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + Seq(child.executeColumnar().asInstanceOf[RDD[InternalRow]]) // Hack because of type erasure + } + + override protected def withNewChildInternal(newChild: SparkPlan): CometColumnarToRowExec = + copy(child = newChild) +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0d00867d1..7b74efb53 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -28,8 +28,8 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps -import org.apache.spark.sql.comet.CometProjectExec -import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, ProjectExec, WholeStageCodegenExec} +import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec} +import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -752,7 +752,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val project = cometPlan .asInstanceOf[WholeStageCodegenExec] .child - .asInstanceOf[ColumnarToRowExec] + .asInstanceOf[CometColumnarToRowExec] .child .asInstanceOf[InputAdapter] .child diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 1709cce61..4ff5acfb6 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -36,7 +36,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark._ import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER} -import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometColumnarToRowExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, ExtendedMode, InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -174,6 +174,7 @@ abstract class CometTestBase wrapped.foreach { case _: CometScanExec | _: CometBatchScanExec => case _: CometSinkPlaceHolder | _: CometScanWrapper => + case _: CometColumnarToRowExec => case _: CometSparkToColumnarExec => case _: CometExec | _: CometShuffleExchangeExec => case _: CometBroadcastExchangeExec =>