Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for reading variant type in Delta Lake #22403

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static io.trino.metastore.type.TypeConstants.TIMESTAMPLOCALTZ_TYPE_NAME;
import static io.trino.metastore.type.TypeConstants.TIMESTAMP_TYPE_NAME;
import static io.trino.metastore.type.TypeConstants.TINYINT_TYPE_NAME;
import static io.trino.metastore.type.TypeConstants.VARIANT_TYPE_NAME;
import static io.trino.metastore.type.TypeInfoFactory.getPrimitiveTypeInfo;
import static io.trino.metastore.type.TypeInfoUtils.getTypeInfoFromTypeString;
import static io.trino.metastore.type.TypeInfoUtils.getTypeInfosFromTypeString;
Expand All @@ -55,6 +56,7 @@ public final class HiveType
public static final HiveType HIVE_TIMESTAMPLOCALTZ = new HiveType(getPrimitiveTypeInfo(TIMESTAMPLOCALTZ_TYPE_NAME));
public static final HiveType HIVE_DATE = new HiveType(getPrimitiveTypeInfo(DATE_TYPE_NAME));
public static final HiveType HIVE_BINARY = new HiveType(getPrimitiveTypeInfo(BINARY_TYPE_NAME));
public static final HiveType HIVE_VARIANT = new HiveType(getPrimitiveTypeInfo(VARIANT_TYPE_NAME));

private final HiveTypeName hiveTypeName;
private final TypeInfo typeInfo;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ public enum PrimitiveCategory
{
VOID, BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, STRING,
DATE, TIMESTAMP, TIMESTAMPLOCALTZ, BINARY, DECIMAL, VARCHAR, CHAR,
INTERVAL_YEAR_MONTH, INTERVAL_DAY_TIME, UNKNOWN
INTERVAL_YEAR_MONTH, INTERVAL_DAY_TIME, VARIANT, UNKNOWN
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ private TypeConstants() {}
public static final String BINARY_TYPE_NAME = "binary";
public static final String INTERVAL_YEAR_MONTH_TYPE_NAME = "interval_year_month";
public static final String INTERVAL_DAY_TIME_TYPE_NAME = "interval_day_time";
public static final String VARIANT_TYPE_NAME = "variant";

public static final String LIST_TYPE_NAME = "array";
public static final String MAP_TYPE_NAME = "map";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ private TypeInfoUtils() {}
registerType(new PrimitiveTypeEntry(PrimitiveCategory.INTERVAL_YEAR_MONTH, TypeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME));
registerType(new PrimitiveTypeEntry(PrimitiveCategory.INTERVAL_DAY_TIME, TypeConstants.INTERVAL_DAY_TIME_TYPE_NAME));
registerType(new PrimitiveTypeEntry(PrimitiveCategory.DECIMAL, TypeConstants.DECIMAL_TYPE_NAME));
registerType(new PrimitiveTypeEntry(PrimitiveCategory.VARIANT, TypeConstants.VARIANT_TYPE_NAME));
registerType(new PrimitiveTypeEntry(PrimitiveCategory.UNKNOWN, "unknown"));
}

Expand Down
5 changes: 5 additions & 0 deletions lib/trino-parquet/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
<description>Trino - Parquet file format support</description>

<dependencies>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
</dependency>

<dependency>
<groupId>com.google.errorprone</groupId>
<artifactId>error_prone_annotations</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@
import java.util.Map;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.StandardTypes.JSON;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static java.lang.String.format;
import static org.apache.parquet.schema.Type.Repetition.OPTIONAL;
import static org.apache.parquet.schema.Type.Repetition.REPEATED;
Expand Down Expand Up @@ -291,6 +294,15 @@ private static Optional<Field> constructField(Type type, ColumnIO columnIO, bool
boolean required = columnIO.getType().getRepetition() != OPTIONAL;
int repetitionLevel = columnIO.getRepetitionLevel();
int definitionLevel = columnIO.getDefinitionLevel();
if (isVariantType(type, columnIO)) {
checkArgument(type.getTypeParameters().isEmpty(), "Expected type parameters to be empty for variant but got %s", type.getTypeParameters());
if (!(columnIO instanceof GroupColumnIO groupColumnIo)) {
throw new IllegalStateException("Expected columnIO to be GroupColumnIO but got %s".formatted(columnIO.getClass().getSimpleName()));
}
Field valueField = constructField(VARBINARY, groupColumnIo.getChild(0), false).orElseThrow();
Field metadataField = constructField(VARBINARY, groupColumnIo.getChild(1), false).orElseThrow();
return Optional.of(new VariantField(type, repetitionLevel, definitionLevel, required, valueField, metadataField));
}
if (type instanceof RowType rowType) {
GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO;
ImmutableList.Builder<Optional<Field>> fieldsBuilder = ImmutableList.builder();
Expand Down Expand Up @@ -350,4 +362,13 @@ private static Optional<Field> constructField(Type type, ColumnIO columnIO, bool
}
return Optional.of(new PrimitiveField(type, required, primitiveColumnIO.getColumnDescriptor(), primitiveColumnIO.getId()));
}

private static boolean isVariantType(Type type, ColumnIO columnIO)
{
return type.getTypeSignature().getBase().equals(JSON) &&
columnIO instanceof GroupColumnIO groupColumnIo &&
groupColumnIo.getChildrenCount() == 2 &&
groupColumnIo.getChild("value") != null &&
groupColumnIo.getChild("metadata") != null;
}
}
56 changes: 56 additions & 0 deletions lib/trino-parquet/src/main/java/io/trino/parquet/VariantField.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.parquet;

