Skip to content

Commit

Permalink
Fix NonAggregationGroupByToDistinctQueryRewriter (#9605)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang authored Oct 17, 2022
1 parent 4935326 commit edf0c01
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 194 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,58 +18,72 @@
*/
package org.apache.pinot.sql.parsers.rewriter;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.PinotQuery;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.sql.parsers.CalciteSqlParser;
import org.apache.pinot.sql.parsers.SqlCompilationException;


/**
* Rewrite non-aggregation group-by query to distinct query.
* The query can be rewritten only if select expression set and group-by expression set are the same.
*
* E.g.
* SELECT col1, col2 FROM foo GROUP BY col1, col2 --> SELECT DISTINCT col1, col2 FROM foo
* SELECT col1 + col2 FROM foo GROUP BY col1 + col2 --> SELECT DISTINCT col1 + col2 FROM foo
* SELECT col1 AS c1 FROM foo GROUP BY col1 --> SELECT DISTINCT col1 AS c1 FROM foo
* SELECT col1, col1 AS c1, col2 FROM foo GROUP BY col1, col2 --> SELECT DISTINCT col1, col1 AS ci, col2 FROM foo
*
* Unsupported queries:
* SELECT col1 FROM foo GROUP BY col1, col2 (not equivalent to SELECT DISTINCT col1 FROM foo)
* SELECT col1 + col2 FROM foo GROUP BY col1, col2 (not equivalent to SELECT col1 + col2 FROM foo)
*/
public class NonAggregationGroupByToDistinctQueryRewriter implements QueryRewriter {
/**
* Rewrite non-aggregate group by query to distinct query.
* E.g.
* ```
* SELECT col1+col2*5 FROM foo GROUP BY col1, col2 => SELECT distinct col1+col2*5 FROM foo
* SELECT col1, col2 FROM foo GROUP BY col1, col2 => SELECT distinct col1, col2 FROM foo
* ```
* @param pinotQuery
*/

@Override
public PinotQuery rewrite(PinotQuery pinotQuery) {
boolean hasAggregation = false;
if (pinotQuery.getGroupByListSize() == 0) {
return pinotQuery;
}
for (Expression select : pinotQuery.getSelectList()) {
if (CalciteSqlParser.isAggregateExpression(select)) {
hasAggregation = true;
return pinotQuery;
}
}
if (pinotQuery.getOrderByList() != null) {
for (Expression orderBy : pinotQuery.getOrderByList()) {
if (CalciteSqlParser.isAggregateExpression(orderBy)) {
hasAggregation = true;
return pinotQuery;
}
}
}
if (!hasAggregation && pinotQuery.getGroupByListSize() > 0) {
Set<String> selectIdentifiers = CalciteSqlParser.extractIdentifiers(pinotQuery.getSelectList(), true);
Set<String> groupByIdentifiers = CalciteSqlParser.extractIdentifiers(pinotQuery.getGroupByList(), true);
if (groupByIdentifiers.containsAll(selectIdentifiers)) {
Expression distinctExpression = RequestUtils.getFunctionExpression("distinct");
for (Expression select : pinotQuery.getSelectList()) {
distinctExpression.getFunctionCall().addToOperands(select);
}
pinotQuery.setSelectList(Arrays.asList(distinctExpression));
pinotQuery.setGroupByList(Collections.emptyList());

// This rewriter is applied after AliasApplier, so all the alias in group-by are already replaced with expressions
Set<Expression> selectExpressions = new HashSet<>();
for (Expression select : pinotQuery.getSelectList()) {
Function function = select.getFunctionCall();
if (function != null && function.getOperator().equals("as")) {
selectExpressions.add(function.getOperands().get(0));
} else {
selectIdentifiers.removeAll(groupByIdentifiers);
throw new SqlCompilationException(String.format(
"For non-aggregation group by query, all the identifiers in select clause should be in groupBys. Found "
+ "identifier: %s", Arrays.toString(selectIdentifiers.toArray(new String[0]))));
selectExpressions.add(select);
}
}
return pinotQuery;
Set<Expression> groupByExpressions = new HashSet<>(pinotQuery.getGroupByList());
if (selectExpressions.equals(groupByExpressions)) {
Expression distinct = RequestUtils.getFunctionExpression("distinct");
distinct.getFunctionCall().setOperands(pinotQuery.getSelectList());
pinotQuery.setSelectList(Collections.singletonList(distinct));
pinotQuery.setGroupByList(null);
return pinotQuery;
} else {
throw new SqlCompilationException(String.format(
"For non-aggregation group-by query, select expression set and group-by expression set should be the same. "
+ "Found select: %s, group-by: %s", selectExpressions, groupByExpressions));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1812,15 +1812,15 @@ public void testOrdinalsQueryRewrite() {
}

@Test
public void testOrdinalsQueryRewriteWithDistinctOrderby() {
public void testOrdinalsQueryRewriteWithDistinctOrderBy() {
String query =
"SELECT baseballStats.playerName AS playerName FROM baseballStats GROUP BY baseballStats.playerName ORDER BY "
+ "1 ASC";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Assert.assertEquals(
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(0)
.getIdentifier().getName(), "baseballStats.playerName");
Assert.assertTrue(pinotQuery.getGroupByList().isEmpty());
Assert.assertNull(pinotQuery.getGroupByList());
Assert.assertEquals(
pinotQuery.getOrderByList().get(0).getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"baseballStats.playerName");
Expand Down Expand Up @@ -1937,19 +1937,16 @@ public void testCompilationInvokedFunction() {
Assert.assertEquals(encodedBase64, "aGVsbG8h");
Assert.assertEquals(decodedBase64, "hello!");

query =
"select toBase64(toUtf8(upper('hello!'))), fromUtf8(fromBase64(toBase64(toUtf8(upper('hello!'))))) from "
+ "mytable";
query = "select toBase64(toUtf8(upper('hello!'))), fromUtf8(fromBase64(toBase64(toUtf8(upper('hello!'))))) from "
+ "mytable";
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
encodedBase64 = pinotQuery.getSelectList().get(0).getLiteral().getStringValue();
decodedBase64 = pinotQuery.getSelectList().get(1).getLiteral().getStringValue();
Assert.assertEquals(encodedBase64, "SEVMTE8h");
Assert.assertEquals(decodedBase64, "HELLO!");

query =
"select reverse(fromUtf8(fromBase64(toBase64(toUtf8(upper('hello!')))))) from mytable where fromUtf8"
+ "(fromBase64(toBase64(toUtf8(upper('hello!')))))"
+ " = bar";
query = "select reverse(fromUtf8(fromBase64(toBase64(toUtf8(upper('hello!')))))) from mytable where "
+ "fromUtf8(fromBase64(toBase64(toUtf8(upper('hello!'))))) = bar";
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
String arg1 = pinotQuery.getSelectList().get(0).getLiteral().getStringValue();
String leftOp =
Expand Down Expand Up @@ -2217,7 +2214,7 @@ public void testNonAggregationGroupByQuery() {
Assert.assertEquals(
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getIdentifier().getName(), "col2");

query = "SELECT col1+col2*5 FROM foo GROUP BY col1, col2";
query = "SELECT col1+col2*5 FROM foo GROUP BY col1+col2*5";
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator().toUpperCase(), "DISTINCT");
Expand All @@ -2237,7 +2234,7 @@ public void testNonAggregationGroupByQuery() {
pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getFunctionCall().getOperands().get(1)
.getFunctionCall().getOperands().get(1).getLiteral().getLongValue(), 5L);

query = "SELECT col1+col2*5 AS col3 FROM foo GROUP BY col1, col2";
query = "SELECT col1+col2*5 AS col3 FROM foo GROUP BY col3";
pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator().toUpperCase(), "DISTINCT");
Expand Down Expand Up @@ -2265,17 +2262,16 @@ public void testNonAggregationGroupByQuery() {
5L);
}

@Test(expectedExceptions = SqlCompilationException.class)
@Test
public void testInvalidNonAggregationGroupBy() {
// Not support Aggregation functions in case statements.
try {
CalciteSqlParser.compileToPinotQuery("SELECT col1+col2 FROM foo GROUP BY col1");
} catch (SqlCompilationException e) {
Assert.assertEquals(e.getMessage(),
"For non-aggregation group by query, all the identifiers in select clause should be in groupBys. Found "
+ "identifier: [col2]");
throw e;
}
Assert.assertThrows(SqlCompilationException.class,
() -> CalciteSqlParser.compileToPinotQuery("SELECT col1 FROM foo GROUP BY col1, col2"));
Assert.assertThrows(SqlCompilationException.class,
() -> CalciteSqlParser.compileToPinotQuery("SELECT col1, col2 FROM foo GROUP BY col1"));
Assert.assertThrows(SqlCompilationException.class,
() -> CalciteSqlParser.compileToPinotQuery("SELECT col1 + col2 FROM foo GROUP BY col1"));
Assert.assertThrows(SqlCompilationException.class,
() -> CalciteSqlParser.compileToPinotQuery("SELECT col1+col2 FROM foo GROUP BY col1,col2"));
}

@Test
Expand Down Expand Up @@ -2681,7 +2677,6 @@ public void testParserExtensionImpl() {
Assert.assertEquals(sqlNodeAndOptions.getSqlType(), PinotSqlType.DML);
}


@Test
public void shouldParseBasicAtTimeZoneExtension() {
// Given:
Expand Down
Loading

0 comments on commit edf0c01

Please sign in to comment.