diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java index f74e1bd24fde..bd12f0901b51 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/MultiStageEngineIntegrationTest.java @@ -128,12 +128,60 @@ public void testMultiValueColumnSelectionQuery() testQueryWithMatchingRowCount(pinotQuery, h2Query); } + @Test(dataProvider = "useBothQueryEngines") + public void testMultiValueColumnAggregationQuery(boolean useMultiStageQueryEngine) + throws Exception { + setUseMultiStageQueryEngine(useMultiStageQueryEngine); + + String[] multiValueFunctions = new String[]{ + "sumMV", "countMV", "minMV", "maxMV", "avgMV", "minMaxRangeMV", "distinctCountMV", "distinctCountBitmapMV", + "distinctCountHLLMV", "distinctSumMV", "distinctAvgMV" + }; + double[] expectedResults = new double[]{ + -5.421344202E9, 577725, -9999.0, 16271.0, -9383.95292223809, 26270.0, 312, 312, 328, 3954484.0, + 12674.628205128205 + }; + + Assert.assertEquals(multiValueFunctions.length, expectedResults.length); + + for (int i = 0; i < multiValueFunctions.length; i++) { + String pinotQuery = String.format("SELECT %s(DivAirportIDs) FROM mytable", multiValueFunctions[i]); + JsonNode jsonNode = postQuery(pinotQuery); + Assert.assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble(), expectedResults[i]); + } + + String pinotQuery = "SELECT percentileMV(DivAirportIDs, 99) FROM mytable"; + JsonNode jsonNode = postQuery(pinotQuery); + Assert.assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble(), 13433.0); + + pinotQuery = "SELECT percentileEstMV(DivAirportIDs, 99) FROM mytable"; + jsonNode = postQuery(pinotQuery); + Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() > 13000); + Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() < 14000); + + pinotQuery = "SELECT percentileTDigestMV(DivAirportIDs, 99) FROM mytable"; + jsonNode = postQuery(pinotQuery); + Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() > 13000); + Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() < 14000); + + pinotQuery = "SELECT percentileKLLMV(DivAirportIDs, 99) FROM mytable"; + jsonNode = postQuery(pinotQuery); + Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() > 12000); + Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() < 15000); + + pinotQuery = "SELECT percentileKLLMV(DivAirportIDs, 99, 100) FROM mytable"; + jsonNode = postQuery(pinotQuery); + Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() > 12000); + Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() < 15000); + + setUseMultiStageQueryEngine(true); + } + @Test public void testTimeFunc() throws Exception { String sqlQuery = "SELECT toDateTime(now(), 'yyyy-MM-dd z'), toDateTime(ago('PT1H'), 'yyyy-MM-dd z') FROM mytable"; JsonNode response = postQuery(sqlQuery); - System.out.println("response = " + response); String todayStr = response.get("resultTable").get("rows").get(0).get(0).asText(); String expectedTodayStr = Instant.now().atZone(ZoneId.of("UTC")).format(DateTimeFormatter.ofPattern("yyyy-MM-dd z")); @@ -441,7 +489,6 @@ public void testLiteralOnlyFunc() + "decodeUrl('key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253') as decodedUrl, toBase64" + "(toUtf8('hello!')) as toBase64, fromUtf8(fromBase64('aGVsbG8h')) as fromBase64"; JsonNode response = postQuery(sqlQuery); - System.out.println("response = " + response.toPrettyString()); long queryEndTimeMs = System.currentTimeMillis(); JsonNode resultTable = response.get("resultTable"); diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java index b636ed4f2443..460bf61fafab 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java @@ -150,25 +150,56 @@ public enum AggregationFunctionType { STUNION("STUnion"), // Aggregation functions for multi-valued columns - COUNTMV("countMV"), - MINMV("minMV"), - MAXMV("maxMV"), - SUMMV("sumMV"), - AVGMV("avgMV"), - MINMAXRANGEMV("minMaxRangeMV"), - DISTINCTCOUNTMV("distinctCountMV"), - DISTINCTCOUNTBITMAPMV("distinctCountBitmapMV"), - DISTINCTCOUNTHLLMV("distinctCountHLLMV"), - DISTINCTCOUNTRAWHLLMV("distinctCountRawHLLMV"), - DISTINCTSUMMV("distinctSumMV"), - DISTINCTAVGMV("distinctAvgMV"), - PERCENTILEMV("percentileMV"), - PERCENTILEESTMV("percentileEstMV"), - PERCENTILERAWESTMV("percentileRawEstMV"), - PERCENTILETDIGESTMV("percentileTDigestMV"), - PERCENTILERAWTDIGESTMV("percentileRawTDigestMV"), - PERCENTILEKLLMV("percentileKLLMV"), - PERCENTILERAWKLLMV("percentileRawKLLMV"), + COUNTMV("countMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.explicit(SqlTypeName.BIGINT), + ReturnTypes.explicit(SqlTypeName.BIGINT)), + MINMV("minMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)), + MAXMV("maxMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)), + SUMMV("sumMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)), + AVGMV("avgMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), + MINMAXRANGEMV("minMaxRangeMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.explicit(SqlTypeName.DOUBLE), + ReturnTypes.explicit(SqlTypeName.OTHER)), + DISTINCTCOUNTMV("distinctCountMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), + DISTINCTCOUNTBITMAPMV("distinctCountBitmapMV", null, SqlKind.OTHER_FUNCTION, + SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, + ReturnTypes.explicit(SqlTypeName.OTHER)), + DISTINCTCOUNTHLLMV("distinctCountHLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)), + DISTINCTCOUNTRAWHLLMV("distinctCountRawHLLMV", null, SqlKind.OTHER_FUNCTION, + SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.VARCHAR_2000, + ReturnTypes.explicit(SqlTypeName.OTHER)), + DISTINCTSUMMV("distinctSumMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), + DISTINCTAVGMV("distinctAvgMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), + PERCENTILEMV("percentileMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE, + ReturnTypes.explicit(SqlTypeName.OTHER)), + PERCENTILEESTMV("percentileEstMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE, + ReturnTypes.explicit(SqlTypeName.OTHER)), + PERCENTILERAWESTMV("percentileRawEstMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000, + ReturnTypes.explicit(SqlTypeName.OTHER)), + PERCENTILETDIGESTMV("percentileTDigestMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE, + ReturnTypes.explicit(SqlTypeName.OTHER)), + PERCENTILERAWTDIGESTMV("percentileRawTDigestMV", null, SqlKind.OTHER_FUNCTION, + SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000, + ReturnTypes.explicit(SqlTypeName.OTHER)), + PERCENTILEKLLMV("percentileKLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC), + ordinal -> ordinal > 1 && ordinal < 4), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)), + PERCENTILERAWKLLMV("percentileRawKLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION, + OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC), + ordinal -> ordinal > 1 && ordinal < 4), ReturnTypes.VARCHAR_2000, ReturnTypes.explicit(SqlTypeName.OTHER)), // boolean aggregate functions BOOLAND("boolAnd", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,