import io.trino.spi.type.Type;

import static com.google.common.base.MoreObjects.toStringHelper;
import static java.util.Objects.requireNonNull;

public class VariantField
extends Field
{
private final Field value;
private final Field metadata;

public VariantField(Type type, int repetitionLevel, int definitionLevel, boolean required, Field value, Field metadata)
{
super(type, repetitionLevel, definitionLevel, required);
this.value = requireNonNull(value, "value is null");
this.metadata = requireNonNull(metadata, "metadata is null");
}

public Field getValue()
{
return value;
}

public Field getMetadata()
{
return metadata;
}

@Override
public String toString()
{
return toStringHelper(this)
.add("type", getType())
.add("repetitionLevel", getRepetitionLevel())
.add("definitionLevel", getDefinitionLevel())
.add("required", isRequired())
.add("value", value)
.add("metadata", getMetadata())
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.collect.ListMultimap;
import com.google.errorprone.annotations.FormatMethod;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.parquet.ChunkKey;
import io.trino.parquet.Column;
Expand All @@ -30,14 +31,17 @@
import io.trino.parquet.ParquetReaderOptions;
import io.trino.parquet.ParquetWriteValidation;
import io.trino.parquet.PrimitiveField;
import io.trino.parquet.VariantField;
import io.trino.parquet.metadata.ColumnChunkMetadata;
import io.trino.parquet.metadata.PrunedBlockMetadata;
import io.trino.parquet.predicate.TupleDomainParquetPredicate;
import io.trino.parquet.reader.FilteredOffsetIndex.OffsetRange;
import io.trino.parquet.spark.Variant;
import io.trino.plugin.base.metrics.LongCount;
import io.trino.spi.Page;
import io.trino.spi.block.ArrayBlock;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.RowBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
Expand All @@ -59,6 +63,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.time.ZoneId;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -69,13 +74,16 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.airlift.slice.Slices.utf8Slice;
import static io.trino.parquet.ParquetValidationUtils.validateParquet;
import static io.trino.parquet.ParquetWriteValidation.StatisticsValidation;
import static io.trino.parquet.ParquetWriteValidation.StatisticsValidation.createStatisticsValidationBuilder;
import static io.trino.parquet.ParquetWriteValidation.WriteChecksumBuilder;
import static io.trino.parquet.ParquetWriteValidation.WriteChecksumBuilder.createWriteChecksumBuilder;
import static io.trino.parquet.reader.ListColumnReader.calculateCollectionOffsets;
import static io.trino.parquet.reader.PageReader.createPageReader;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.lang.Math.max;
import static java.lang.Math.min;
import static java.lang.Math.toIntExact;
Expand All @@ -97,6 +105,7 @@ public class ParquetReader
private final List<Column> columnFields;
private final List<PrimitiveField> primitiveFields;
private final ParquetDataSource dataSource;
private final ZoneId zoneId;
private final ColumnReaderFactory columnReaderFactory;
private final AggregatedMemoryContext memoryContext;

Expand Down Expand Up @@ -149,6 +158,7 @@ public ParquetReader(
this.primitiveFields = getPrimitiveFields(columnFields.stream().map(Column::field).collect(toImmutableList()));
this.rowGroups = requireNonNull(rowGroups, "rowGroups is null");
this.dataSource = requireNonNull(dataSource, "dataSource is null");
this.zoneId = requireNonNull(timeZone, "timeZone is null").toTimeZone().toZoneId();
this.columnReaderFactory = new ColumnReaderFactory(timeZone, options);
this.memoryContext = requireNonNull(memoryContext, "memoryContext is null");
this.currentRowGroupMemoryContext = memoryContext.newAggregatedMemoryContext();
Expand Down Expand Up @@ -332,6 +342,25 @@ private void freeCurrentRowGroupBuffers()
}
}

private ColumnChunk readVariant(VariantField field)
throws IOException
{
ColumnChunk valueChunk = readColumnChunk(field.getValue());

BlockBuilder variantBlock = VARCHAR.createBlockBuilder(null, 1);
if (valueChunk.getBlock().getPositionCount() == 0) {
variantBlock.appendNull();
}
else {
ColumnChunk metadataChunk = readColumnChunk(field.getMetadata());
Slice value = VARBINARY.getSlice(valueChunk.getBlock(), 0);
Slice metadata = VARBINARY.getSlice(metadataChunk.getBlock(), 0);
Variant variant = new Variant(value.byteArray(), metadata.byteArray());
VARCHAR.writeSlice(variantBlock, utf8Slice(variant.toJson(zoneId)));
}
return new ColumnChunk(variantBlock.build(), valueChunk.getDefinitionLevels(), valueChunk.getRepetitionLevels());
}

private ColumnChunk readArray(GroupField field)
throws IOException
{
Expand Down Expand Up @@ -523,6 +552,10 @@ else if (field instanceof GroupField groupField) {
.flatMap(Optional::stream)
.forEach(child -> parseField(child, primitiveFields));
}
else if (field instanceof VariantField variantField) {
parseField(variantField.getValue(), primitiveFields);
parseField(variantField.getMetadata(), primitiveFields);
}
}

public Block readBlock(Field field)
Expand All @@ -535,7 +568,10 @@ private ColumnChunk readColumnChunk(Field field)
throws IOException
{
ColumnChunk columnChunk;
if (field.getType() instanceof RowType) {
if (field instanceof VariantField variantField) {
columnChunk = readVariant(variantField);
}
else if (field.getType() instanceof RowType) {
columnChunk = readStruct((GroupField) field);
}
else if (field.getType() instanceof MapType) {
Expand Down
Loading
Loading