Skip to content

Commit

Permalink
[multistage] Support multi-value aggregation functions (apache#11216)
Browse files Browse the repository at this point in the history
* temp

* Support multi-value aggregation functions
  • Loading branch information
xiangfu0 authored and s0nskar committed Aug 10, 2023
1 parent e1659f6 commit 288f3a4
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,60 @@ public void testMultiValueColumnSelectionQuery()
testQueryWithMatchingRowCount(pinotQuery, h2Query);
}

@Test(dataProvider = "useBothQueryEngines")
public void testMultiValueColumnAggregationQuery(boolean useMultiStageQueryEngine)
throws Exception {
setUseMultiStageQueryEngine(useMultiStageQueryEngine);

String[] multiValueFunctions = new String[]{
"sumMV", "countMV", "minMV", "maxMV", "avgMV", "minMaxRangeMV", "distinctCountMV", "distinctCountBitmapMV",
"distinctCountHLLMV", "distinctSumMV", "distinctAvgMV"
};
double[] expectedResults = new double[]{
-5.421344202E9, 577725, -9999.0, 16271.0, -9383.95292223809, 26270.0, 312, 312, 328, 3954484.0,
12674.628205128205
};

Assert.assertEquals(multiValueFunctions.length, expectedResults.length);

for (int i = 0; i < multiValueFunctions.length; i++) {
String pinotQuery = String.format("SELECT %s(DivAirportIDs) FROM mytable", multiValueFunctions[i]);
JsonNode jsonNode = postQuery(pinotQuery);
Assert.assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble(), expectedResults[i]);
}

String pinotQuery = "SELECT percentileMV(DivAirportIDs, 99) FROM mytable";
JsonNode jsonNode = postQuery(pinotQuery);
Assert.assertEquals(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble(), 13433.0);

pinotQuery = "SELECT percentileEstMV(DivAirportIDs, 99) FROM mytable";
jsonNode = postQuery(pinotQuery);
Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() > 13000);
Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() < 14000);

pinotQuery = "SELECT percentileTDigestMV(DivAirportIDs, 99) FROM mytable";
jsonNode = postQuery(pinotQuery);
Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() > 13000);
Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() < 14000);

pinotQuery = "SELECT percentileKLLMV(DivAirportIDs, 99) FROM mytable";
jsonNode = postQuery(pinotQuery);
Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() > 12000);
Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() < 15000);

pinotQuery = "SELECT percentileKLLMV(DivAirportIDs, 99, 100) FROM mytable";
jsonNode = postQuery(pinotQuery);
Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() > 12000);
Assert.assertTrue(jsonNode.get("resultTable").get("rows").get(0).get(0).asDouble() < 15000);

setUseMultiStageQueryEngine(true);
}

