Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37483][SQL] Support push down top N to JDBC data source V2 #34738

Closed
wants to merge 16 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.spark.annotation.Evolving;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface to
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down LIMIT. Please note that the combination of LIMIT with other operations
* such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.SortOrder;

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N
* with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc.
* is NOT pushed down.
*
* @since 3.3.0
*/
@Evolving
public interface SupportsPushDownTopN extends ScanBuilder {

/**
* Pushes down top N to the data source.
*/
boolean pushTopN(SortOrder[] orders, int limit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,25 @@ case class RowDataSourceScanExec(
handledFilters
}

val topNOrLimitInfo =
if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) {
val pushedTopN =
s"""
|ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}
|LIMIT ${pushedDownOperators.limit.get}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we output only one line? otherwise the plan EXPLAIN result may look very weird.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

|""".stripMargin.replaceAll("\n", " ")
Some("pushedTopN" -> pushedTopN)
} else {
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation is wrong

}

Map(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> seqToString(markedFilters.toSeq)) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
Map("PushedAggregates" -> seqToString(v.aggregateExpressions),
"PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++
topNOrLimitInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
Expand Down Expand Up @@ -336,7 +336,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
Expand Down Expand Up @@ -410,7 +410,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand All @@ -433,7 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand Down Expand Up @@ -726,6 +726,21 @@ object DataSourceStrategy
}
}

protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
sortOrders.map {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, All. This broke Scala 2.13 compilation.

[error] /home/runner/work/spark/spark/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala:730:20: match may not be exhaustive.
[error] It would fail on the following input: SortOrder(_, _, _, _)
[error]     sortOrders.map {
[error]                    ^
[warn] 24 warnings found
[error] one error found
[error] (sql / Compile / compileIncremental) Compilation failed
[error] Total time: 267 s (04:27), completed Dec 15, 2021 5:57:25 AM

case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
Copy link
Member

@dongjoon-hyun dongjoon-hyun Dec 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, this will cause scala.MatchError in Scala 2.12. We need a new test case which is not matched this case, @beliefer .

val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
case Descending => SortDirection.DESCENDING
}
val nullOrderingV2 = nullOrderingV1 match {
case NullsFirst => NullOrdering.NULLS_FIRST
case NullsLast => NullOrdering.NULLS_LAST
}
SortValue(FieldReference(name), directionV2, nullOrderingV2)
}
}

