Skip to content

Commit

Permalink
Add a class to represent general aggregate functions in DS V2
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Dec 30, 2021
1 parent 4c58f12 commit 4829767
Show file tree
Hide file tree
Showing 17 changed files with 120 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ public interface Expression {
/**
* Format the expression as a human readable SQL-like string.
*/
String describe();
default String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,4 @@ public String toString() {
return "COUNT(" + column.describe() + ")";
}
}

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,4 @@ public CountStar() {

@Override
public String toString() { return "COUNT(*)"; }

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import java.util.Arrays;
import java.util.stream.Collectors;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* The general implementation of {@link AggregateFunc}, which contains the upper-cased function
* name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial
* aggregate with this function to the source, but can only push down the entire aggregate.
* <p>
* The currently supported SQL aggregate functions:
* <ol>
* <li><pre>AVG(input1)</pre> Since 3.3.0</li>
* </ol>
*
* @since 3.3.0
*/
@Evolving
public final class GeneralAggregateFunc implements AggregateFunc {
private final String name;
private final boolean isDistinct;
private final NamedReference[] inputs;

public String name() { return name; }
public boolean isDistinct() { return isDistinct; }
public NamedReference[] inputs() { return inputs; }

public GeneralAggregateFunc(String name, boolean isDistinct, NamedReference[] inputs) {
this.name = name;
this.isDistinct = isDistinct;
this.inputs = inputs;
}

@Override
public String toString() {
String inputsString = Arrays.stream(inputs)
.map(Expression::describe)
.collect(Collectors.joining(", "));
if (isDistinct) {
return name + "(DISTINCT " + inputsString + ")";
} else {
return name + "(" + inputsString + ")";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,4 @@ public final class Max implements AggregateFunc {

@Override
public String toString() { return "MAX(" + column.describe() + ")"; }

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,4 @@ public final class Min implements AggregateFunc {

@Override
public String toString() { return "MIN(" + column.describe() + ")"; }

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,4 @@ public String toString() {
return "SUM(" + column.describe() + ")";
}
}

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,4 @@ public abstract class Filter implements Expression, Serializable {
* Returns list of columns that are referenced by this filter.
*/
public abstract NamedReference[] references();

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down aggregates. Spark assumes that the data source can't fully complete the
* grouping work, and will group the data source output again. For queries like
* "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
* to the data source, the data source can still output data with duplicated keys, which is OK
* as Spark will do GROUP BY key again. The final query plan can be something like this:
* push down aggregates.
* <p>
* If the data source can't fully complete the grouping work, then
* {@link #supportCompletePushDown()} should return false, and Spark will group the data source
* output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after pushing down
* the aggregate to the data source, the data source can still output data with duplicated keys,
* which is OK as Spark will do GROUP BY key again. The final query plan can be something like this:
* <pre>
* Aggregate [key#1], [min(min(value)#2) AS m#3]
* +- RelationV2[key#1, min(value)#2]
* Aggregate [key#1], [min(min_value#2) AS m#3]
* +- RelationV2[key#1, min_value#2]
* </pre>
* Similarly, if there is no grouping expression, the data source can still output more than one
* rows.
*
* <p>
* When pushing down operators, Spark pushes down filters to the data source first, then push down
* aggregates or apply column pruning. Depends on data source implementation, aggregates may or
Expand All @@ -46,8 +47,8 @@
public interface SupportsPushDownAggregates extends ScanBuilder {

/**
* Whether the datasource support complete aggregation push-down. Spark could avoid partial-agg
* and final-agg when the aggregation operation can be pushed down to the datasource completely.
* Whether the datasource support complete aggregation push-down. Spark will do grouping again
* if this method returns false.
*
* @return true if the aggregation can be pushed down to datasource completely, false otherwise.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R

override def arguments: Array[Expression] = Array(ref)

override def describe: String = name + "(" + reference.describe + ")"

override def toString: String = describe
override def toString: String = name + "(" + reference.describe + ")"

protected def withNewRef(ref: NamedReference): Transform

Expand All @@ -114,16 +112,14 @@ private[sql] final case class BucketTransform(

override def arguments: Array[Expression] = numBuckets +: columns.toArray

override def describe: String =
override def toString: String =
if (sortedColumns.nonEmpty) {
s"bucket(${arguments.map(_.describe).mkString(", ")}," +
s" ${sortedColumns.map(_.describe).mkString(", ")})"
} else {
s"bucket(${arguments.map(_.describe).mkString(", ")})"
}

override def toString: String = describe

override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences)
}
Expand Down Expand Up @@ -169,9 +165,7 @@ private[sql] final case class ApplyTransform(
arguments.collect { case named: NamedReference => named }
}

override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})"

override def toString: String = describe
override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
}

/**
Expand Down Expand Up @@ -338,21 +332,19 @@ private[sql] object HoursTransform {
}

private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
override def describe: String = {
override def toString: String = {
if (dataType.isInstanceOf[StringType]) {
s"'$value'"
} else {
s"$value"
}
}
override def toString: String = describe
}

private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
override def fieldNames: Array[String] = parts.toArray
override def describe: String = parts.quoted
override def toString: String = describe
override def toString: String = parts.quoted
}

private[sql] object FieldReference {
Expand All @@ -366,7 +358,7 @@ private[sql] final case class SortValue(
direction: SortDirection,
nullOrdering: NullOrdering) extends SortOrder {

override def describe(): String = s"$expression $direction $nullOrdering"
override def toString(): String = s"$expression $direction $nullOrdering"
}

private[sql] object SortValue {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ class TransformExtractorSuite extends SparkFunSuite {
private def lit[T](literal: T): Literal[T] = new Literal[T] {
override def value: T = literal
override def dataType: DataType = catalyst.expressions.Literal(literal).dataType
override def describe: String = literal.toString
override def toString: String = literal.toString
}

/**
* Creates a NamedReference using an anonymous class.
*/
private def ref(names: String*): NamedReference = new NamedReference {
override def fieldNames: Array[String] = names.toArray
override def describe: String = names.mkString(".")
override def toString: String = names.mkString(".")
}

/**
Expand All @@ -46,7 +46,7 @@ class TransformExtractorSuite extends SparkFunSuite {
override def name: String = func
override def references: Array[NamedReference] = Array(ref)
override def arguments: Array[Expression] = Array(ref)
override def describe: String = ref.describe
override def toString: String = ref.describe
}

test("Identity extractor") {
Expand Down Expand Up @@ -135,7 +135,7 @@ class TransformExtractorSuite extends SparkFunSuite {
override def name: String = "bucket"
override def references: Array[NamedReference] = Array(col)
override def arguments: Array[Expression] = Array(lit(16), col)
override def describe: String = s"bucket(16, ${col.describe})"
override def toString: String = s"bucket(16, ${col.describe})"
}

bucketTransform match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ case class RowDataSourceScanExec(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> seqToString(markedFilters.toSeq)) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
Map("PushedAggregates" -> seqToString(v.aggregateExpressions),
"PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++
Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())),
"PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++
topNOrLimitInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
Expand Down Expand Up @@ -717,8 +717,10 @@ object DataSourceStrategy
Some(new Count(FieldReference(name), agg.isDistinct))
case _ => None
}
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
case aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference(name), agg.isDistinct))
case aggregate.Average(PushableColumnWithoutNestedColumn(name), _) =>
Some(new GeneralAggregateFunc("AVG", agg.isDistinct, Array(FieldReference(name))))
case _ => None
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, GeneralAggregateFunc}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
Expand Down Expand Up @@ -109,6 +109,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
r, normalizedAggregates, normalizedGroupingExpressions)
if (pushedAggregates.isEmpty) {
aggNode // return original plan node
} else if (!supportPartialAggPushDown(pushedAggregates.get) &&
!r.supportCompletePushDown()) {
aggNode // return original plan node
} else {
// No need to do column pruning because only the aggregate columns are used as
// DataSourceV2ScanRelation output columns. All the other columns are not
Expand Down Expand Up @@ -145,9 +148,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
""".stripMargin)

val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)

val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)

if (r.supportCompletePushDown()) {
val projectExpressions = resultExpressions.map { expr =>
// TODO At present, only push down group by attribute is supported.
Expand Down Expand Up @@ -209,6 +210,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
// We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
agg.aggregateExpressions().forall(!_.isInstanceOf[GeneralAggregateFunc])
}

private def addCastIfNeeded(aggAttribute: AttributeReference, aggDataType: DataType) =
if (aggAttribute.dataType == aggDataType) {
aggAttribute
Expand Down Expand Up @@ -256,7 +262,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {

def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform {
case sample: Sample => sample.child match {
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
val tableSample = TableSampleInfo(
sample.lowerBound,
sample.upperBound,
Expand All @@ -282,7 +288,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
operation
case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))
if filter.isEmpty =>
if filter.isEmpty =>
val orders = DataSourceStrategy.translateSortOrders(order)
if (orders.length == order.length) {
val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ case class JDBCScanBuilder(
if (!jdbcOptions.pushDownAggregate) return false

val dialect = JdbcDialects.get(jdbcOptions.url)
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate(_))
val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
if (compiledAggs.length != aggregation.aggregateExpressions.length) return false

val groupByCols = aggregation.groupByColumns.map { col =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand Down Expand Up @@ -219,7 +219,11 @@ abstract class JdbcDialect extends Serializable with Logging{
val column = quoteIdentifier(sum.column.fieldNames.head)
Some(s"SUM($distinct$column)")
case _: CountStar =>
Some(s"COUNT(*)")
Some("COUNT(*)")
case f: GeneralAggregateFunc if f.name() == "AVG" =>
assert(f.inputs().length == 1)
val distinct = if (f.isDistinct) "DISTINCT " else ""
Some(s"AVG($distinct${f.inputs().head})")
case _ => None
}
}
Expand Down
Loading

0 comments on commit 4829767

Please sign in to comment.