Skip to content

Commit

Permalink
fix compile failure for 3.2 and 3.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Huaxin Gao committed May 20, 2024
1 parent fb1ac87 commit 6d87d34
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 106 deletions.
18 changes: 1 addition & 17 deletions common/src/main/java/org/apache/comet/parquet/BatchReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
import org.apache.spark.sql.execution.datasources.PartitionedFile;
import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.vectorized.ColumnarBatch;
Expand Down Expand Up @@ -257,13 +256,7 @@ public void init() throws URISyntaxException, IOException {
MessageType fileSchema = requestedSchema;

if (sparkSchema == null) {
// TODO: remove this after we drop the support for Spark 3.2 and 3.3
boolean isSpark34 = classExists("org.apache.spark.sql.catalyst.util.ResolveDefaultColumns$");
if (isSpark34) {
sparkSchema = new CometParquetToSparkSchemaConverter(conf).convert(requestedSchema);
} else {
sparkSchema = new ParquetToSparkSchemaConverter(conf).convert(requestedSchema);
}
sparkSchema = new CometParquetToSparkSchemaConverter(conf).convert(requestedSchema);
} else {
requestedSchema =
CometParquetReadSupport.clipParquetSchema(
Expand Down Expand Up @@ -586,15 +579,6 @@ public void submitPrefetchTask(ExecutorService threadPool) {
this.prefetchTask = threadPool.submit(new PrefetchTask());
}

private boolean classExists(String className) {
try {
Class<?> cls = Class.forName(className);
return true;
} catch (ClassNotFoundException e) {
return false;
}
}

// A task for prefetching parquet row groups.
private class PrefetchTask implements Callable<Option<Throwable>> {
private long getBytesRead() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.parquet

import org.apache.parquet.column.ColumnDescriptor
import org.apache.parquet.io.GroupColumnIO
import org.apache.parquet.io.PrimitiveColumnIO
import org.apache.parquet.schema.Type.Repetition

import org.apache.spark.sql.types.DataType

/**
* Rich information for a Parquet column together with its SparkSQL type.
*/
case class CometParquetColumn(
sparkType: DataType,
descriptor: Option[ColumnDescriptor], // only set when this is a primitive column
repetitionLevel: Int,
definitionLevel: Int,
required: Boolean,
path: Seq[String],
children: Seq[CometParquetColumn]) {

def isPrimitive: Boolean = descriptor.nonEmpty
}

object CometParquetColumn {
def apply(sparkType: DataType, io: PrimitiveColumnIO): CometParquetColumn = {
this(sparkType, Some(io.getColumnDescriptor), io.getRepetitionLevel,
io.getDefinitionLevel, io.getType.isRepetition(Repetition.REQUIRED),
io.getFieldPath, Seq.empty)
}

def apply(sparkType: DataType, io: GroupColumnIO, children: Seq[CometParquetColumn]): CometParquetColumn = {
this(sparkType, None, io.getRepetitionLevel,
io.getDefinitionLevel, io.getType.isRepetition(Repetition.REQUIRED),
io.getFieldPath, children)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,116 @@
package org.apache.comet.parquet

import org.apache.hadoop.conf.Configuration
import org.apache.parquet.io.{ColumnIO, GroupColumnIO, PrimitiveColumnIO}
import org.apache.parquet.io.{ColumnIO, ColumnIOFactory, GroupColumnIO, PrimitiveColumnIO}
import org.apache.parquet.schema._
import org.apache.parquet.schema.LogicalTypeAnnotation._
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
import org.apache.parquet.schema.Type.Repetition._
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.normalizeFieldName
import org.apache.spark.sql.execution.datasources.parquet.{ParquetColumn, ParquetToSparkSchemaConverter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import java.util.Locale

class CometParquetToSparkSchemaConverter(
assumeBinaryIsString: Boolean = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get,
assumeInt96IsTimestamp: Boolean = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get,
inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
nanosAsLong: Boolean = SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get) extends ParquetToSparkSchemaConverter {
inferTimestampNTZ: Boolean = true,
nanosAsLong: Boolean = false) {

def this(conf: Configuration) = {
this(
conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
conf.get(SQLConf.CASE_SENSITIVE.key).toBoolean,
Option(conf.get("spark.sql.parquet.inferTimestampNTZ.enabled")).map(_.toBoolean).getOrElse(false),
Option(conf.get("spark.sql.legacy.parquet.nanosAsLong")).map(_.toBoolean).getOrElse(true)
)
}

/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
*/
def convert(parquetSchema: MessageType): StructType = {
val column = new ColumnIOFactory().getColumnIO(parquetSchema)
val converted = convertInternal(column)
converted.sparkType.asInstanceOf[StructType]
}

private def convertInternal(
groupColumn: GroupColumnIO,
sparkReadSchema: Option[StructType] = None): CometParquetColumn = {
// First convert the read schema into a map from field name to the field itself, to avoid O(n)
// lookup cost below.
val schemaMapOpt = sparkReadSchema.map { schema =>
schema.map(f => normalizeFieldName(f.name) -> f).toMap
}

val converted = (0 until groupColumn.getChildrenCount).map { i =>
val field = groupColumn.getChild(i)
val fieldFromReadSchema = schemaMapOpt.flatMap { schemaMap =>
schemaMap.get(normalizeFieldName(field.getName))
}
var fieldReadType = fieldFromReadSchema.map(_.dataType)

def this(conf: Configuration) = this(
assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
assumeInt96IsTimestamp = conf.get(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key).toBoolean,
caseSensitive = conf.get(SQLConf.CASE_SENSITIVE.key).toBoolean,
inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
nanosAsLong = conf.get(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key).toBoolean)
// If a field is repeated here then it is neither contained by a `LIST` nor `MAP`
// annotated group (these should've been handled in `convertGroupField`), e.g.:
//
// message schema {
// repeated int32 int_array;
// }
// or
// message schema {
// repeated group struct_array {
// optional int32 field;
// }
// }
//
// the corresponding Spark read type should be an array and we should pass the element type
// to the group or primitive type conversion method.
if (field.getType.getRepetition == REPEATED) {
fieldReadType = fieldReadType.flatMap {
case at: ArrayType => Some(at.elementType)
case _ =>
throw new UnsupportedOperationException("Illegal Parquet type " + groupColumn.toString)
}
}

val convertedField = convertField(field, fieldReadType)
val fieldName = fieldFromReadSchema.map(_.name).getOrElse(field.getType.getName)

field.getType.getRepetition match {
case OPTIONAL | REQUIRED =>
val nullable = field.getType.getRepetition == OPTIONAL
(StructField(fieldName, convertedField.sparkType, nullable = nullable),
convertedField)

case REPEATED =>
// A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor
// annotated by `LIST` or `MAP` should be interpreted as a required list of required
// elements where the element type is the type of the field.
val arrayType = ArrayType(convertedField.sparkType, containsNull = false)
(StructField(fieldName, arrayType, nullable = false),
CometParquetColumn(arrayType, None, convertedField.repetitionLevel - 1,
convertedField.definitionLevel - 1, required = true, convertedField.path,
Seq(convertedField.copy(required = true))))
}
}

CometParquetColumn(StructType(converted.map(_._1)), groupColumn, converted.map(_._2))
}

override def convertField(
private def normalizeFieldName(name: String): String =
if (caseSensitive) name else name.toLowerCase(Locale.ROOT)

/**
* Converts a Parquet [[Type]] to a [[CometParquetColumn]] which wraps a Spark SQL [[DataType]] with
* additional information such as the Parquet column's repetition & definition level, column
* path, column descriptor etc.
*/
def convertField(
field: ColumnIO,
sparkReadType: Option[DataType] = None): ParquetColumn = {
sparkReadType: Option[DataType] = None): CometParquetColumn = {
val targetType = sparkReadType.map {
case udt: UserDefinedType[_] => udt.sqlType
case otherType => otherType
Expand All @@ -59,7 +142,7 @@ class CometParquetToSparkSchemaConverter(

private def convertPrimitiveField(
primitiveColumn: PrimitiveColumnIO,
sparkReadType: Option[DataType] = None): ParquetColumn = {
sparkReadType: Option[DataType] = None): CometParquetColumn = {
val parquetType = primitiveColumn.getType.asPrimitiveType()
val typeAnnotation = primitiveColumn.getType.getLogicalTypeAnnotation
val typeName = primitiveColumn.getPrimitive
Expand Down Expand Up @@ -138,7 +221,6 @@ class CometParquetToSparkSchemaConverter(
}
case timestamp: TimestampLogicalTypeAnnotation
if timestamp.getUnit == TimeUnit.MICROS || timestamp.getUnit == TimeUnit.MILLIS =>
val inferTimestampNTZ = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get
if (timestamp.isAdjustedToUTC || !inferTimestampNTZ) {
TimestampType
} else {
Expand All @@ -147,14 +229,14 @@ class CometParquetToSparkSchemaConverter(
// SPARK-40819: NANOS are not supported as a Timestamp, convert to LongType without
// timezone awareness to address behaviour regression introduced by SPARK-34661
case timestamp: TimestampLogicalTypeAnnotation
if timestamp.getUnit == TimeUnit.NANOS && SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get =>
if timestamp.getUnit == TimeUnit.NANOS && nanosAsLong =>
LongType
case _ => illegalType()
}

case INT96 =>
CometParquetSchemaConverter.checkConversionRequirement(
SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get,
assumeInt96IsTimestamp,
"INT96 is not supported unless it's interpreted as timestamp. " +
s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.")
TimestampType
Expand All @@ -163,7 +245,7 @@ class CometParquetToSparkSchemaConverter(
typeAnnotation match {
case _: StringLogicalTypeAnnotation | _: EnumLogicalTypeAnnotation |
_: JsonLogicalTypeAnnotation => StringType
case null if SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get => StringType
case null if assumeBinaryIsString => StringType
case null => BinaryType
case _: BsonLogicalTypeAnnotation => BinaryType
case _: DecimalLogicalTypeAnnotation => makeDecimalType()
Expand All @@ -183,12 +265,12 @@ class CometParquetToSparkSchemaConverter(
case _ => illegalType()
})

ParquetColumn(sparkType, primitiveColumn)
CometParquetColumn(sparkType, primitiveColumn)
}

private def convertGroupField(
groupColumn: GroupColumnIO,
sparkReadType: Option[DataType] = None): ParquetColumn = {
sparkReadType: Option[DataType] = None): CometParquetColumn = {
val field = groupColumn.getType.asGroupType()
Option(field.getLogicalTypeAnnotation).fold(
convertInternal(groupColumn, sparkReadType.map(_.asInstanceOf[StructType]))) {
Expand Down Expand Up @@ -218,7 +300,7 @@ class CometParquetToSparkSchemaConverter(
repeatedType.isRepetition(REPEATED), s"Invalid list type $field")
val sparkReadElementType = sparkReadType.map(_.asInstanceOf[ArrayType].elementType)

if (isElementType2(repeatedType, field.getName)) {
if (isElementType(repeatedType, field.getName)) {
var converted = convertField(repeated, sparkReadElementType)
val convertedType = sparkReadElementType.getOrElse(converted.sparkType)

Expand All @@ -229,14 +311,14 @@ class CometParquetToSparkSchemaConverter(
// we should mark the primitive field as required
if (repeatedType.isPrimitive) converted = converted.copy(required = true)

ParquetColumn(ArrayType(convertedType, containsNull = false),
CometParquetColumn(ArrayType(convertedType, containsNull = false),
groupColumn, Seq(converted))
} else {
val element = repeated.asInstanceOf[GroupColumnIO].getChild(0)
val converted = convertField(element, sparkReadElementType)
val convertedType = sparkReadElementType.getOrElse(converted.sparkType)
val optional = element.getType.isRepetition(OPTIONAL)
ParquetColumn(ArrayType(convertedType, containsNull = optional),
CometParquetColumn(ArrayType(convertedType, containsNull = optional),
groupColumn, Seq(converted))
}

Expand Down Expand Up @@ -267,7 +349,7 @@ class CometParquetToSparkSchemaConverter(
val convertedKeyType = sparkReadKeyType.getOrElse(convertedKey.sparkType)
val convertedValueType = sparkReadValueType.getOrElse(convertedValue.sparkType)
val valueOptional = value.getType.isRepetition(OPTIONAL)
ParquetColumn(
CometParquetColumn(
MapType(convertedKeyType, convertedValueType,
valueContainsNull = valueOptional),
groupColumn, Seq(convertedKey, convertedValue))
Expand All @@ -280,7 +362,7 @@ class CometParquetToSparkSchemaConverter(
// Here we implement Parquet LIST backwards-compatibility rules.
// See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules
// scalastyle:on
private[parquet] def isElementType2(repeatedType: Type, parentName: String): Boolean = {
private[parquet] def isElementType(repeatedType: Type, parentName: String): Boolean = {
{
// For legacy 2-level list types with primitive element type, e.g.:
//
Expand Down Expand Up @@ -327,70 +409,6 @@ class CometParquetToSparkSchemaConverter(
repeatedType.getName == s"${parentName}_tuple"
}
}


private def convertInternal(
groupColumn: GroupColumnIO,
sparkReadSchema: Option[StructType] = None): ParquetColumn = {
// First convert the read schema into a map from field name to the field itself, to avoid O(n)
// lookup cost below.
val schemaMapOpt = sparkReadSchema.map { schema =>
schema.map(f => normalizeFieldName(f.name) -> f).toMap
}

val converted = (0 until groupColumn.getChildrenCount).map { i =>
val field = groupColumn.getChild(i)
val fieldFromReadSchema = schemaMapOpt.flatMap { schemaMap =>
schemaMap.get(normalizeFieldName(field.getName))
}
var fieldReadType = fieldFromReadSchema.map(_.dataType)

// If a field is repeated here then it is neither contained by a `LIST` nor `MAP`
// annotated group (these should've been handled in `convertGroupField`), e.g.:
//
// message schema {
// repeated int32 int_array;
// }
// or
// message schema {
// repeated group struct_array {
// optional int32 field;
// }
// }
//
// the corresponding Spark read type should be an array and we should pass the element type
// to the group or primitive type conversion method.
if (field.getType.getRepetition == REPEATED) {
fieldReadType = fieldReadType.flatMap {
case at: ArrayType => Some(at.elementType)
case _ =>
throw new UnsupportedOperationException("Illegal Parquet type " + groupColumn.toString)
}
}

val convertedField = convertField(field, fieldReadType)
val fieldName = fieldFromReadSchema.map(_.name).getOrElse(field.getType.getName)

field.getType.getRepetition match {
case OPTIONAL | REQUIRED =>
val nullable = field.getType.getRepetition == OPTIONAL
(StructField(fieldName, convertedField.sparkType, nullable = nullable),
convertedField)

case REPEATED =>
// A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor
// annotated by `LIST` or `MAP` should be interpreted as a required list of required
// elements where the element type is the type of the field.
val arrayType = ArrayType(convertedField.sparkType, containsNull = false)
(StructField(fieldName, arrayType, nullable = false),
ParquetColumn(arrayType, None, convertedField.repetitionLevel - 1,
convertedField.definitionLevel - 1, required = true, convertedField.path,
Seq(convertedField.copy(required = true))))
}
}

ParquetColumn(StructType(converted.map(_._1)), groupColumn, converted.map(_._2))
}
}

private object CometParquetSchemaConverter {
Expand Down

0 comments on commit 6d87d34

Please sign in to comment.