Skip to content

Commit ed3ba66

Browse files
belieferchenzhx
authored andcommitted
[SPARK-37527][SQL] Translate more standard aggregate functions for pushdown
### What changes were proposed in this pull request? Currently, Spark aggregate pushdown will translate some standard aggregate functions, so that compile these functions to adapt specify database. After this job, users could override `JdbcDialect.compileAggregate` to implement some standard aggregate functions supported by some database. This PR just translate the ANSI standard aggregate functions. The mainstream database supports these functions show below: | Name | ClickHouse | Presto | Teradata | Snowflake | Oracle | Postgresql | Vertica | MySQL | RedShift | ElasticSearch | Impala | Druid | SyBase | DB2 | H2 | Exasol | Mariadb | Phoenix | Yellowbrick | Singlestore | Influxdata | Dolphindb | Intersystems | |-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------|-------| | `VAR_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | No | Yes | Yes | | `VAR_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | No | Yes | Yes | No | Yes | Yes | | `STDDEV_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | | `STDDEV_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | Yes | Yes | Yes | Yes | No | Yes | Yes | | `COVAR_POP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No | Yes | Yes | No | No | No | No | Yes | Yes | No | | `COVAR_SAMP` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No | Yes | Yes | No | No | No | No | No | No | No | | `CORR` | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No | No | No | No | No | Yes | No | Yes | Yes | No | No | No | No | No | Yes | No | Because some aggregate functions will be converted by Optimizer show below, this PR no need to match them. |Input|Parsed|Optimized| |------|--------------------|----------| |`Every`| `aggregate.BoolAnd` |`Min`| |`Any`| `aggregate.BoolOr` |`Max`| |`Some`| `aggregate.BoolOr` |`Max`| ### Why are the changes needed? Make the implement of `*Dialect` could extends the aggregate functions by override `JdbcDialect.compileAggregate`. ### Does this PR introduce _any_ user-facing change? Yes. Users could pushdown more aggregate functions. ### How was this patch tested? Exists tests. Closes apache#35101 from beliefer/SPARK-37527-new2. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Huaxin Gao <[email protected]>
1 parent 1620750 commit ed3ba66

File tree

4 files changed

+113
-0
lines changed

4 files changed

+113
-0
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java

+7
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@
3232
* The currently supported SQL aggregate functions:
3333
* <ol>
3434
* <li><pre>AVG(input1)</pre> Since 3.3.0</li>
35+
* <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
36+
* <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
37+
* <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>
38+
* <li><pre>STDDEV_SAMP(input1)</pre> Since 3.3.0</li>
39+
* <li><pre>COVAR_POP(input1, input2)</pre> Since 3.3.0</li>
40+
* <li><pre>COVAR_SAMP(input1, input2)</pre> Since 3.3.0</li>
41+
* <li><pre>CORR(input1, input2)</pre> Since 3.3.0</li>
3542
* </ol>
3643
*
3744
* @since 3.3.0

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

