diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 3889de84900..91e7d6cde62 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -72,7 +72,6 @@ public static Map toJavaMap(MapValue mapValue) { return values; } - /** * Creates a {@link MapValue} from map of string keys and string values. The type {@code * map(string -> string)} is a common occurrence in Delta Log schema. @@ -105,9 +104,7 @@ public ColumnVector getValues() { }; } - /** - * Creates an {@link ArrayValue} from list of objects. - */ + /** Creates an {@link ArrayValue} from list of objects. */ public static ArrayValue buildArrayValue(List values, DataType dataType) { if (values == null) { return null; @@ -175,13 +172,13 @@ public short getShort(int rowId) { @Override public int getInt(int rowId) { - checkArgument(IntegerType.INTEGER.equals(dataType)); + checkArgument(IntegerType.INTEGER.equals(dataType) || DateType.DATE.equals(dataType)); return (Integer) getValidatedValue(rowId, Integer.class); } @Override public long getLong(int rowId) { - checkArgument(LongType.LONG.equals(dataType)); + checkArgument(LongType.LONG.equals(dataType) || TimestampType.TIMESTAMP.equals(dataType)); return (Long) getValidatedValue(rowId, Long.class); } @@ -245,16 +242,17 @@ private void validateRowId(int rowId) { private Object getValidatedValue(int rowId, Class expectedType) { validateRowId(rowId); Object value = values.get(rowId); - checkArgument(expectedType.isInstance(value), - "Value must be of type %s", expectedType.getSimpleName()); + checkArgument( + expectedType.isInstance(value), + "Value must be of type %s", + expectedType.getSimpleName()); return value; } - private List extractChildValues(int ordinal, DataType childDatatype) { return values.stream() - .map(e -> extractChildValue(e, ordinal, childDatatype)) - .collect(Collectors.toList()); + .map(e -> extractChildValue(e, ordinal, childDatatype)) + .collect(Collectors.toList()); } private Object extractChildValue(Object element, int ordinal, DataType childDatatype) { @@ -270,28 +268,52 @@ private Object extractChildValue(Object element, int ordinal, DataType childData private Object extractTypedValue(Row row, int ordinal, DataType childDatatype) { // Primitive Types - if (childDatatype instanceof BooleanType) return row.getBoolean(ordinal); - if (childDatatype instanceof ByteType) return row.getByte(ordinal); - if (childDatatype instanceof ShortType) return row.getShort(ordinal); - if (childDatatype instanceof IntegerType || - childDatatype instanceof DateType) return row.getInt(ordinal); - if (childDatatype instanceof LongType || - childDatatype instanceof TimestampType) return row.getLong(ordinal); - if (childDatatype instanceof FloatType) return row.getFloat(ordinal); - if (childDatatype instanceof DoubleType) return row.getDouble(ordinal); + if (childDatatype instanceof BooleanType) { + return row.getBoolean(ordinal); + } + if (childDatatype instanceof ByteType) { + return row.getByte(ordinal); + } + if (childDatatype instanceof ShortType) { + return row.getShort(ordinal); + } + if (childDatatype instanceof IntegerType || childDatatype instanceof DateType) { + return row.getInt(ordinal); + } + if (childDatatype instanceof LongType || childDatatype instanceof TimestampType) { + return row.getLong(ordinal); + } + if (childDatatype instanceof FloatType) { + return row.getFloat(ordinal); + } + if (childDatatype instanceof DoubleType) { + return row.getDouble(ordinal); + } // Complex Types - if (childDatatype instanceof StringType) return row.getString(ordinal); - if (childDatatype instanceof BinaryType) return row.getBinary(ordinal); - if (childDatatype instanceof DecimalType) return row.getDecimal(ordinal); + if (childDatatype instanceof StringType) { + return row.getString(ordinal); + } + if (childDatatype instanceof BinaryType) { + return row.getBinary(ordinal); + } + if (childDatatype instanceof DecimalType) { + return row.getDecimal(ordinal); + } // Nested Types - if (childDatatype instanceof StructType) return row.getStruct(ordinal); - if (childDatatype instanceof ArrayType) return row.getArray(ordinal); - if (childDatatype instanceof MapType) return row.getMap(ordinal); + if (childDatatype instanceof StructType) { + return row.getStruct(ordinal); + } + if (childDatatype instanceof ArrayType) { + return row.getArray(ordinal); + } + if (childDatatype instanceof MapType) { + return row.getMap(ordinal); + } throw new UnsupportedOperationException( - String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName())); + String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName())); } }; } diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala index 6ca9cee8fa8..4bfb1e1f41c 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/util/VectorUtilsSuite.scala @@ -16,23 +16,133 @@ package io.delta.kernel.internal.util +import java.sql.{Date, Timestamp} import io.delta.kernel.test.VectorTestUtils -import io.delta.kernel.types.BooleanType +import io.delta.kernel.types.{ + BinaryType, + BooleanType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + LongType, + ShortType, + StringType, + TimestampType +} + +import java.lang.{ + Boolean => BooleanJ, + Byte => ByteJ, + Double => DoubleJ, + Float => FloatJ, + Integer => IntegerJ, + Long => LongJ, + Short => ShortJ +} +import scala.collection.JavaConverters._ import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.prop.Tables.Table -import java.lang.{Boolean => BooleanJ} -import java.util +import java.math.BigDecimal class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils { - test("test build column vector from list of primitives") { - checkVectors(booleanVector(Seq[BooleanJ](true, false, null)), - VectorUtils.buildColumnVector(util.Arrays.asList(true, false, null), BooleanType.BOOLEAN), - BooleanType.BOOLEAN, - (vec, id) => vec.getBoolean(id) + Table( + ("values", "dataType"), + (List[ByteJ](1.toByte, 2.toByte, 3.toByte, null), ByteType.BYTE), + (List[ShortJ](1.toShort, 2.toShort, 3.toShort, null), ShortType.SHORT), + (List[IntegerJ](1, 2, 3, null), IntegerType.INTEGER), + (List[LongJ](1L, 2L, 3L, null), LongType.LONG), + (List[FloatJ](1.0f, 2.0f, 3.0f, null), FloatType.FLOAT), + (List[DoubleJ](1.0, 2.0, 3.0, null), DoubleType.DOUBLE), + (List[Array[Byte]]("one".getBytes, "two".getBytes, "three".getBytes, null), BinaryType.BINARY), + (List[BooleanJ](true, false, false, null), BooleanType.BOOLEAN), + ( + List[BigDecimal](new BigDecimal("1"), new BigDecimal("2"), new BigDecimal("3"), null), + new DecimalType(10, 2) + ), + (List[String]("one", "two", "three", null), StringType.STRING), + ( + List[IntegerJ](10, 20, 30, null), + DateType.DATE + ), + ( + List[LongJ]( + Timestamp.valueOf("2023-01-01 00:00:00").getTime, + Timestamp.valueOf("2023-01-02 00:00:00").getTime, + Timestamp.valueOf("2023-01-03 00:00:00").getTime, + null + ), + TimestampType.TIMESTAMP ) + ).foreach( + testCase => + test(s"handle ${testCase._2} array correctly") { + val values = testCase._1 + val dataType = testCase._2 + val columnVector = VectorUtils.buildColumnVector(values.asJava, dataType) + assert(columnVector.getSize == 4) - - } - + dataType match { + case ByteType.BYTE => + assert(columnVector.getByte(0) == 1.toByte) + assert(columnVector.getByte(1) == 2.toByte) + assert(columnVector.getByte(2) == 3.toByte) + case ShortType.SHORT => + assert(columnVector.getShort(0) == 1.toShort) + assert(columnVector.getShort(1) == 2.toShort) + assert(columnVector.getShort(2) == 3.toShort) + case IntegerType.INTEGER => + assert(columnVector.getInt(0) == 1) + assert(columnVector.getInt(1) == 2) + assert(columnVector.getInt(2) == 3) + case LongType.LONG => + assert(columnVector.getLong(0) == 1L) + assert(columnVector.getLong(1) == 2L) + assert(columnVector.getLong(2) == 3L) + case FloatType.FLOAT => + assert(columnVector.getFloat(0) == 1.0f) + assert(columnVector.getFloat(1) == 2.0f) + assert(columnVector.getFloat(2) == 3.0f) + case DoubleType.DOUBLE => + assert(columnVector.getDouble(0) == 1.0) + assert(columnVector.getDouble(1) == 2.0) + assert(columnVector.getDouble(2) == 3.0) + case BooleanType.BOOLEAN => + assert(columnVector.getBoolean(0)) + assert(!columnVector.getBoolean(1)) + assert(!columnVector.getBoolean(2)) + case _: DecimalType => + assert(columnVector.getDecimal(0) == new BigDecimal("1")) + assert(columnVector.getDecimal(1) == new BigDecimal("2")) + assert(columnVector.getDecimal(2) == new BigDecimal("3")) + case BinaryType.BINARY => + assert(columnVector.getBinary(0) sameElements "one".getBytes) + assert(columnVector.getBinary(1) sameElements "two".getBytes) + assert(columnVector.getBinary(2) sameElements "three".getBytes) + case StringType.STRING => + assert(columnVector.getString(0) == "one") + assert(columnVector.getString(1) == "two") + assert(columnVector.getString(2) == "three") + case DateType.DATE => + assert(columnVector.getInt(0) == 10) + assert(columnVector.getInt(1) == 20) + assert(columnVector.getInt(2) == 30) + case TimestampType.TIMESTAMP => + assert( + columnVector.getLong(0) == Timestamp.valueOf("2023-01-01 00:00:00").getTime + ) + assert( + columnVector.getLong(1) == Timestamp.valueOf("2023-01-02 00:00:00").getTime + ) + assert( + columnVector.getLong(2) == Timestamp.valueOf("2023-01-03 00:00:00").getTime + ) + } + assert(columnVector.isNullAt(3)) + } + ) }