Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
huan233usc committed Feb 15, 2025
1 parent 3f535f7 commit e4d11c1
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public static <K, V> Map<K, V> toJavaMap(MapValue mapValue) {
return values;
}


/**
* Creates a {@link MapValue} from map of string keys and string values. The type {@code
* map(string -> string)} is a common occurrence in Delta Log schema.
Expand Down Expand Up @@ -105,9 +104,7 @@ public ColumnVector getValues() {
};
}

/**
* Creates an {@link ArrayValue} from list of objects.
*/
/** Creates an {@link ArrayValue} from list of objects. */
public static ArrayValue buildArrayValue(List<?> values, DataType dataType) {
if (values == null) {
return null;
Expand Down Expand Up @@ -175,13 +172,13 @@ public short getShort(int rowId) {

@Override
public int getInt(int rowId) {
checkArgument(IntegerType.INTEGER.equals(dataType));
checkArgument(IntegerType.INTEGER.equals(dataType) || DateType.DATE.equals(dataType));
return (Integer) getValidatedValue(rowId, Integer.class);
}

@Override
public long getLong(int rowId) {
checkArgument(LongType.LONG.equals(dataType));
checkArgument(LongType.LONG.equals(dataType) || TimestampType.TIMESTAMP.equals(dataType));
return (Long) getValidatedValue(rowId, Long.class);
}

Expand Down Expand Up @@ -245,16 +242,17 @@ private void validateRowId(int rowId) {
private Object getValidatedValue(int rowId, Class<?> expectedType) {
validateRowId(rowId);
Object value = values.get(rowId);
checkArgument(expectedType.isInstance(value),
"Value must be of type %s", expectedType.getSimpleName());
checkArgument(
expectedType.isInstance(value),
"Value must be of type %s",
expectedType.getSimpleName());
return value;
}


private List<?> extractChildValues(int ordinal, DataType childDatatype) {
return values.stream()
.map(e -> extractChildValue(e, ordinal, childDatatype))
.collect(Collectors.toList());
.map(e -> extractChildValue(e, ordinal, childDatatype))
.collect(Collectors.toList());
}

private Object extractChildValue(Object element, int ordinal, DataType childDatatype) {
Expand All @@ -270,28 +268,52 @@ private Object extractChildValue(Object element, int ordinal, DataType childData

private Object extractTypedValue(Row row, int ordinal, DataType childDatatype) {
// Primitive Types
if (childDatatype instanceof BooleanType) return row.getBoolean(ordinal);
if (childDatatype instanceof ByteType) return row.getByte(ordinal);
if (childDatatype instanceof ShortType) return row.getShort(ordinal);
if (childDatatype instanceof IntegerType ||
childDatatype instanceof DateType) return row.getInt(ordinal);
if (childDatatype instanceof LongType ||
childDatatype instanceof TimestampType) return row.getLong(ordinal);
if (childDatatype instanceof FloatType) return row.getFloat(ordinal);
if (childDatatype instanceof DoubleType) return row.getDouble(ordinal);
if (childDatatype instanceof BooleanType) {
return row.getBoolean(ordinal);
}
if (childDatatype instanceof ByteType) {
return row.getByte(ordinal);
}
if (childDatatype instanceof ShortType) {
return row.getShort(ordinal);
}
if (childDatatype instanceof IntegerType || childDatatype instanceof DateType) {
return row.getInt(ordinal);
}
if (childDatatype instanceof LongType || childDatatype instanceof TimestampType) {
return row.getLong(ordinal);
}
if (childDatatype instanceof FloatType) {
return row.getFloat(ordinal);
}
if (childDatatype instanceof DoubleType) {
return row.getDouble(ordinal);
}

// Complex Types
if (childDatatype instanceof StringType) return row.getString(ordinal);
if (childDatatype instanceof BinaryType) return row.getBinary(ordinal);
if (childDatatype instanceof DecimalType) return row.getDecimal(ordinal);
if (childDatatype instanceof StringType) {
return row.getString(ordinal);
}
if (childDatatype instanceof BinaryType) {
return row.getBinary(ordinal);
}
if (childDatatype instanceof DecimalType) {
return row.getDecimal(ordinal);
}

// Nested Types
if (childDatatype instanceof StructType) return row.getStruct(ordinal);
if (childDatatype instanceof ArrayType) return row.getArray(ordinal);
if (childDatatype instanceof MapType) return row.getMap(ordinal);
if (childDatatype instanceof StructType) {
return row.getStruct(ordinal);
}
if (childDatatype instanceof ArrayType) {
return row.getArray(ordinal);
}
if (childDatatype instanceof MapType) {
return row.getMap(ordinal);
}

throw new UnsupportedOperationException(
String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName()));
String.format("Unsupported data type: %s", childDatatype.getClass().getSimpleName()));
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,133 @@

package io.delta.kernel.internal.util

import java.sql.{Date, Timestamp}
import io.delta.kernel.test.VectorTestUtils
import io.delta.kernel.types.BooleanType
import io.delta.kernel.types.{
BinaryType,
BooleanType,
ByteType,
DateType,
DecimalType,
DoubleType,
FloatType,
IntegerType,
LongType,
ShortType,
StringType,
TimestampType
}

import java.lang.{
Boolean => BooleanJ,
Byte => ByteJ,
Double => DoubleJ,
Float => FloatJ,
Integer => IntegerJ,
Long => LongJ,
Short => ShortJ
}
import scala.collection.JavaConverters._
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.prop.Tables.Table

import java.lang.{Boolean => BooleanJ}
import java.util
import java.math.BigDecimal

class VectorUtilsSuite extends AnyFunSuite with VectorTestUtils {

test("test build column vector from list of primitives") {
checkVectors(booleanVector(Seq[BooleanJ](true, false, null)),
VectorUtils.buildColumnVector(util.Arrays.asList(true, false, null), BooleanType.BOOLEAN),
BooleanType.BOOLEAN,
(vec, id) => vec.getBoolean(id)
Table(
("values", "dataType"),
(List[ByteJ](1.toByte, 2.toByte, 3.toByte, null), ByteType.BYTE),
(List[ShortJ](1.toShort, 2.toShort, 3.toShort, null), ShortType.SHORT),
(List[IntegerJ](1, 2, 3, null), IntegerType.INTEGER),
(List[LongJ](1L, 2L, 3L, null), LongType.LONG),
(List[FloatJ](1.0f, 2.0f, 3.0f, null), FloatType.FLOAT),
(List[DoubleJ](1.0, 2.0, 3.0, null), DoubleType.DOUBLE),
(List[Array[Byte]]("one".getBytes, "two".getBytes, "three".getBytes, null), BinaryType.BINARY),
(List[BooleanJ](true, false, false, null), BooleanType.BOOLEAN),
(
List[BigDecimal](new BigDecimal("1"), new BigDecimal("2"), new BigDecimal("3"), null),
new DecimalType(10, 2)
),
(List[String]("one", "two", "three", null), StringType.STRING),
(
List[IntegerJ](10, 20, 30, null),
DateType.DATE
),
(
List[LongJ](
Timestamp.valueOf("2023-01-01 00:00:00").getTime,
Timestamp.valueOf("2023-01-02 00:00:00").getTime,
Timestamp.valueOf("2023-01-03 00:00:00").getTime,
null
),
TimestampType.TIMESTAMP
)
).foreach(
testCase =>
test(s"handle ${testCase._2} array correctly") {
val values = testCase._1
val dataType = testCase._2
val columnVector = VectorUtils.buildColumnVector(values.asJava, dataType)
assert(columnVector.getSize == 4)


}

dataType match {
case ByteType.BYTE =>
assert(columnVector.getByte(0) == 1.toByte)
assert(columnVector.getByte(1) == 2.toByte)
assert(columnVector.getByte(2) == 3.toByte)
case ShortType.SHORT =>
assert(columnVector.getShort(0) == 1.toShort)
assert(columnVector.getShort(1) == 2.toShort)
assert(columnVector.getShort(2) == 3.toShort)
case IntegerType.INTEGER =>
assert(columnVector.getInt(0) == 1)
assert(columnVector.getInt(1) == 2)
assert(columnVector.getInt(2) == 3)
case LongType.LONG =>
assert(columnVector.getLong(0) == 1L)
assert(columnVector.getLong(1) == 2L)
assert(columnVector.getLong(2) == 3L)
case FloatType.FLOAT =>
assert(columnVector.getFloat(0) == 1.0f)
assert(columnVector.getFloat(1) == 2.0f)
assert(columnVector.getFloat(2) == 3.0f)
case DoubleType.DOUBLE =>
assert(columnVector.getDouble(0) == 1.0)
assert(columnVector.getDouble(1) == 2.0)
assert(columnVector.getDouble(2) == 3.0)
case BooleanType.BOOLEAN =>
assert(columnVector.getBoolean(0))
assert(!columnVector.getBoolean(1))
assert(!columnVector.getBoolean(2))
case _: DecimalType =>
assert(columnVector.getDecimal(0) == new BigDecimal("1"))
assert(columnVector.getDecimal(1) == new BigDecimal("2"))
assert(columnVector.getDecimal(2) == new BigDecimal("3"))
case BinaryType.BINARY =>
assert(columnVector.getBinary(0) sameElements "one".getBytes)
assert(columnVector.getBinary(1) sameElements "two".getBytes)
assert(columnVector.getBinary(2) sameElements "three".getBytes)
case StringType.STRING =>
assert(columnVector.getString(0) == "one")
assert(columnVector.getString(1) == "two")
assert(columnVector.getString(2) == "three")
case DateType.DATE =>
assert(columnVector.getInt(0) == 10)
assert(columnVector.getInt(1) == 20)
assert(columnVector.getInt(2) == 30)
case TimestampType.TIMESTAMP =>
assert(
columnVector.getLong(0) == Timestamp.valueOf("2023-01-01 00:00:00").getTime
)
assert(
columnVector.getLong(1) == Timestamp.valueOf("2023-01-02 00:00:00").getTime
)
assert(
columnVector.getLong(2) == Timestamp.valueOf("2023-01-03 00:00:00").getTime
)
}
assert(columnVector.isNullAt(3))
}
)
}

0 comments on commit e4d11c1

Please sign in to comment.