+21
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,27 @@ object DataSourceStrategy
718718
Some(new Sum(FieldReference(name), aggregates.isDistinct))
719719
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
720720
Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name))))
721+
case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
722+
Some(new GeneralAggregateFunc("VAR_POP", aggregates.isDistinct, Array(FieldReference(name))))
723+
case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
724+
Some(new GeneralAggregateFunc("VAR_SAMP", aggregates.isDistinct, Array(FieldReference(name))))
725+
case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) =>
726+
Some(new GeneralAggregateFunc("STDDEV_POP", aggregates.isDistinct, Array(FieldReference(name))))
727+
case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) =>
728+
Some(new GeneralAggregateFunc("STDDEV_SAMP", aggregates.isDistinct, Array(FieldReference(name))))
729+
case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left),
730+
PushableColumnWithoutNestedColumn(right), _) =>
731+
Some(new GeneralAggregateFunc("COVAR_POP", aggregates.isDistinct,
732+
Array(FieldReference(left), FieldReference(right))))
733+
case aggregate.CovSample(PushableColumnWithoutNestedColumn(left),
734+
PushableColumnWithoutNestedColumn(right), _) =>
735+
Some(new GeneralAggregateFunc("COVAR_SAMP", aggregates.isDistinct,
736+
Array(FieldReference(left), FieldReference(right))))
737+
case aggregate.Corr(PushableColumnWithoutNestedColumn(left),
738+
PushableColumnWithoutNestedColumn(right), _) =>
739+
Some(new GeneralAggregateFunc("CORR", aggregates.isDistinct,
740+
Array(FieldReference(left), FieldReference(right))))
741+
721742
case _ => None
722743
}
723744
} else {

sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala

+25
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,36 @@ import java.util.Locale
2222

2323
import org.apache.spark.sql.AnalysisException
2424
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
25+
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
2526

2627
private object H2Dialect extends JdbcDialect {
2728
override def canHandle(url: String): Boolean =
2829
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
2930

31+
override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
32+
super.compileAggregate(aggFunction).orElse(
33+
aggFunction match {
34+
case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
35+
assert(f.inputs().length == 1)
36+
val distinct = if (f.isDistinct) "DISTINCT " else ""
37+
Some(s"VAR_POP($distinct${f.inputs().head})")
38+
case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
39+
assert(f.inputs().length == 1)
40+
val distinct = if (f.isDistinct) "DISTINCT " else ""
41+
Some(s"VAR_SAMP($distinct${f.inputs().head})")
42+
case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
43+
assert(f.inputs().length == 1)
44+
val distinct = if (f.isDistinct) "DISTINCT " else ""
45+
Some(s"STDDEV_POP($distinct${f.inputs().head})")
46+
case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
47+
assert(f.inputs().length == 1)
48+
val distinct = if (f.isDistinct) "DISTINCT " else ""
49+
Some(s"STDDEV_SAMP($distinct${f.inputs().head})")
50+
case _ => None
51+
}
52+
)
53+
}
54+
3055
override def classifyException(message: String, e: Throwable): AnalysisException = {
3156
if (e.isInstanceOf[SQLException]) {
3257
// Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala

+60
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,66 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
713713
checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1)))
714714
}
715715

716+
test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") {
717+
val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" +
718+
" group by DePt")
719+
checkFiltersRemoved(df)
720+
checkAggregateRemoved(df)
721+
df.queryExecution.optimizedPlan.collect {
722+
case _: DataSourceV2ScanRelation =>
723+
val expected_plan_fragment =
724+
"PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " +
725+
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
726+
"PushedGroupByColumns: [DEPT]"
727+
checkKeywordsExistsInExplain(df, expected_plan_fragment)
728+
}
729+
checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
730+
}
731+
732+
test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") {
733+
val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" +
734+
" where dept > 0 group by DePt")
735+
checkFiltersRemoved(df)
736+
checkAggregateRemoved(df)
737+
df.queryExecution.optimizedPlan.collect {
738+
case _: DataSourceV2ScanRelation =>
739+
val expected_plan_fragment =
740+
"PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " +
741+
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
742+
"PushedGroupByColumns: [DEPT]"
743+
checkKeywordsExistsInExplain(df, expected_plan_fragment)
744+
}
745+
checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null)))
746+
}
747+
748+
test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") {
749+
val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" +
750+
" FROM h2.test.employee where dept > 0 group by DePt")
751+
checkFiltersRemoved(df)
752+
checkAggregateRemoved(df, false)
753+
df.queryExecution.optimizedPlan.collect {
754+
case _: DataSourceV2ScanRelation =>
755+
val expected_plan_fragment =
756+
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]"
757+
checkKeywordsExistsInExplain(df, expected_plan_fragment)
758+
}
759+
checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
760+
}
761+
762+
test("scan with aggregate push-down: CORR with filter and group by") {
763+
val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" +
764+
" group by DePt")
765+
checkFiltersRemoved(df)
766+
checkAggregateRemoved(df, false)
767+
df.queryExecution.optimizedPlan.collect {
768+
case _: DataSourceV2ScanRelation =>
769+
val expected_plan_fragment =
770+
"PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)]"
771+
checkKeywordsExistsInExplain(df, expected_plan_fragment)
772+
}
773+
checkAnswer(df, Seq(Row(1d), Row(1d), Row(null)))
774+
}
775+
716776
test("scan with aggregate push-down: aggregate over alias NOT push down") {
717777
val cols = Seq("a", "b", "c", "d")
718778
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)

0 commit comments

Comments
 (0)