Skip to content

Commit c6d90e8

Browse files
cloud-fanchenzhx
authored andcommitted
[SPARK-37789][SQL] Add a class to represent general aggregate functions in DS V2
### What changes were proposed in this pull request? There are a lot of aggregate functions in SQL and it's a lot of work to add them one by one in the DS v2 API. This PR proposes to add a new `GeneralAggregateFunc` class to represent all the general SQL aggregate functions. Since it's general, Spark doesn't know its aggregation buffer and can only push down the aggregation to the source completely. As an example, this PR also translates `AVG` to `GeneralAggregateFunc` and pushes it to JDBC V2. ### Why are the changes needed? To add aggregate functions in DS v2 easier. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? JDBC v2 test Closes apache#35070 from cloud-fan/agg. Lead-authored-by: Wenchen Fan <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent aad72ad commit c6d90e8

File tree

17 files changed

+137
-67
lines changed

17 files changed

+137
-67
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ public interface Expression {
2929
/**
3030
* Format the expression as a human readable SQL-like string.
3131
*/
32-
String describe();
32+
default String describe() { return this.toString(); }
3333
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java

-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,4 @@ public String toString() {
4646
return "COUNT(" + column.describe() + ")";
4747
}
4848
}
49-
50-
@Override
51-
public String describe() { return this.toString(); }
5249
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java

