-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-37789][SQL] Add a class to represent general aggregate functions in DS V2 #35070
Changes from 1 commit
4829767
127e412
0eed98e
6dc3370
a296a2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.connector.expressions.aggregate; | ||
|
||
import java.util.Arrays; | ||
import java.util.stream.Collectors; | ||
|
||
import org.apache.spark.annotation.Evolving; | ||
import org.apache.spark.sql.connector.expressions.Expression; | ||
import org.apache.spark.sql.connector.expressions.NamedReference; | ||
|
||
/** | ||
* The general implementation of {@link AggregateFunc}, which contains the upper-cased function | ||
* name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial | ||
* aggregate with this function to the source, but can only push down the entire aggregate. | ||
* <p> | ||
* The currently supported SQL aggregate functions: | ||
* <ol> | ||
* <li><pre>AVG(input1)</pre> Since 3.3.0</li> | ||
* </ol> | ||
* | ||
* @since 3.3.0 | ||
*/ | ||
@Evolving | ||
public final class GeneralAggregateFunc implements AggregateFunc { | ||
private final String name; | ||
private final boolean isDistinct; | ||
private final NamedReference[] inputs; | ||
|
||
public String name() { return name; } | ||
public boolean isDistinct() { return isDistinct; } | ||
public NamedReference[] inputs() { return inputs; } | ||
|
||
public GeneralAggregateFunc(String name, boolean isDistinct, NamedReference[] inputs) { | ||
this.name = name; | ||
this.isDistinct = isDistinct; | ||
this.inputs = inputs; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
String inputsString = Arrays.stream(inputs) | ||
.map(Expression::describe) | ||
.collect(Collectors.joining(", ")); | ||
if (isDistinct) { | ||
return name + "(DISTINCT " + inputsString + ")"; | ||
} else { | ||
return name + "(" + inputsString + ")"; | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,18 +22,19 @@ | |
|
||
/** | ||
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to | ||
* push down aggregates. Spark assumes that the data source can't fully complete the | ||
* grouping work, and will group the data source output again. For queries like | ||
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate | ||
* to the data source, the data source can still output data with duplicated keys, which is OK | ||
* as Spark will do GROUP BY key again. The final query plan can be something like this: | ||
* push down aggregates. | ||
* <p> | ||
* If the data source can't fully complete the grouping work, then | ||
* {@link #supportCompletePushDown()} should return false, and Spark will group the data source | ||
* output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down | ||
* the aggregate to the data source, the data source can still output data with duplicated keys, | ||
* which is OK as Spark will do GROUP BY key again. The final query plan can be something like this: | ||
* <pre> | ||
* Aggregate [key#1], [min(min(value)#2) AS m#3] | ||
* +- RelationV2[key#1, min(value)#2] | ||
* Aggregate [key#1], [min(min_value#2) AS m#3] | ||
* +- RelationV2[key#1, min_value#2] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's actually decided by the data source, and I pick |
||
* </pre> | ||
* Similarly, if there is no grouping expression, the data source can still output more than one | ||
* rows. | ||
* | ||
* <p> | ||
* When pushing down operators, Spark pushes down filters to the data source first, then push down | ||
* aggregates or apply column pruning. Depends on data source implementation, aggregates may or | ||
|
@@ -46,8 +47,8 @@ | |
public interface SupportsPushDownAggregates extends ScanBuilder { | ||
|
||
/** | ||
* Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg | ||
* and final-agg when the aggregation operation can be pushed down to the datasource completely. | ||
* Whether the datasource support complete aggregation push-down. Spark will do grouping again | ||
* if this method returns false. | ||
* | ||
* @return true if the aggregation can be pushed down to datasource completely, false otherwise. | ||
*/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 | |
import org.apache.spark.sql.connector.catalog.SupportsRead | ||
import org.apache.spark.sql.connector.catalog.TableCapability._ | ||
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} | ||
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} | ||
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} | ||
import org.apache.spark.sql.errors.QueryCompilationErrors | ||
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} | ||
import org.apache.spark.sql.execution.command._ | ||
|
@@ -717,8 +717,10 @@ object DataSourceStrategy | |
Some(new Count(FieldReference(name), agg.isDistinct)) | ||
case _ => None | ||
} | ||
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => | ||
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => | ||
Some(new Sum(FieldReference(name), agg.isDistinct)) | ||
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The value of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not available here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can pass it. |
||
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name)))) | ||
case _ => None | ||
} | ||
} else { | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation | |||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} | ||||||
import org.apache.spark.sql.catalyst.rules.Rule | ||||||
import org.apache.spark.sql.connector.expressions.SortOrder | ||||||
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation | ||||||
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc} | ||||||
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} | ||||||
import org.apache.spark.sql.execution.datasources.DataSourceStrategy | ||||||
import org.apache.spark.sql.sources | ||||||
|
@@ -109,6 +109,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
r, normalizedAggregates, normalizedGroupingExpressions) | ||||||
if (pushedAggregates.isEmpty) { | ||||||
aggNode // return original plan node | ||||||
} else if (!supportPartialAggPushDown(pushedAggregates.get) && | ||||||
!r.supportCompletePushDown()) { | ||||||
aggNode // return original plan node | ||||||
} else { | ||||||
// No need to do column pruning because only the aggregate columns are used as | ||||||
// DataSourceV2ScanRelation output columns. All the other columns are not | ||||||
|
@@ -145,9 +148,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
""".stripMargin) | ||||||
|
||||||
val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) | ||||||
|
||||||
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) | ||||||
|
||||||
if (r.supportCompletePushDown()) { | ||||||
val projectExpressions = resultExpressions.map { expr => | ||||||
// TODO At present, only push down group by attribute is supported. | ||||||
|
@@ -209,6 +210,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
} | ||||||
} | ||||||
|
||||||
private def supportPartialAggPushDown(agg: Aggregation): Boolean = { | ||||||
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. | ||||||
agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
} | ||||||
|
||||||
private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) = | ||||||
if (aggAttribute.dataType == aggDataType) { | ||||||
aggAttribute | ||||||
|
@@ -256,7 +262,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
|
||||||
def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform { | ||||||
case sample: Sample => sample.child match { | ||||||
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => | ||||||
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => | ||||||
val tableSample = TableSampleInfo( | ||||||
sample.lowerBound, | ||||||
sample.upperBound, | ||||||
|
@@ -282,7 +288,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { | |||||
} | ||||||
operation | ||||||
case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) | ||||||
if filter.isEmpty => | ||||||
if filter.isEmpty => | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel there should be two indent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it should be 4 spaces. You can check it out in other code files. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ref: spark/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala Line 493 in fe73039
spark/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala Line 561 in fe73039
Maybe we could unify the code style. |
||||||
val orders = DataSourceStrategy.translateSortOrders(order) | ||||||
if (orders.length == order.length) { | ||||||
val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized that almost all the implementations override both
toString
anddescribe
, so making this change to make it simpler.