@Test
public void testTimeFunc()
throws Exception {
String sqlQuery = "SELECT toDateTime(now(), 'yyyy-MM-dd z'), toDateTime(ago('PT1H'), 'yyyy-MM-dd z') FROM mytable";
JsonNode response = postQuery(sqlQuery);
System.out.println("response = " + response);
String todayStr = response.get("resultTable").get("rows").get(0).get(0).asText();
String expectedTodayStr =
Instant.now().atZone(ZoneId.of("UTC")).format(DateTimeFormatter.ofPattern("yyyy-MM-dd z"));
Expand Down Expand Up @@ -441,7 +489,6 @@ public void testLiteralOnlyFunc()
+ "decodeUrl('key1%3Dvalue+1%26key2%3Dvalue%40%21%242%26key3%3Dvalue%253') as decodedUrl, toBase64"
+ "(toUtf8('hello!')) as toBase64, fromUtf8(fromBase64('aGVsbG8h')) as fromBase64";
JsonNode response = postQuery(sqlQuery);
System.out.println("response = " + response.toPrettyString());
long queryEndTimeMs = System.currentTimeMillis();

JsonNode resultTable = response.get("resultTable");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,25 +150,56 @@ public enum AggregationFunctionType {
STUNION("STUnion"),

// Aggregation functions for multi-valued columns
COUNTMV("countMV"),
MINMV("minMV"),
MAXMV("maxMV"),
SUMMV("sumMV"),
AVGMV("avgMV"),
MINMAXRANGEMV("minMaxRangeMV"),
DISTINCTCOUNTMV("distinctCountMV"),
DISTINCTCOUNTBITMAPMV("distinctCountBitmapMV"),
DISTINCTCOUNTHLLMV("distinctCountHLLMV"),
DISTINCTCOUNTRAWHLLMV("distinctCountRawHLLMV"),
DISTINCTSUMMV("distinctSumMV"),
DISTINCTAVGMV("distinctAvgMV"),
PERCENTILEMV("percentileMV"),
PERCENTILEESTMV("percentileEstMV"),
PERCENTILERAWESTMV("percentileRawEstMV"),
PERCENTILETDIGESTMV("percentileTDigestMV"),
PERCENTILERAWTDIGESTMV("percentileRawTDigestMV"),
PERCENTILEKLLMV("percentileKLLMV"),
PERCENTILERAWKLLMV("percentileRawKLLMV"),
COUNTMV("countMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.explicit(SqlTypeName.BIGINT),
ReturnTypes.explicit(SqlTypeName.BIGINT)),
MINMV("minMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)),
MAXMV("maxMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)),
SUMMV("sumMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.DOUBLE)),
AVGMV("avgMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)),
MINMAXRANGEMV("minMaxRangeMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.explicit(SqlTypeName.DOUBLE),
ReturnTypes.explicit(SqlTypeName.OTHER)),
DISTINCTCOUNTMV("distinctCountMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)),
DISTINCTCOUNTBITMAPMV("distinctCountBitmapMV", null, SqlKind.OTHER_FUNCTION,
SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT,
ReturnTypes.explicit(SqlTypeName.OTHER)),
DISTINCTCOUNTHLLMV("distinctCountHLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.BIGINT, ReturnTypes.explicit(SqlTypeName.OTHER)),
DISTINCTCOUNTRAWHLLMV("distinctCountRawHLLMV", null, SqlKind.OTHER_FUNCTION,
SqlFunctionCategory.USER_DEFINED_FUNCTION, OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.VARCHAR_2000,
ReturnTypes.explicit(SqlTypeName.OTHER)),
DISTINCTSUMMV("distinctSumMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)),
DISTINCTAVGMV("distinctAvgMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(SqlTypeFamily.ARRAY), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)),
PERCENTILEMV("percentileMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE,
ReturnTypes.explicit(SqlTypeName.OTHER)),
PERCENTILEESTMV("percentileEstMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE,
ReturnTypes.explicit(SqlTypeName.OTHER)),
PERCENTILERAWESTMV("percentileRawEstMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000,
ReturnTypes.explicit(SqlTypeName.OTHER)),
PERCENTILETDIGESTMV("percentileTDigestMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.DOUBLE,
ReturnTypes.explicit(SqlTypeName.OTHER)),
PERCENTILERAWTDIGESTMV("percentileRawTDigestMV", null, SqlKind.OTHER_FUNCTION,
SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC)), ReturnTypes.VARCHAR_2000,
ReturnTypes.explicit(SqlTypeName.OTHER)),
PERCENTILEKLLMV("percentileKLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC),
ordinal -> ordinal > 1 && ordinal < 4), ReturnTypes.DOUBLE, ReturnTypes.explicit(SqlTypeName.OTHER)),
PERCENTILERAWKLLMV("percentileRawKLLMV", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
OperandTypes.family(ImmutableList.of(SqlTypeFamily.ARRAY, SqlTypeFamily.NUMERIC, SqlTypeFamily.NUMERIC),
ordinal -> ordinal > 1 && ordinal < 4), ReturnTypes.VARCHAR_2000, ReturnTypes.explicit(SqlTypeName.OTHER)),

// boolean aggregate functions
BOOLAND("boolAnd", null, SqlKind.OTHER_FUNCTION, SqlFunctionCategory.USER_DEFINED_FUNCTION,
Expand Down

0 comments on commit 288f3a4

Please sign in to comment.