Skip to content

Commit 4c2380b

Browse files
belieferchenzhx
authored andcommitted
[SPARK-37960][SQL] A new framework to represent catalyst expressions in DS v2 APIs
### What changes were proposed in this pull request? This PR provides a new framework to represent catalyst expressions in DS v2 APIs. `GeneralSQLExpression` is a general SQL expression to represent catalyst expression in DS v2 API. `ExpressionSQLBuilder` is a builder to generate `GeneralSQLExpression` from catalyst expressions. `CASE ... WHEN ... ELSE ... END` is just the first use case. This PR also supports aggregate push down with `CASE ... WHEN ... ELSE ... END`. ### Why are the changes needed? Support aggregate push down with `CASE ... WHEN ... ELSE ... END`. ### Does this PR introduce _any_ user-facing change? Yes. Users could use `CASE ... WHEN ... ELSE ... END` with aggregate push down. ### How was this patch tested? New tests. Closes apache#35248 from beliefer/SPARK-37960. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 762af83 commit 4c2380b

File tree

15 files changed

+371
-107
lines changed

15 files changed

+371
-107
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.expressions;
19+
20+
import java.io.Serializable;
21+
22+
import org.apache.spark.annotation.Evolving;
23+
24+
/**
25+
* The general SQL string corresponding to expression.
26+
*
27+
* @since 3.3.0
28+
*/
29+
@Evolving
30+
public class GeneralSQLExpression implements Expression, Serializable {
31+
private String sql;
32+
33+
public GeneralSQLExpression(String sql) {
34+
this.sql = sql;
35+
}
36+
37+
public String sql() { return sql; }
38+
39+
@Override
40+
public String toString() { return sql; }
41+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.connector.expressions.aggregate;
1919

2020
import org.apache.spark.annotation.Evolving;
21-
import org.apache.spark.sql.connector.expressions.NamedReference;
21+
import org.apache.spark.sql.connector.expressions.Expression;
2222

2323
/**
2424
* An aggregate function that returns the mean of all the values in a group.
@@ -27,23 +27,23 @@
2727
*/
2828
@Evolving
2929
public final class Avg implements AggregateFunc {
30-
private final NamedReference column;
30+
private final Expression input;
3131
private final boolean isDistinct;
3232

33-
public Avg(NamedReference column, boolean isDistinct) {
34-
this.column = column;
33+
public Avg(Expression column, boolean isDistinct) {
34+
this.input = column;
3535
this.isDistinct = isDistinct;
3636
}
3737

38-
public NamedReference column() { return column; }
38+
public Expression column() { return input; }
3939
public boolean isDistinct() { return isDistinct; }
4040

4141
@Override
4242
public String toString() {
4343
if (isDistinct) {
44-
return "AVG(DISTINCT " + column.describe() + ")";
44+
return "AVG(DISTINCT " + input.describe() + ")";
4545
} else {
46-
return "AVG(" + column.describe() + ")";
46+
return "AVG(" + input.describe() + ")";
4747
}
4848
}
4949
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.connector.expressions.aggregate;
1919

2020
import org.apache.spark.annotation.Evolving;
21-
import org.apache.spark.sql.connector.expressions.NamedReference;
21+
import org.apache.spark.sql.connector.expressions.Expression;
2222

2323
/**
2424
* An aggregate function that returns the number of the specific row in a group.
@@ -27,23 +27,23 @@
2727
*/
2828
@Evolving
2929
public final class Count implements AggregateFunc {
30-
private final NamedReference column;
30+
private final Expression input;
3131
private final boolean isDistinct;
3232

33-
public Count(NamedReference column, boolean isDistinct) {
34-
this.column = column;
33+
public Count(Expression column, boolean isDistinct) {
34+
this.input = column;
3535
this.isDistinct = isDistinct;
3636
}
3737

38-
public NamedReference column() { return column; }
38+
public Expression column() { return input; }
3939
public boolean isDistinct() { return isDistinct; }
4040

4141
@Override
4242
public String toString() {
4343
if (isDistinct) {
44-
return "COUNT(DISTINCT " + column.describe() + ")";
44+
return "COUNT(DISTINCT " + input.describe() + ")";
4545
} else {
46-
return "COUNT(" + column.describe() + ")";
46+
return "COUNT(" + input.describe() + ")";
4747
}
4848
}
4949
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.connector.expressions.aggregate;
1919

2020
import org.apache.spark.annotation.Evolving;
21-
import org.apache.spark.sql.connector.expressions.NamedReference;
21+
import org.apache.spark.sql.connector.expressions.Expression;
2222

2323
/**
2424
* An aggregate function that returns the maximum value in a group.
@@ -27,12 +27,12 @@
2727
*/
2828
@Evolving
2929
public final class Max implements AggregateFunc {
30-
private final NamedReference column;
30+
private final Expression input;
3131

32-
public Max(NamedReference column) { this.column = column; }
32+
public Max(Expression column) { this.input = column; }
3333

34-
public NamedReference column() { return column; }
34+
public Expression column() { return input; }
3535

3636
@Override
37-
public String toString() { return "MAX(" + column.describe() + ")"; }
37+
public String toString() { return "MAX(" + input.describe() + ")"; }
3838
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.connector.expressions.aggregate;
1919

2020
import org.apache.spark.annotation.Evolving;
21-
import org.apache.spark.sql.connector.expressions.NamedReference;
21+
import org.apache.spark.sql.connector.expressions.Expression;
2222

2323
/**
2424
* An aggregate function that returns the minimum value in a group.
@@ -27,12 +27,12 @@
2727
*/
2828
@Evolving
2929
public final class Min implements AggregateFunc {
30-
private final NamedReference column;
30+
private final Expression input;
3131

32-
public Min(NamedReference column) { this.column = column; }
32+
public Min(Expression column) { this.input = column; }
3333

34-
public NamedReference column() { return column; }
34+
public Expression column() { return input; }
3535

3636
@Override
37-
public String toString() { return "MIN(" + column.describe() + ")"; }
37+
public String toString() { return "MIN(" + input.describe() + ")"; }
3838
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.connector.expressions.aggregate;
1919

2020
import org.apache.spark.annotation.Evolving;
21-
import org.apache.spark.sql.connector.expressions.NamedReference;
21+
import org.apache.spark.sql.connector.expressions.Expression;
2222

2323
/**
2424
* An aggregate function that returns the summation of all the values in a group.
@@ -27,23 +27,23 @@
2727
*/
2828
@Evolving
2929
public final class Sum implements AggregateFunc {
30-
private final NamedReference column;
30+
private final Expression input;
3131
private final boolean isDistinct;
3232

33-
public Sum(NamedReference column, boolean isDistinct) {
34-
this.column = column;
33+
public Sum(Expression column, boolean isDistinct) {
34+
this.input = column;
3535
this.isDistinct = isDistinct;
3636
}
3737

38-
public NamedReference column() { return column; }
38+
public Expression column() { return input; }
3939
public boolean isDistinct() { return isDistinct; }
4040

4141
@Override
4242
public String toString() {
4343
if (isDistinct) {
44-
return "SUM(DISTINCT " + column.describe() + ")";
44+
return "SUM(DISTINCT " + input.describe() + ")";
4545
} else {
46-
return "SUM(" + column.describe() + ")";
46+
return "SUM(" + input.describe() + ")";
4747
}
4848
}
4949
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryOperator, CaseWhen, EqualTo, Expression, IsNotNull, IsNull, Literal, Not}
21+
import org.apache.spark.sql.connector.expressions.LiteralValue
22+
23+
/**
24+
* The builder to generate SQL string from catalyst expressions.
25+
*/
26+
class ExpressionSQLBuilder(e: Expression) {
27+
28+
def build(): Option[String] = generateSQL(e)
29+
30+
private def generateSQL(expr: Expression): Option[String] = expr match {
31+
case Literal(value, dataType) => Some(LiteralValue(value, dataType).toString)
32+
case a: Attribute => Some(quoteIfNeeded(a.name))
33+
case IsNull(col) => generateSQL(col).map(c => s"$c IS NULL")
34+
case IsNotNull(col) => generateSQL(col).map(c => s"$c IS NOT NULL")
35+
case b: BinaryOperator =>
36+
val l = generateSQL(b.left)
37+
val r = generateSQL(b.right)
38+
if (l.isDefined && r.isDefined) {
39+
Some(s"(${l.get}) ${b.sqlOperator} (${r.get})")
40+
} else {
41+
None
42+
}
43+
case Not(EqualTo(left, right)) =>
44+
val l = generateSQL(left)
45+
val r = generateSQL(right)
46+
if (l.isDefined && r.isDefined) {
47+
Some(s"${l.get} != ${r.get}")
48+
} else {
49+
None
50+
}
51+
case Not(child) => generateSQL(child).map(v => s"NOT ($v)")
52+
case CaseWhen(branches, elseValue) =>
53+
val conditionsSQL = branches.map(_._1).flatMap(generateSQL)
54+
val valuesSQL = branches.map(_._2).flatMap(generateSQL)
55+
if (conditionsSQL.length == branches.length && valuesSQL.length == branches.length) {
56+
val branchSQL =
57+
conditionsSQL.zip(valuesSQL).map { case (c, v) => s" WHEN $c THEN $v" }.mkString
58+
if (elseValue.isDefined) {
59+
elseValue.flatMap(generateSQL).map(v => s"CASE$branchSQL ELSE $v END")
60+
} else {
61+
Some(s"CASE$branchSQL END")
62+
}
63+
} else {
64+
None
65+
}
66+
// TODO supports other expressions
67+
case _ => None
68+
}
69+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala

+23-16
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.datasources
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.Expression
22-
import org.apache.spark.sql.connector.expressions.NamedReference
2322
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
2423
import org.apache.spark.sql.execution.RowToColumnConverter
24+
import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
2525
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
2626
import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType}
2727
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -42,27 +42,28 @@ object AggregatePushDownUtils {
4242

4343
var finalSchema = new StructType()
4444

45-
def getStructFieldForCol(col: NamedReference): StructField = {
46-
schema.apply(col.fieldNames.head)
45+
def getStructFieldForCol(colName: String): StructField = {
46+
schema.apply(colName)
4747
}
4848

49-
def isPartitionCol(col: NamedReference) = {
50-
partitionNames.contains(col.fieldNames.head)
49+
def isPartitionCol(colName: String) = {
50+
partitionNames.contains(colName)
5151
}
5252

5353
def processMinOrMax(agg: AggregateFunc): Boolean = {
54-
val (column, aggType) = agg match {
55-
case max: Max => (max.column, "max")
56-
case min: Min => (min.column, "min")
57-
case _ =>
58-
throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}")
54+
val (columnName, aggType) = agg match {
55+
case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined =>
56+
(V2ColumnUtils.extractV2Column(max.column).get, "max")
57+
case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined =>
58+
(V2ColumnUtils.extractV2Column(min.column).get, "min")
59+
case _ => return false
5960
}
6061

61-
if (isPartitionCol(column)) {
62+
if (isPartitionCol(columnName)) {
6263
// don't push down partition column, footer doesn't have max/min for partition column
6364
return false
6465
}
65-
val structField = getStructFieldForCol(column)
66+
val structField = getStructFieldForCol(columnName)
6667

6768
structField.dataType match {
6869
// not push down complex type
@@ -93,16 +94,22 @@ object AggregatePushDownUtils {
9394
// (https://issues.apache.org/jira/browse/SPARK-36646)
9495
return None
9596
}
97+
aggregation.groupByColumns.foreach { col =>
98+
// don't push down if the group by columns are not the same as the partition columns (orders
99+
// doesn't matter because reorder can be done at data source layer)
100+
if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None
101+
finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head))
102+
}
96103

97104
aggregation.aggregateExpressions.foreach {
98105
case max: Max =>
99106
if (!processMinOrMax(max)) return None
100107
case min: Min =>
101108
if (!processMinOrMax(min)) return None
102-
case count: Count =>
103-
if (count.column.fieldNames.length != 1 || count.isDistinct) return None
104-
finalSchema =
105-
finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType))
109+
case count: Count
110+
if V2ColumnUtils.extractV2Column(count.column).isDefined && !count.isDistinct =>
111+
val columnName = V2ColumnUtils.extractV2Column(count.column).get
112+
finalSchema = finalSchema.add(StructField(s"count($columnName)", LongType))
106113
case _: CountStar =>
107114
finalSchema = finalSchema.add(StructField("count(*)", LongType))
108115
case _ =>

0 commit comments

Comments
 (0)