From a53f0b54f4102dc499e4c4ab0fdb3bc445932f76 Mon Sep 17 00:00:00 2001 From: Parth Chandra Date: Thu, 23 Jan 2025 15:03:55 -0800 Subject: [PATCH] fix: address post merge comet-parquet-exec review comments (#1327) - remove CometArrowUtils - add (ignored) v2 tests for get_struct_field ## Which issue does this PR close? Addresses review comments from: https://github.com/apache/datafusion-comet/pull/1318 --- .../comet/parquet/NativeBatchReader.java | 4 +- .../spark/sql/comet/CometArrowUtils.scala | 180 ------------------ .../apache/comet/CometExpressionSuite.scala | 57 +++--- 3 files changed, 35 insertions(+), 206 deletions(-) delete mode 100644 common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index 9d79b707a5..b7b0285ac3 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -57,8 +57,8 @@ import org.apache.spark.TaskContext$; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.comet.CometArrowUtils; import org.apache.spark.sql.comet.parquet.CometParquetReadSupport; +import org.apache.spark.sql.comet.util.Utils$; import org.apache.spark.sql.execution.datasources.PartitionedFile; import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter; import org.apache.spark.sql.execution.metric.SQLMetric; @@ -260,7 +260,7 @@ public void init() throws URISyntaxException, IOException { } ////// End get requested schema String timeZoneId = conf.get("spark.sql.session.timeZone"); - Schema arrowSchema = CometArrowUtils.toArrowSchema(sparkSchema, timeZoneId); + Schema arrowSchema = Utils$.MODULE$.toArrowSchema(sparkSchema, timeZoneId); ByteArrayOutputStream out = new ByteArrayOutputStream(); WriteChannel writeChannel = new WriteChannel(Channels.newChannel(out)); MessageSerializer.serialize(writeChannel, arrowSchema); diff --git a/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala b/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala deleted file mode 100644 index 2f4f55fc0b..0000000000 --- a/common/src/main/scala/org/apache/spark/sql/comet/CometArrowUtils.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet - -import scala.collection.JavaConverters._ - -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.complex.MapVector -import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ - -object CometArrowUtils { - - val rootAllocator = new RootAllocator(Long.MaxValue) - - // todo: support more types. - - /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ - def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { - case BooleanType => ArrowType.Bool.INSTANCE - case ByteType => new ArrowType.Int(8, true) - case ShortType => new ArrowType.Int(8 * 2, true) - case IntegerType => new ArrowType.Int(8 * 4, true) - case LongType => new ArrowType.Int(8 * 8, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case StringType => ArrowType.Utf8.INSTANCE - case BinaryType => ArrowType.Binary.INSTANCE - case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) - case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType if timeZoneId == null => - throw new IllegalStateException("Missing timezoneId where it is mandatory.") - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) - case TimestampNTZType => - new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) - case NullType => ArrowType.Null.INSTANCE - case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) - case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) - case _ => - throw new IllegalArgumentException() - } - - def fromArrowType(dt: ArrowType): DataType = dt match { - case ArrowType.Bool.INSTANCE => BooleanType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType - case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.SINGLE => - FloatType - case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.DOUBLE => - DoubleType - case ArrowType.Utf8.INSTANCE => StringType - case ArrowType.Binary.INSTANCE => BinaryType - case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) - case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType - case ts: ArrowType.Timestamp - if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => - TimestampNTZType - case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType - case ArrowType.Null.INSTANCE => NullType - case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => - YearMonthIntervalType() - case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() - case _ => throw new IllegalArgumentException() - // throw QueryExecutionErrors.unsupportedArrowTypeError(dt) - } - - /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ - def toArrowField(name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { - dt match { - case ArrayType(elementType, containsNull) => - val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) - new Field( - name, - fieldType, - Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) - case StructType(fields) => - val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) - new Field( - name, - fieldType, - fields - .map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId) - } - .toSeq - .asJava) - case MapType(keyType, valueType, valueContainsNull) => - val mapType = new FieldType(nullable, new ArrowType.Map(false), null) - // Note: Map Type struct can not be null, Struct Type key field can not be null - new Field( - name, - mapType, - Seq( - toArrowField( - MapVector.DATA_VECTOR_NAME, - new StructType() - .add(MapVector.KEY_NAME, keyType, nullable = false) - .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), - nullable = false, - timeZoneId)).asJava) - case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId) - case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) - new Field(name, fieldType, Seq.empty[Field].asJava) - } - } - - def fromArrowField(field: Field): DataType = { - field.getType match { - case _: ArrowType.Map => - val elementField = field.getChildren.get(0) - val keyType = fromArrowField(elementField.getChildren.get(0)) - val valueType = fromArrowField(elementField.getChildren.get(1)) - MapType(keyType, valueType, elementField.getChildren.get(1).isNullable) - case ArrowType.List.INSTANCE => - val elementField = field.getChildren().get(0) - val elementType = fromArrowField(elementField) - ArrayType(elementType, containsNull = elementField.isNullable) - case ArrowType.Struct.INSTANCE => - val fields = field.getChildren().asScala.map { child => - val dt = fromArrowField(child) - StructField(child.getName, dt, child.isNullable) - } - StructType(fields.toArray) - case arrowType => fromArrowType(arrowType) - } - } - - /** - * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType - */ - def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { - new Schema(schema.map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId) - }.asJava) - } - - def fromArrowSchema(schema: Schema): StructType = { - StructType(schema.getFields.asScala.map { field => - val dt = fromArrowField(field) - StructField(field.getName, dt, field.isNullable) - }.toArray) - } - - /** Return Map with conf settings to be used in ArrowPythonRunner */ - def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { - val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) - val pandasColsByName = Seq( - SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> - conf.pandasGroupedMapAssignColumnsByName.toString) - val arrowSafeTypeCheck = Seq( - SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> - conf.arrowSafeTypeConversion.toString) - Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) - } - -} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 99598174cf..b43fa77286 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2307,7 +2307,22 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("get_struct_field with DataFusion ParquetExec - simple case") { + private def testV1AndV2(testName: String)(f: => Unit): Unit = { + test(s"$testName - V1") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") { f } + } + + // The test will fail because it will produce a different plan and the operator check will fail + // We could get the test to pass anyway by skipping the operator check, but when V2 does get supported, + // we want to make sure we enable the operator check and marking the test as ignore will make it + // more obvious + // + ignore(s"$testName - V2") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { f } + } + } + + testV1AndV2("get_struct_field with DataFusion ParquetExec - simple case") { withTempPath { dir => // create input file with Comet disabled withSQLConf(CometConf.COMET_ENABLED.key -> "false") { @@ -2320,21 +2335,18 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { df.write.parquet(dir.toString()) } - Seq("parquet").foreach { v1List => - withSQLConf( - SQLConf.USE_V1_SOURCE_LIST.key -> v1List, - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION, - CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION, + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { - val df = spark.read.parquet(dir.toString()) - checkSparkAnswerAndOperator(df.select("nested1.id")) - } + val df = spark.read.parquet(dir.toString()) + checkSparkAnswerAndOperator(df.select("nested1.id")) } } } - test("get_struct_field with DataFusion ParquetExec - select subset of struct") { + testV1AndV2("get_struct_field with DataFusion ParquetExec - select subset of struct") { withTempPath { dir => // create input file with Comet disabled withSQLConf(CometConf.COMET_ENABLED.key -> "false") { @@ -2353,22 +2365,19 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { df.write.parquet(dir.toString()) } - Seq("parquet").foreach { v1List => - withSQLConf( - SQLConf.USE_V1_SOURCE_LIST.key -> v1List, - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION, - CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION, + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { - val df = spark.read.parquet(dir.toString()) + val df = spark.read.parquet(dir.toString()) - checkSparkAnswerAndOperator(df.select("nested1.id")) + checkSparkAnswerAndOperator(df.select("nested1.id")) - checkSparkAnswerAndOperator(df.select("nested1.id", "nested1.nested2.id")) + checkSparkAnswerAndOperator(df.select("nested1.id", "nested1.nested2.id")) - // unsupported cast from Int64 to Struct([Field { name: "id", data_type: Int64, ... - // checkSparkAnswerAndOperator(df.select("nested1.nested2.id")) - } + // unsupported cast from Int64 to Struct([Field { name: "id", data_type: Int64, ... + // checkSparkAnswerAndOperator(df.select("nested1.nested2.id")) } } } @@ -2393,7 +2402,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { df.write.parquet(dir.toString()) } - Seq("parquet").foreach { v1List => + Seq("", "parquet").foreach { v1List => withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> v1List, CometConf.COMET_ENABLED.key -> "true",