Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37789][SQL] Add a class to represent general aggregate functions in DS V2 #35070

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ public interface Expression {
/**
* Format the expression as a human readable SQL-like string.
*/
String describe();
default String describe() { return this.toString(); }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that almost all the implementations override both toString and describe, so making this change to make it simpler.

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,4 @@ public String toString() {
return "COUNT(" + column.describe() + ")";
}
}

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,4 @@ public CountStar() {

@Override
public String toString() { return "COUNT(*)"; }

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import java.util.Arrays;
import java.util.stream.Collectors;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* The general implementation of {@link AggregateFunc}, which contains the upper-cased function
* name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial
* aggregate with this function to the source, but can only push down the entire aggregate.
* <p>
* The currently supported SQL aggregate functions:
* <ol>
* <li><pre>AVG(input1)</pre> Since 3.3.0</li>
* </ol>
*
* @since 3.3.0
*/
@Evolving
public final class GeneralAggregateFunc implements AggregateFunc {
private final String name;
private final boolean isDistinct;
private final NamedReference[] inputs;

public String name() { return name; }
public boolean isDistinct() { return isDistinct; }
public NamedReference[] inputs() { return inputs; }

public GeneralAggregateFunc(String name, boolean isDistinct, NamedReference[] inputs) {
this.name = name;
this.isDistinct = isDistinct;
this.inputs = inputs;
}

@Override
public String toString() {
String inputsString = Arrays.stream(inputs)
.map(Expression::describe)
.collect(Collectors.joining(", "));
if (isDistinct) {
return name + "(DISTINCT " + inputsString + ")";
} else {
return name + "(" + inputsString + ")";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,4 @@ public final class Max implements AggregateFunc {

@Override
public String toString() { return "MAX(" + column.describe() + ")"; }

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,4 @@ public final class Min implements AggregateFunc {

@Override
public String toString() { return "MIN(" + column.describe() + ")"; }

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,4 @@ public String toString() {
return "SUM(" + column.describe() + ")";
}
}

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,4 @@ public abstract class Filter implements Expression, Serializable {
* Returns list of columns that are referenced by this filter.
*/
public abstract NamedReference[] references();

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down aggregates. Spark assumes that the data source can't fully complete the
* grouping work, and will group the data source output again. For queries like
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
* to the data source, the data source can still output data with duplicated keys, which is OK
* as Spark will do GROUP BY key again. The final query plan can be something like this:
* push down aggregates.
* <p>
* If the data source can't fully complete the grouping work, then
* {@link #supportCompletePushDown()} should return false, and Spark will group the data source
* output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down
* the aggregate to the data source, the data source can still output data with duplicated keys,
* which is OK as Spark will do GROUP BY key again. The final query plan can be something like this:
* <pre>
* Aggregate [key#1], [min(min(value)#2) AS m#3]
* +- RelationV2[key#1, min(value)#2]
* Aggregate [key#1], [min(min_value#2) AS m#3]
* +- RelationV2[key#1, min_value#2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is min(value)#2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually decided by the data source, and I pick min_value to make it more readable.

* </pre>
* Similarly, if there is no grouping expression, the data source can still output more than one
* rows.
*
* <p>
* When pushing down operators, Spark pushes down filters to the data source first, then push down
* aggregates or apply column pruning. Depends on data source implementation, aggregates may or
Expand All @@ -46,8 +47,8 @@
public interface SupportsPushDownAggregates extends ScanBuilder {

/**
* Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg
* and final-agg when the aggregation operation can be pushed down to the datasource completely.
* Whether the datasource support complete aggregation push-down. Spark will do grouping again
* if this method returns false.
*
* @return true if the aggregation can be pushed down to datasource completely, false otherwise.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R

override def arguments: Array[Expression] = Array(ref)

override def describe: String = name + "(" + reference.describe + ")"

override def toString: String = describe
override def toString: String = name + "(" + reference.describe + ")"

protected def withNewRef(ref: NamedReference): Transform

Expand All @@ -114,16 +112,14 @@ private[sql] final case class BucketTransform(

override def arguments: Array[Expression] = numBuckets +: columns.toArray

override def describe: String =
override def toString: String =
if (sortedColumns.nonEmpty) {
s"bucket(${arguments.map(_.describe).mkString(", ")}," +
s" ${sortedColumns.map(_.describe).mkString(", ")})"
} else {
s"bucket(${arguments.map(_.describe).mkString(", ")})"
}

override def toString: String = describe

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences)
}
Expand Down Expand Up @@ -169,9 +165,7 @@ private[sql] final case class ApplyTransform(
arguments.collect { case named: NamedReference => named }
}

override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})"

override def toString: String = describe
override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
}

/**
Expand Down Expand Up @@ -338,21 +332,19 @@ private[sql] object HoursTransform {
}

private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
override def describe: String = {
override def toString: String = {
if (dataType.isInstanceOf[StringType]) {
s"'$value'"
} else {
s"$value"
}
}
override def toString: String = describe
}

private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
override def fieldNames: Array[String] = parts.toArray
override def describe: String = parts.quoted
override def toString: String = describe
override def toString: String = parts.quoted
}

private[sql] object FieldReference {
Expand All @@ -366,7 +358,7 @@ private[sql] final case class SortValue(
direction: SortDirection,
nullOrdering: NullOrdering) extends SortOrder {

override def describe(): String = s"$expression $direction $nullOrdering"
override def toString(): String = s"$expression $direction $nullOrdering"
}

private[sql] object SortValue {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ class TransformExtractorSuite extends SparkFunSuite {
private def lit[T](literal: T): Literal[T] = new Literal[T] {
override def value: T = literal
override def dataType: DataType = catalyst.expressions.Literal(literal).dataType
override def describe: String = literal.toString
override def toString: String = literal.toString
}

/**
* Creates a NamedReference using an anonymous class.
*/
private def ref(names: String*): NamedReference = new NamedReference {
override def fieldNames: Array[String] = names.toArray
override def describe: String = names.mkString(".")
override def toString: String = names.mkString(".")
}

/**
Expand All @@ -46,7 +46,7 @@ class TransformExtractorSuite extends SparkFunSuite {
override def name: String = func
override def references: Array[NamedReference] = Array(ref)
override def arguments: Array[Expression] = Array(ref)
override def describe: String = ref.describe
override def toString: String = ref.describe
}

test("Identity extractor") {
Expand Down Expand Up @@ -135,7 +135,7 @@ class TransformExtractorSuite extends SparkFunSuite {
override def name: String = "bucket"
override def references: Array[NamedReference] = Array(col)
override def arguments: Array[Expression] = Array(lit(16), col)
override def describe: String = s"bucket(16, ${col.describe})"
override def toString: String = s"bucket(16, ${col.describe})"
}

bucketTransform match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ case class RowDataSourceScanExec(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> seqToString(markedFilters.toSeq)) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
Map("PushedAggregates" -> seqToString(v.aggregateExpressions),
"PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++
Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())),
"PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++
topNOrLimitInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
Expand Down Expand Up @@ -717,8 +717,10 @@ object DataSourceStrategy
Some(new Count(FieldReference(name), agg.isDistinct))
case _ => None
}
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference(name), agg.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need add if completePushdown here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is completePushdown defined?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value of supportCompletePushDown()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not available here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can pass it.

Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name))))
case _ => None
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
Expand Down Expand Up @@ -109,6 +109,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
r, normalizedAggregates, normalizedGroupingExpressions)
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
} else if (!supportPartialAggPushDown(pushedAggregates.get) &&
!r.supportCompletePushDown()) {
aggNode // return original plan node
} else {
// No need to do column pruning because only the aggregate columns are used as
// DataSourceV2ScanRelation output columns. All the other columns are not
Expand Down Expand Up @@ -145,9 +148,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
""".stripMargin)

val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)

val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)

if (r.supportCompletePushDown()) {
val projectExpressions = resultExpressions.map { expr =>
// TODO At present, only push down group by attribute is supported.
Expand Down Expand Up @@ -209,6 +210,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agg.aggregateExpressions().exists(_.isInstanceOf[GeneralAggregateFunc])

}

private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) =
if (aggAttribute.dataType == aggDataType) {
aggAttribute
Expand Down Expand Up @@ -256,7 +262,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {

def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform {
case sample: Sample => sample.child match {
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
val tableSample = TableSampleInfo(
sample.lowerBound,
sample.upperBound,
Expand All @@ -282,7 +288,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
operation
case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))
if filter.isEmpty =>
if filter.isEmpty =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel there should be two indent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be 4 spaces. You can check it out in other code files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ref:

if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) =>

if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) =>

Maybe we could unify the code style.

val orders = DataSourceStrategy.translateSortOrders(order)
if (orders.length == order.length) {
val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ case class JDBCScanBuilder(
if (!jdbcOptions.pushDownAggregate) return false

val dialect = JdbcDialects.get(jdbcOptions.url)
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate(_))
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
if (compiledAggs.length != aggregation.aggregateExpressions.length) return false

val groupByCols = aggregation.groupByColumns.map { col =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand Down Expand Up @@ -219,7 +219,11 @@ abstract class JdbcDialect extends Serializable with Logging{
val column = quoteIdentifier(sum.column.fieldNames.head)
Some(s"SUM($distinct$column)")
case _: CountStar =>
Some(s"COUNT(*)")
Some("COUNT(*)")
case f: GeneralAggregateFunc if f.name() == "AVG" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"AVG($distinct${f.inputs().head})")
case _ => None
}
}
Expand Down
Loading