Skip to content

Commit 4042e1a

Browse files
belieferchenzhx
authored andcommitted
[SPARK-37644][SQL] Support datasource v2 complete aggregate pushdown
### What changes were proposed in this pull request? Currently , Spark supports push down aggregate with partial-agg and final-agg . For some data source (e.g. JDBC ) , we can avoid partial-agg and final-agg by running completely on database. ### Why are the changes needed? Improve performance for aggregate pushdown. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the inner implement. ### How was this patch tested? New tests. Closes apache#34904 from beliefer/SPARK-37644. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 50c0451 commit 4042e1a

File tree

4 files changed

+172
-40
lines changed

4 files changed

+172
-40
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java

+8
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@
4545
@Evolving
4646
public interface SupportsPushDownAggregates extends ScanBuilder {
4747

48+
/**
49+
* Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg
50+
* and final-agg when the aggregation operation can be pushed down to the datasource completely.
51+
*
52+
* @return true if the aggregation can be pushed down to datasource completely, false otherwise.
53+
*/
54+
default boolean supportCompletePushDown() { return false; }
55+
4856
/**
4957
* Pushes down Aggregation to datasource. The order of the datasource scan output columns should
5058
* be: grouping columns, aggregate columns (in the same order as the aggregate functions in

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

+64-37
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
22+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
2323
import org.apache.spark.sql.catalyst.expressions.aggregate
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2525
import org.apache.spark.sql.catalyst.planning.ScanOperation
@@ -30,7 +30,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
3030
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
3131
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
3232
import org.apache.spark.sql.sources
33-
import org.apache.spark.sql.types.StructType
33+
import org.apache.spark.sql.types.{DataType, LongType, StructType}
3434
import org.apache.spark.sql.util.SchemaUtils._
3535

3636
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
@@ -131,7 +131,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
131131
case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
132132
case (_, b) => b
133133
}
134-
val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
134+
val aggOutput = newOutput.drop(groupAttrs.length)
135+
val output = groupAttrs ++ aggOutput
135136

136137
logInfo(
137138
s"""
@@ -147,40 +148,59 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
147148

148149
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
149150

150-
val plan = Aggregate(
151-
output.take(groupingExpressions.length), resultExpressions, scanRelation)
152-
153-
// scalastyle:off
154-
// Change the optimized logical plan to reflect the pushed down aggregate
155-
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
156-
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
157-
// The original logical plan is
158-
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
159-
// +- RelationV2[c1#9, c2#10] ...
160-
//
161-
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
162-
// we have the following
163-
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
164-
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
165-
//
166-
// We want to change it to
167-
// == Optimized Logical Plan ==
168-
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
169-
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
170-
// scalastyle:on
171-
val aggOutput = output.drop(groupAttrs.length)
172-
plan.transformExpressions {
173-
case agg: AggregateExpression =>
174-
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
175-
val aggFunction: aggregate.AggregateFunction =
176-
agg.aggregateFunction match {
177-
case max: aggregate.Max => max.copy(child = aggOutput(ordinal))
178-
case min: aggregate.Min => min.copy(child = aggOutput(ordinal))
179-
case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal))
180-
case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal))
181-
case other => other
182-
}
183-
agg.copy(aggregateFunction = aggFunction)
151+
if (r.supportCompletePushDown()) {
152+
val projectExpressions = resultExpressions.map { expr =>
153+
// TODO At present, only push down group by attribute is supported.
154+
// In future, more attribute conversion is extended here. e.g. GetStructField
155+
expr.transform {
156+
case agg: AggregateExpression =>
157+
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
158+
val child =
159+
addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
160+
Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
161+
}
162+
}.asInstanceOf[Seq[NamedExpression]]
163+
Project(projectExpressions, scanRelation)
164+
} else {
165+
val plan = Aggregate(
166+
output.take(groupingExpressions.length), resultExpressions, scanRelation)
167+
168+
// scalastyle:off
169+
// Change the optimized logical plan to reflect the pushed down aggregate
170+
// e.g. TABLE t (c1 INT, c2 INT, c3 INT)
171+
// SELECT min(c1), max(c1) FROM t GROUP BY c2;
172+
// The original logical plan is
173+
// Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
174+
// +- RelationV2[c1#9, c2#10] ...
175+
//
176+
// After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
177+
// we have the following
178+
// !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
179+
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
180+
//
181+
// We want to change it to
182+
// == Optimized Logical Plan ==
183+
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
184+
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
185+
// scalastyle:on
186+
plan.transformExpressions {
187+
case agg: AggregateExpression =>
188+
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
189+
val aggAttribute = aggOutput(ordinal)
190+
val aggFunction: aggregate.AggregateFunction =
191+
agg.aggregateFunction match {
192+
case max: aggregate.Max =>
193+
max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType))
194+
case min: aggregate.Min =>
195+
min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType))
196+
case sum: aggregate.Sum =>
197+
sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType))
198+
case _: aggregate.Count =>
199+
aggregate.Sum(addCastIfNeeded(aggAttribute, LongType))
200+
case other => other
201+
}
202+
agg.copy(aggregateFunction = aggFunction)
203+
}
184204
}
185205
}
186206
case _ => aggNode
@@ -189,6 +209,13 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
189209
}
190210
}
191211

212+
private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) =
213+
if (aggAttribute.dataType == aggDataType) {
214+
aggAttribute
215+
} else {
216+
Cast(aggAttribute, aggDataType)
217+
}
218+
192219
def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {
193220
case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
194221
// column pruning

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala

+3
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ case class JDBCScanBuilder(
7272

7373
private var pushedGroupByCols: Option[Array[String]] = None
7474

75+
override def supportCompletePushDown: Boolean =
76+
jdbcOptions.numPartitions.map(_ == 1).getOrElse(true)
77+
7578
override def pushAggregation(aggregation: Aggregation): Boolean = {
7679
if (!jdbcOptions.pushDownAggregate) return false
7780

0 commit comments

Comments
 (0)