/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ class JDBCOptions(
// This only applies to Data Source V2 JDBC
val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean

// An option to allow/disallow pushing down query of top N into V2 JDBC data source
// This only applies to Data Source V2 JDBC
val pushDownTopN = parameters.getOrElse(JDBC_PUSHDOWN_TOP_N, "false").toBoolean
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should just reuse the limit pushdown option, as it's kind of a special case of LIMIT.


// An option to allow/disallow pushing down TABLESAMPLE into JDBC data source
// This only applies to Data Source V2 JDBC
val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean
Expand Down Expand Up @@ -276,6 +280,7 @@ object JDBCOptions {
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
val JDBC_PUSHDOWN_TOP_N = newOption("pushDownTopN")
val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -151,12 +152,14 @@ object JDBCRDD extends Logging {
* @param options - JDBC options that contains url, table and other information.
* @param outputSchema - The schema of the columns or aggregate columns to SELECT.
* @param groupByColumns - The pushed down group by columns.
* @param sample - The pushed down tableSample.
* @param limit - The pushed down limit. If the value is 0, it means no limit or limit
* is not pushed down.
* @param sample - The pushed down tableSample.
* @param sortValues - The sort values cooperates with limit to realize top N.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
// scalastyle:off argcount
def scanTable(
sc: SparkContext,
schema: StructType,
Expand All @@ -167,7 +170,8 @@ object JDBCRDD extends Logging {
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[String]] = None,
sample: Option[TableSampleInfo] = None,
limit: Int = 0): RDD[InternalRow] = {
limit: Int = 0,
sortValues: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
Expand All @@ -187,8 +191,10 @@ object JDBCRDD extends Logging {
options,
groupByColumns,
sample,
limit)
limit,
sortValues)
}
// scalastyle:on argcount
}

/**
Expand All @@ -207,7 +213,8 @@ private[jdbc] class JDBCRDD(
options: JDBCOptions,
groupByColumns: Option[Array[String]],
sample: Option[TableSampleInfo],
limit: Int)
limit: Int,
sortValues: Array[SortOrder])
extends RDD[InternalRow](sc, Nil) {

/**
Expand Down Expand Up @@ -255,6 +262,14 @@ private[jdbc] class JDBCRDD(
}
}

private def getOrderByClause: String = {
if (sortValues.nonEmpty) {
s" ORDER BY ${sortValues.map(_.describe()).mkString(", ")}"
} else {
""
}
}

/**
* Runs the SQL query against the JDBC driver.
*
Expand Down Expand Up @@ -339,7 +354,7 @@ private[jdbc] class JDBCRDD(
val myLimitClause: String = dialect.getLimitClause(limit)

val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
s" $myWhereClause $getGroupByClause $myLimitClause"
s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -301,7 +302,8 @@ private[sql] case class JDBCRelation(
filters: Array[Filter],
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
limit: Int): RDD[Row] = {
limit: Int,
sortValues: Array[SortOrder]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
Expand All @@ -313,7 +315,8 @@ private[sql] case class JDBCRelation(
Some(finalSchema),
groupByColumns,
tableSample,
limit).asInstanceOf[RDD[Row]]
limit,
sortValues).asInstanceOf[RDD[Row]]
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
Expand Down Expand Up @@ -157,6 +157,17 @@ object PushDownUtils extends PredicateHelper {
}
}

/**
* Pushes down top N to the data source Scan
*/
def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = {
scanBuilder match {
case s: SupportsPushDownTopN =>
s.pushTopN(order, limit)
case _ => false
}
}

/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation

/**
Expand All @@ -25,4 +26,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
case class PushedDownOperators(
aggregation: Option[Aggregation],
sample: Option[TableSampleInfo],
limit: Option[Int])
limit: Option[Int],
sortValues: Seq[SortOrder]) {
assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
Expand Down Expand Up @@ -246,16 +247,42 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match {
case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit)
if (limitPushed) {
sHolder.pushedLimit = Some(limit)
}
operation
case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))
if filter.isEmpty =>
val orders = DataSourceStrategy.translateSortOrders(order)
val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
if (topNPushed) {
sHolder.pushedLimit = Some(limit)
sHolder.sortValues = orders
operation
} else {
s
}
case p: Project =>
val newChild = pushDownLimit(p.child, limit)
if (newChild == p.child) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can call p.withNewChildren(Seq(newChild)), which does this check for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the point.

p
} else {
p.copy(child = newChild)
}
case other => other
}

def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform {
case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
child match {
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue)
if (limitPushed) {
sHolder.pushedLimit = Some(limitValue)
}
globalLimit
case _ => globalLimit
val newChild = pushDownLimit(child, limitValue)
if (newChild == child) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

globalLimit
} else {
val localLimit = globalLimit.child.asInstanceOf[LocalLimit].copy(child = newChild)
globalLimit.copy(child = localLimit)
}
}

Expand All @@ -270,8 +297,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
val pushedDownOperators =
PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit)
val pushedDownOperators = PushedDownOperators(aggregation,
sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortValues)
V1ScanWrapper(v1, pushedFilters, pushedDownOperators)
case _ => scan
}
Expand All @@ -284,6 +311,8 @@ case class ScanBuilderHolder(
builder: ScanBuilder) extends LeafNode {
var pushedLimit: Option[Int] = None

var sortValues: Seq[SortOrder] = Seq.empty[SortOrder]

var pushedSample: Option[TableSampleInfo] = None
}

Expand Down
Loading