Skip to content

Commit

Permalink
BugFixing for ArrayList elements() usage
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 committed Jun 10, 2024
1 parent 8256b27 commit 3d73d2b
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ private static int[] toIntArray(Object value) {
return (int[]) value;
} else if (value instanceof IntArrayList) {
// For ArrayAggregationFunction
return ((IntArrayList) value).elements();
return ((IntArrayList) value).toIntArray();
}
throw new IllegalStateException(String.format("Cannot convert: '%s' to int[]", value));
}
Expand All @@ -533,7 +533,7 @@ private static float[] toFloatArray(Object value) {
return (float[]) value;
} else if (value instanceof FloatArrayList) {
// For ArrayAggregationFunction
return ((FloatArrayList) value).elements();
return ((FloatArrayList) value).toFloatArray();
}
throw new IllegalStateException(String.format("Cannot convert: '%s' to float[]", value));
}
Expand All @@ -543,7 +543,7 @@ private static double[] toDoubleArray(Object value) {
return (double[]) value;
} else if (value instanceof DoubleArrayList) {
// For HistogramAggregationFunction and ArrayAggregationFunction
return ((DoubleArrayList) value).elements();
return ((DoubleArrayList) value).toDoubleArray();
} else if (value instanceof int[]) {
int[] intValues = (int[]) value;
int length = intValues.length;
Expand Down Expand Up @@ -576,7 +576,7 @@ private static long[] toLongArray(Object value) {
return (long[]) value;
} else if (value instanceof LongArrayList) {
// For FunnelCountAggregationFunction and ArrayAggregationFunction
return ((LongArrayList) value).elements();
return ((LongArrayList) value).toLongArray();
} else {
int[] intValues = (int[]) value;
int length = intValues.length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,19 @@ private void setFinalResult(DataTableBuilder dataTableBuilder, ColumnDataType[]
dataTableBuilder.setColumn(index, (ByteArray) result);
break;
case INT_ARRAY:
dataTableBuilder.setColumn(index, ((IntArrayList) result).elements());
dataTableBuilder.setColumn(index, ((IntArrayList) result).toIntArray());
break;
case LONG_ARRAY:
dataTableBuilder.setColumn(index, ((LongArrayList) result).elements());
dataTableBuilder.setColumn(index, ((LongArrayList) result).toLongArray());
break;
case FLOAT_ARRAY:
dataTableBuilder.setColumn(index, ((FloatArrayList) result).elements());
dataTableBuilder.setColumn(index, ((FloatArrayList) result).toFloatArray());
break;
case DOUBLE_ARRAY:
dataTableBuilder.setColumn(index, ((DoubleArrayList) result).elements());
dataTableBuilder.setColumn(index, ((DoubleArrayList) result).toDoubleArray());
break;
case STRING_ARRAY:
dataTableBuilder.setColumn(index, ((ObjectArrayList<String>) result).elements());
dataTableBuilder.setColumn(index, ((ObjectArrayList<String>) result).toArray(new String[0]));
break;
default:
throw new IllegalStateException("Illegal column data type in final result: " + columnDataType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,28 +251,28 @@ private void setDataTableColumn(ColumnDataType storedColumnDataType, DataTableBu
break;
case INT_ARRAY:
if (value instanceof IntArrayList) {
dataTableBuilder.setColumn(columnIndex, ((IntArrayList) value).elements());
dataTableBuilder.setColumn(columnIndex, ((IntArrayList) value).toIntArray());
} else {
dataTableBuilder.setColumn(columnIndex, (int[]) value);
}
break;
case LONG_ARRAY:
if (value instanceof LongArrayList) {
dataTableBuilder.setColumn(columnIndex, ((LongArrayList) value).elements());
dataTableBuilder.setColumn(columnIndex, ((LongArrayList) value).toLongArray());
} else {
dataTableBuilder.setColumn(columnIndex, (long[]) value);
}
break;
case FLOAT_ARRAY:
if (value instanceof FloatArrayList) {
dataTableBuilder.setColumn(columnIndex, ((FloatArrayList) value).elements());
dataTableBuilder.setColumn(columnIndex, ((FloatArrayList) value).toFloatArray());
} else {
dataTableBuilder.setColumn(columnIndex, (float[]) value);
}
break;
case DOUBLE_ARRAY:
if (value instanceof DoubleArrayList) {
dataTableBuilder.setColumn(columnIndex, ((DoubleArrayList) value).elements());
dataTableBuilder.setColumn(columnIndex, ((DoubleArrayList) value).toDoubleArray());
} else {
dataTableBuilder.setColumn(columnIndex, (double[]) value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.pinot.core.util.DoubleComparisonUtil;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.Schema;
import org.testng.annotations.Test;
Expand All @@ -48,6 +49,8 @@ public class ArrayTest extends CustomDataQueryClusterIntegrationTest {
private static final String STRING_COLUMN = "stringCol";
private static final String TIMESTAMP_COLUMN = "timestampCol";
private static final String GROUP_BY_COLUMN = "groupKey";
private static final String LONG_ARRAY_COLUMN = "longArrayCol";
private static final String DOUBLE_ARRAY_COLUMN = "doubleArrayCol";

@Override
protected long getCountStarResult() {
Expand Down Expand Up @@ -460,6 +463,35 @@ public void testLongArrayLiteral(boolean useMultiStageQueryEngine)
}
}

@Test(dataProvider = "useBothQueryEngines")
public void testArraySum(boolean useMultiStageQueryEngine)
throws Exception {
setUseMultiStageQueryEngine(useMultiStageQueryEngine);
String query = String.format("SELECT sumArrayLong(%s), sumArrayDouble(%s) FROM %s", LONG_ARRAY_COLUMN,
DOUBLE_ARRAY_COLUMN, getTableName());
JsonNode result = postQuery(query).get("resultTable");
JsonNode columnDataTypesNode = result.get("dataSchema").get("columnDataTypes");
assertEquals(columnDataTypesNode.get(0).textValue(), "LONG_ARRAY");
assertEquals(columnDataTypesNode.get(1).textValue(), "DOUBLE_ARRAY");
JsonNode rows = result.get("rows");
assertEquals(rows.size(), 1);
JsonNode row = rows.get(0);
assertEquals(row.size(), 2);
JsonNode entry0 = row.get(0);
assertEquals(entry0.size(), 4);
assertEquals(entry0.get(0).longValue(), 0L);
assertEquals(entry0.get(1).longValue(), 1000L);
assertEquals(entry0.get(2).longValue(), 2000L);
assertEquals(entry0.get(3).longValue(), 3000L);
JsonNode entry1 = row.get(1);
assertEquals(entry1.size(), 4);
assertEquals(entry1.get(0).doubleValue(), 0.0);
// Compare double values:
assertEquals(DoubleComparisonUtil.doubleCompare(entry1.get(1).doubleValue(), 100.0, 0.00000000001), 0);
assertEquals(DoubleComparisonUtil.doubleCompare(entry1.get(2).doubleValue(), 200.0, 0.00000000001), 0);
assertEquals(DoubleComparisonUtil.doubleCompare(entry1.get(3).doubleValue(), 300.0, 0.00000000001), 0);
}

@Test(dataProvider = "useBothQueryEngines")
public void testFloatArrayLiteral(boolean useMultiStageQueryEngine)
throws Exception {
Expand Down Expand Up @@ -541,6 +573,8 @@ public Schema createSchema() {
.addSingleValueDimension(STRING_COLUMN, FieldSpec.DataType.STRING)
.addSingleValueDimension(TIMESTAMP_COLUMN, FieldSpec.DataType.TIMESTAMP)
.addSingleValueDimension(GROUP_BY_COLUMN, FieldSpec.DataType.STRING)
.addMultiValueDimension(LONG_ARRAY_COLUMN, FieldSpec.DataType.LONG)
.addMultiValueDimension(DOUBLE_ARRAY_COLUMN, FieldSpec.DataType.DOUBLE)
.build();
}

Expand Down Expand Up @@ -570,6 +604,12 @@ public File createAvroFile()
null, null),
new org.apache.avro.Schema.Field(GROUP_BY_COLUMN,
org.apache.avro.Schema.create(org.apache.avro.Schema.Type.STRING),
null, null),
new org.apache.avro.Schema.Field(LONG_ARRAY_COLUMN,
org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(org.apache.avro.Schema.Type.LONG)),
null, null),
new org.apache.avro.Schema.Field(DOUBLE_ARRAY_COLUMN,
org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE)),
null, null)
));

Expand All @@ -592,6 +632,8 @@ public File createAvroFile()
record.put(STRING_COLUMN, RandomStringUtils.random(finalI));
record.put(TIMESTAMP_COLUMN, finalI);
record.put(GROUP_BY_COLUMN, String.valueOf(finalI % 10));
record.put(LONG_ARRAY_COLUMN, ImmutableList.of(0, 1, 2, 3));
record.put(DOUBLE_ARRAY_COLUMN, ImmutableList.of(0.0, 0.1, 0.2, 0.3));
return record;
}
));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,21 @@ public static Object convert(Object value, ColumnDataType storedType) {
case INT_ARRAY:
if (value instanceof IntArrayList) {
// For ArrayAggregationFunction
return ((IntArrayList) value).elements();
return ((IntArrayList) value).toIntArray();
} else {
return value;
}
case LONG_ARRAY:
if (value instanceof LongArrayList) {
// For FunnelCountAggregationFunction and ArrayAggregationFunction
return ((LongArrayList) value).elements();
return ((LongArrayList) value).toLongArray();
} else {
return value;
}
case FLOAT_ARRAY:
if (value instanceof FloatArrayList) {
// For ArrayAggregationFunction
return ((FloatArrayList) value).elements();
return ((FloatArrayList) value).toFloatArray();
} else if (value instanceof double[]) {
// This is due to for parsing array literal value like [0.1, 0.2, 0.3].
// The parsed value is stored as double[] in java, however the calcite type is FLOAT_ARRAY.
Expand All @@ -80,7 +80,7 @@ public static Object convert(Object value, ColumnDataType storedType) {
case DOUBLE_ARRAY:
if (value instanceof DoubleArrayList) {
// For HistogramAggregationFunction and ArrayAggregationFunction
return ((DoubleArrayList) value).elements();
return ((DoubleArrayList) value).toDoubleArray();
} else {
return value;
}
Expand Down

0 comments on commit 3d73d2b

Please sign in to comment.