-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,4 @@ public CountStar() {
3232

3333
@Override
3434
public String toString() { return "COUNT(*)"; }
35-
36-
@Override
37-
public String describe() { return this.toString(); }
3835
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.expressions.aggregate;
19+
20+
import java.util.Arrays;
21+
import java.util.stream.Collectors;
22+
23+
import org.apache.spark.annotation.Evolving;
24+
import org.apache.spark.sql.connector.expressions.Expression;
25+
import org.apache.spark.sql.connector.expressions.NamedReference;
26+
27+
/**
28+
* The general implementation of {@link AggregateFunc}, which contains the upper-cased function
29+
* name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial
30+
* aggregate with this function to the source, but can only push down the entire aggregate.
31+
* <p>
32+
* The currently supported SQL aggregate functions:
33+
* <ol>
34+
* <li><pre>AVG(input1)</pre> Since 3.3.0</li>
35+
* </ol>
36+
*
37+
* @since 3.3.0
38+
*/
39+
@Evolving
40+
public final class GeneralAggregateFunc implements AggregateFunc {
41+
private final String name;
42+
private final boolean isDistinct;
43+
private final NamedReference[] inputs;
44+
45+
public String name() { return name; }
46+
public boolean isDistinct() { return isDistinct; }
47+
public NamedReference[] inputs() { return inputs; }
48+
49+
public GeneralAggregateFunc(String name, boolean isDistinct, NamedReference[] inputs) {
50+
this.name = name;
51+
this.isDistinct = isDistinct;
52+
this.inputs = inputs;
53+
}
54+
55+
@Override
56+
public String toString() {
57+
String inputsString = Arrays.stream(inputs)
58+
.map(Expression::describe)
59+
.collect(Collectors.joining(", "));
60+
if (isDistinct) {
61+
return name + "(DISTINCT " + inputsString + ")";
62+
} else {
63+
return name + "(" + inputsString + ")";
64+
}
65+
}
66+
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java

-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,4 @@ public final class Max implements AggregateFunc {
3535

3636
@Override
3737
public String toString() { return "MAX(" + column.describe() + ")"; }
38-
39-
@Override
40-
public String describe() { return this.toString(); }
4138
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java

-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,4 @@ public final class Min implements AggregateFunc {
3535

3636
@Override
3737
public String toString() { return "MIN(" + column.describe() + ")"; }
38-
39-
@Override
40-
public String describe() { return this.toString(); }
4138
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java

-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,4 @@ public String toString() {
4646
return "SUM(" + column.describe() + ")";
4747
}
4848
}
49-
50-
@Override
51-
public String describe() { return this.toString(); }
5249
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java

-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,4 @@ public abstract class Filter implements Expression, Serializable {
3737
* Returns list of columns that are referenced by this filter.
3838
*/
3939
public abstract NamedReference[] references();
40-
41-
@Override
42-
public String describe() { return this.toString(); }
4340
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java

+11-10
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,19 @@
2222

2323
/**
2424
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
25-
* push down aggregates. Spark assumes that the data source can't fully complete the
26-
* grouping work, and will group the data source output again. For queries like
27-
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
28-
* to the data source, the data source can still output data with duplicated keys, which is OK
29-
* as Spark will do GROUP BY key again. The final query plan can be something like this:
25+
* push down aggregates.
26+
* <p>
27+
* If the data source can't fully complete the grouping work, then
28+
* {@link #supportCompletePushDown()} should return false, and Spark will group the data source
29+
* output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down
30+
* the aggregate to the data source, the data source can still output data with duplicated keys,
31+
* which is OK as Spark will do GROUP BY key again. The final query plan can be something like this:
3032
* <pre>
31-
* Aggregate [key#1], [min(min(value)#2) AS m#3]
32-
* +- RelationV2[key#1, min(value)#2]
33+
* Aggregate [key#1], [min(min_value#2) AS m#3]
34+
* +- RelationV2[key#1, min_value#2]
3335
* </pre>
3436
* Similarly, if there is no grouping expression, the data source can still output more than one
3537
* rows.
36-
*
3738
* <p>
3839
* When pushing down operators, Spark pushes down filters to the data source first, then push down
3940
* aggregates or apply column pruning. Depends on data source implementation, aggregates may or
@@ -46,8 +47,8 @@
4647
public interface SupportsPushDownAggregates extends ScanBuilder {
4748

4849
/**
49-
* Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg
50-
* and final-agg when the aggregation operation can be pushed down to the datasource completely.
50+
* Whether the datasource support complete aggregation push-down. Spark will do grouping again
51+
* if this method returns false.
5152
*
5253
* @return true if the aggregation can be pushed down to datasource completely, false otherwise.
5354
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala

+6-14
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R
8888

8989
override def arguments: Array[Expression] = Array(ref)
9090

91-
override def describe: String = name + "(" + reference.describe + ")"
92-
93-
override def toString: String = describe
91+
override def toString: String = name + "(" + reference.describe + ")"
9492

9593
protected def withNewRef(ref: NamedReference): Transform
9694

@@ -114,16 +112,14 @@ private[sql] final case class BucketTransform(
114112

115113
override def arguments: Array[Expression] = numBuckets +: columns.toArray
116114

117-
override def describe: String =
115+
override def toString: String =
118116
if (sortedColumns.nonEmpty) {
119117
s"bucket(${arguments.map(_.describe).mkString(", ")}," +
120118
s" ${sortedColumns.map(_.describe).mkString(", ")})"
121119
} else {
122120
s"bucket(${arguments.map(_.describe).mkString(", ")})"
123121
}
124122

125-
override def toString: String = describe
126-
127123
override def withReferences(newReferences: Seq[NamedReference]): Transform = {
128124
this.copy(columns = newReferences)
129125
}
@@ -169,9 +165,7 @@ private[sql] final case class ApplyTransform(
169165
arguments.collect { case named: NamedReference => named }
170166
}
171167

172-
override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
173-
174-
override def toString: String = describe
168+
override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
175169
}
176170

177171
/**
@@ -338,21 +332,19 @@ private[sql] object HoursTransform {
338332
}
339333

340334
private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
341-
override def describe: String = {
335+
override def toString: String = {
342336
if (dataType.isInstanceOf[StringType]) {
343337
s"'$value'"
344338
} else {
345339
s"$value"
346340
}
347341
}
348-
override def toString: String = describe
349342
}
350343

351344
private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference {
352345
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
353346
override def fieldNames: Array[String] = parts.toArray
354-
override def describe: String = parts.quoted
355-
override def toString: String = describe
347+
override def toString: String = parts.quoted
356348
}
357349

358350
private[sql] object FieldReference {
@@ -366,7 +358,7 @@ private[sql] final case class SortValue(
366358
direction: SortDirection,
367359
nullOrdering: NullOrdering) extends SortOrder {
368360

369-
override def describe(): String = s"$expression $direction $nullOrdering"
361+
override def toString(): String = s"$expression $direction $nullOrdering"
370362
}
371363

372364
private[sql] object SortValue {

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ class TransformExtractorSuite extends SparkFunSuite {
2828
private def lit[T](literal: T): Literal[T] = new Literal[T] {
2929
override def value: T = literal
3030
override def dataType: DataType = catalyst.expressions.Literal(literal).dataType
31-
override def describe: String = literal.toString
31+
override def toString: String = literal.toString
3232
}
3333

3434
/**
3535
* Creates a NamedReference using an anonymous class.
3636
*/
3737
private def ref(names: String*): NamedReference = new NamedReference {
3838
override def fieldNames: Array[String] = names.toArray
39-
override def describe: String = names.mkString(".")
39+
override def toString: String = names.mkString(".")
4040
}
4141

4242
/**
@@ -46,7 +46,7 @@ class TransformExtractorSuite extends SparkFunSuite {
4646
override def name: String = func
4747
override def references: Array[NamedReference] = Array(ref)
4848
override def arguments: Array[Expression] = Array(ref)
49-
override def describe: String = ref.describe
49+
override def toString: String = ref.describe
5050
}
5151

5252
test("Identity extractor") {
@@ -135,7 +135,7 @@ class TransformExtractorSuite extends SparkFunSuite {
135135
override def name: String = "bucket"
136136
override def references: Array[NamedReference] = Array(col)
137137
override def arguments: Array[Expression] = Array(lit(16), col)
138-
override def describe: String = s"bucket(16, ${col.describe})"
138+
override def toString: String = s"bucket(16, ${col.describe})"
139139
}
140140

141141
bucketTransform match {

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ case class RowDataSourceScanExec(
156156
"ReadSchema" -> requiredSchema.catalogString,
157157
"PushedFilters" -> seqToString(markedFilters.toSeq)) ++
158158
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
159-
Map("PushedAggregates" -> seqToString(v.aggregateExpressions),
160-
"PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++
159+
Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())),
160+
"PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++
161161
topNOrLimitInfo ++
162162
pushedDownOperators.sample.map(v => "PushedSample" ->
163163
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
4141
import org.apache.spark.sql.connector.catalog.SupportsRead
4242
import org.apache.spark.sql.connector.catalog.TableCapability._
4343
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
44-
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
44+
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
4545
import org.apache.spark.sql.errors.QueryCompilationErrors
4646
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
4747
import org.apache.spark.sql.execution.command._
@@ -714,8 +714,10 @@ object DataSourceStrategy
714714
Some(new Count(FieldReference(name), aggregates.isDistinct))
715715
case _ => None
716716
}
717-
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
717+
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
718718
Some(new Sum(FieldReference(name), aggregates.isDistinct))
719+
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
720+
Some(new GeneralAggregateFunc("AVG", aggregates.isDistinct, Array(FieldReference(name))))
719721
case _ => None
720722
}
721723
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

+11-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
2626
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
2727
import org.apache.spark.sql.catalyst.rules.Rule
2828
import org.apache.spark.sql.connector.expressions.SortOrder
29-
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
29+
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc}
3030
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
3131
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
3232
import org.apache.spark.sql.sources
@@ -109,6 +109,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
109109
sHolder.builder, normalizedAggregates, normalizedGroupingExpressions)
110110
if (pushedAggregates.isEmpty) {
111111
aggNode // return original plan node
112+
} else if (!supportPartialAggPushDown(pushedAggregates.get) &&
113+
!r.supportCompletePushDown()) {
114+
aggNode // return original plan node
112115
} else {
113116
// No need to do column pruning because only the aggregate columns are used as
114117
// DataSourceV2ScanRelation output columns. All the other columns are not
@@ -145,9 +148,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
145148
""".stripMargin)
146149

147150
val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
148-
149151
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
150-
151152
if (r.supportCompletePushDown()) {
152153
val projectExpressions = resultExpressions.map { expr =>
153154
// TODO At present, only push down group by attribute is supported.
@@ -209,6 +210,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
209210
}
210211
}
211212

213+
private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
214+
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
215+
agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc])
216+
}
217+
212218
private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) =
213219
if (aggAttribute.dataType == aggDataType) {
214220
aggAttribute
@@ -256,7 +262,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
256262

257263
def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform {
258264
case sample: Sample => sample.child match {
259-
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
265+
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
260266
val tableSample = TableSampleInfo(
261267
sample.lowerBound,
262268
sample.upperBound,
@@ -282,7 +288,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
282288
}
283289
operation
284290
case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))
285-
if filter.isEmpty =>
291+
if filter.isEmpty =>
286292
val orders = DataSourceStrategy.translateSortOrders(order)
287293
if (orders.length == order.length) {
288294
val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ case class JDBCScanBuilder(
7979
if (!jdbcOptions.pushDownAggregate) return false
8080

8181
val dialect = JdbcDialects.get(jdbcOptions.url)
82-
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate(_))
82+
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
8383
if (compiledAggs.length != aggregation.aggregateExpressions.length) return false
8484

8585
val groupByCols = aggregation.groupByColumns.map { col =>

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

+6-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.AnalysisException
3030
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
3131
import org.apache.spark.sql.connector.catalog.TableChange
3232
import org.apache.spark.sql.connector.catalog.TableChange._
33-
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
33+
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
3434
import org.apache.spark.sql.errors.QueryCompilationErrors
3535
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
3636
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -216,7 +216,11 @@ abstract class JdbcDialect extends Serializable with Logging{
216216
val column = quoteIdentifier(sum.column.fieldNames.head)
217217
Some(s"SUM($distinct$column)")
218218
case _: CountStar =>
219-
Some(s"COUNT(*)")
219+
Some("COUNT(*)")
220+
case f: GeneralAggregateFunc if f.name() == "AVG" =>
221+
assert(f.inputs().length == 1)
222+
val distinct = if (f.isDistinct) "DISTINCT " else ""
223+
Some(s"AVG($distinct${f.inputs().head})")
220224
case _ => None
221225
}
222226
}

0 commit comments

Comments
 (0)