From 3d73d2ba5b5265836bced882d28dc2a3371627c7 Mon Sep 17 00:00:00 2001 From: Xiang Fu Date: Mon, 10 Jun 2024 14:19:05 -0700 Subject: [PATCH] BugFixing for ArrayList elements() usage --- .../apache/pinot/common/utils/DataSchema.java | 8 ++-- .../results/AggregationResultsBlock.java | 10 ++--- .../blocks/results/GroupByResultsBlock.java | 8 ++-- .../integration/tests/custom/ArrayTest.java | 42 +++++++++++++++++++ .../runtime/operator/utils/TypeUtils.java | 8 ++-- 5 files changed, 59 insertions(+), 17 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java index 9bd8dacef109..75c60aa78468 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/DataSchema.java @@ -523,7 +523,7 @@ private static int[] toIntArray(Object value) { return (int[]) value; } else if (value instanceof IntArrayList) { // For ArrayAggregationFunction - return ((IntArrayList) value).elements(); + return ((IntArrayList) value).toIntArray(); } throw new IllegalStateException(String.format("Cannot convert: '%s' to int[]", value)); } @@ -533,7 +533,7 @@ private static float[] toFloatArray(Object value) { return (float[]) value; } else if (value instanceof FloatArrayList) { // For ArrayAggregationFunction - return ((FloatArrayList) value).elements(); + return ((FloatArrayList) value).toFloatArray(); } throw new IllegalStateException(String.format("Cannot convert: '%s' to float[]", value)); } @@ -543,7 +543,7 @@ private static double[] toDoubleArray(Object value) { return (double[]) value; } else if (value instanceof DoubleArrayList) { // For HistogramAggregationFunction and ArrayAggregationFunction - return ((DoubleArrayList) value).elements(); + return ((DoubleArrayList) value).toDoubleArray(); } else if (value instanceof int[]) { int[] intValues = (int[]) value; int length = intValues.length; @@ -576,7 +576,7 @@ private static long[] toLongArray(Object value) { return (long[]) value; } else if (value instanceof LongArrayList) { // For FunnelCountAggregationFunction and ArrayAggregationFunction - return ((LongArrayList) value).elements(); + return ((LongArrayList) value).toLongArray(); } else { int[] intValues = (int[]) value; int length = intValues.length; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java index 8c3e025af305..5333ca8be12c 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/AggregationResultsBlock.java @@ -195,19 +195,19 @@ private void setFinalResult(DataTableBuilder dataTableBuilder, ColumnDataType[] dataTableBuilder.setColumn(index, (ByteArray) result); break; case INT_ARRAY: - dataTableBuilder.setColumn(index, ((IntArrayList) result).elements()); + dataTableBuilder.setColumn(index, ((IntArrayList) result).toIntArray()); break; case LONG_ARRAY: - dataTableBuilder.setColumn(index, ((LongArrayList) result).elements()); + dataTableBuilder.setColumn(index, ((LongArrayList) result).toLongArray()); break; case FLOAT_ARRAY: - dataTableBuilder.setColumn(index, ((FloatArrayList) result).elements()); + dataTableBuilder.setColumn(index, ((FloatArrayList) result).toFloatArray()); break; case DOUBLE_ARRAY: - dataTableBuilder.setColumn(index, ((DoubleArrayList) result).elements()); + dataTableBuilder.setColumn(index, ((DoubleArrayList) result).toDoubleArray()); break; case STRING_ARRAY: - dataTableBuilder.setColumn(index, ((ObjectArrayList) result).elements()); + dataTableBuilder.setColumn(index, ((ObjectArrayList) result).toArray(new String[0])); break; default: throw new IllegalStateException("Illegal column data type in final result: " + columnDataType); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/GroupByResultsBlock.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/GroupByResultsBlock.java index 81f46e6277c1..508717e661a0 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/GroupByResultsBlock.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/blocks/results/GroupByResultsBlock.java @@ -251,28 +251,28 @@ private void setDataTableColumn(ColumnDataType storedColumnDataType, DataTableBu break; case INT_ARRAY: if (value instanceof IntArrayList) { - dataTableBuilder.setColumn(columnIndex, ((IntArrayList) value).elements()); + dataTableBuilder.setColumn(columnIndex, ((IntArrayList) value).toIntArray()); } else { dataTableBuilder.setColumn(columnIndex, (int[]) value); } break; case LONG_ARRAY: if (value instanceof LongArrayList) { - dataTableBuilder.setColumn(columnIndex, ((LongArrayList) value).elements()); + dataTableBuilder.setColumn(columnIndex, ((LongArrayList) value).toLongArray()); } else { dataTableBuilder.setColumn(columnIndex, (long[]) value); } break; case FLOAT_ARRAY: if (value instanceof FloatArrayList) { - dataTableBuilder.setColumn(columnIndex, ((FloatArrayList) value).elements()); + dataTableBuilder.setColumn(columnIndex, ((FloatArrayList) value).toFloatArray()); } else { dataTableBuilder.setColumn(columnIndex, (float[]) value); } break; case DOUBLE_ARRAY: if (value instanceof DoubleArrayList) { - dataTableBuilder.setColumn(columnIndex, ((DoubleArrayList) value).elements()); + dataTableBuilder.setColumn(columnIndex, ((DoubleArrayList) value).toDoubleArray()); } else { dataTableBuilder.setColumn(columnIndex, (double[]) value); } diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java index ceeefa28d295..00e41b6f22a6 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java @@ -28,6 +28,7 @@ import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericDatumWriter; import org.apache.commons.lang3.RandomStringUtils; +import org.apache.pinot.core.util.DoubleComparisonUtil; import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.data.Schema; import org.testng.annotations.Test; @@ -48,6 +49,8 @@ public class ArrayTest extends CustomDataQueryClusterIntegrationTest { private static final String STRING_COLUMN = "stringCol"; private static final String TIMESTAMP_COLUMN = "timestampCol"; private static final String GROUP_BY_COLUMN = "groupKey"; + private static final String LONG_ARRAY_COLUMN = "longArrayCol"; + private static final String DOUBLE_ARRAY_COLUMN = "doubleArrayCol"; @Override protected long getCountStarResult() { @@ -460,6 +463,35 @@ public void testLongArrayLiteral(boolean useMultiStageQueryEngine) } } + @Test(dataProvider = "useBothQueryEngines") + public void testArraySum(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + String query = String.format("SELECT sumArrayLong(%s), sumArrayDouble(%s) FROM %s", LONG_ARRAY_COLUMN, + DOUBLE_ARRAY_COLUMN, getTableName()); + JsonNode result = postQuery(query).get("resultTable"); + JsonNode columnDataTypesNode = result.get("dataSchema").get("columnDataTypes"); + assertEquals(columnDataTypesNode.get(0).textValue(), "LONG_ARRAY"); + assertEquals(columnDataTypesNode.get(1).textValue(), "DOUBLE_ARRAY"); + JsonNode rows = result.get("rows"); + assertEquals(rows.size(), 1); + JsonNode row = rows.get(0); + assertEquals(row.size(), 2); + JsonNode entry0 = row.get(0); + assertEquals(entry0.size(), 4); + assertEquals(entry0.get(0).longValue(), 0L); + assertEquals(entry0.get(1).longValue(), 1000L); + assertEquals(entry0.get(2).longValue(), 2000L); + assertEquals(entry0.get(3).longValue(), 3000L); + JsonNode entry1 = row.get(1); + assertEquals(entry1.size(), 4); + assertEquals(entry1.get(0).doubleValue(), 0.0); + // Compare double values: + assertEquals(DoubleComparisonUtil.doubleCompare(entry1.get(1).doubleValue(), 100.0, 0.00000000001), 0); + assertEquals(DoubleComparisonUtil.doubleCompare(entry1.get(2).doubleValue(), 200.0, 0.00000000001), 0); + assertEquals(DoubleComparisonUtil.doubleCompare(entry1.get(3).doubleValue(), 300.0, 0.00000000001), 0); + } + @Test(dataProvider = "useBothQueryEngines") public void testFloatArrayLiteral(boolean useMultiStageQueryEngine) throws Exception { @@ -541,6 +573,8 @@ public Schema createSchema() { .addSingleValueDimension(STRING_COLUMN, FieldSpec.DataType.STRING) .addSingleValueDimension(TIMESTAMP_COLUMN, FieldSpec.DataType.TIMESTAMP) .addSingleValueDimension(GROUP_BY_COLUMN, FieldSpec.DataType.STRING) + .addMultiValueDimension(LONG_ARRAY_COLUMN, FieldSpec.DataType.LONG) + .addMultiValueDimension(DOUBLE_ARRAY_COLUMN, FieldSpec.DataType.DOUBLE) .build(); } @@ -570,6 +604,12 @@ public File createAvroFile() null, null), new org.apache.avro.Schema.Field(GROUP_BY_COLUMN, org.apache.avro.Schema.create(org.apache.avro.Schema.Type.STRING), + null, null), + new org.apache.avro.Schema.Field(LONG_ARRAY_COLUMN, + org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(org.apache.avro.Schema.Type.LONG)), + null, null), + new org.apache.avro.Schema.Field(DOUBLE_ARRAY_COLUMN, + org.apache.avro.Schema.createArray(org.apache.avro.Schema.create(org.apache.avro.Schema.Type.DOUBLE)), null, null) )); @@ -592,6 +632,8 @@ public File createAvroFile() record.put(STRING_COLUMN, RandomStringUtils.random(finalI)); record.put(TIMESTAMP_COLUMN, finalI); record.put(GROUP_BY_COLUMN, String.valueOf(finalI % 10)); + record.put(LONG_ARRAY_COLUMN, ImmutableList.of(0, 1, 2, 3)); + record.put(DOUBLE_ARRAY_COLUMN, ImmutableList.of(0.0, 0.1, 0.2, 0.3)); return record; } )); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java index 933c3be95f05..518d633e0dbe 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/TypeUtils.java @@ -51,21 +51,21 @@ public static Object convert(Object value, ColumnDataType storedType) { case INT_ARRAY: if (value instanceof IntArrayList) { // For ArrayAggregationFunction - return ((IntArrayList) value).elements(); + return ((IntArrayList) value).toIntArray(); } else { return value; } case LONG_ARRAY: if (value instanceof LongArrayList) { // For FunnelCountAggregationFunction and ArrayAggregationFunction - return ((LongArrayList) value).elements(); + return ((LongArrayList) value).toLongArray(); } else { return value; } case FLOAT_ARRAY: if (value instanceof FloatArrayList) { // For ArrayAggregationFunction - return ((FloatArrayList) value).elements(); + return ((FloatArrayList) value).toFloatArray(); } else if (value instanceof double[]) { // This is due to for parsing array literal value like [0.1, 0.2, 0.3]. // The parsed value is stored as double[] in java, however the calcite type is FLOAT_ARRAY. @@ -80,7 +80,7 @@ public static Object convert(Object value, ColumnDataType storedType) { case DOUBLE_ARRAY: if (value instanceof DoubleArrayList) { // For HistogramAggregationFunction and ArrayAggregationFunction - return ((DoubleArrayList) value).elements(); + return ((DoubleArrayList) value).toDoubleArray(); } else { return value; }