Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37483][SQL] Support push down top N to JDBC data source V2 #34918

Closed
wants to merge 18 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.spark.annotation.Evolving;

/**
* A mix-in interface for {@link Scan}. Data sources can implement this interface to
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down LIMIT. Please note that the combination of LIMIT with other operations
* such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.SortOrder;

/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
* push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N
* with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc.
* is NOT pushed down.
*
* @since 3.3.0
*/
@Evolving
public interface SupportsPushDownTopN extends ScanBuilder {

/**
* Pushes down top N to the data source.
*/
boolean pushTopN(SortOrder[] orders, int limit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,23 @@ case class RowDataSourceScanExec(
handledFilters
}

val topNOrLimitInfo =
if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) {
val pushedTopN =
s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" +
s" LIMIT ${pushedDownOperators.limit.get}"
Some("pushedTopN" -> pushedTopN)
} else {
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value")
}

Map(
"ReadSchema" -> requiredSchema.catalogString,
"PushedFilters" -> seqToString(markedFilters.toSeq)) ++
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
Map("PushedAggregates" -> seqToString(v.aggregateExpressions),
"PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++
topNOrLimitInfo ++
pushedDownOperators.sample.map(v => "PushedSample" ->
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
Expand Down Expand Up @@ -336,7 +336,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
Expand Down Expand Up @@ -410,7 +410,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand All @@ -433,7 +433,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
PushedDownOperators(None, None, None),
PushedDownOperators(None, None, None, Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
Expand Down Expand Up @@ -726,6 +726,24 @@ object DataSourceStrategy
}
}

protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match {
case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
val directionV2 = directionV1 match {
case Ascending => SortDirection.ASCENDING
case Descending => SortDirection.DESCENDING
}
val nullOrderingV2 = nullOrderingV1 match {
case NullsFirst => NullOrdering.NULLS_FIRST
case NullsLast => NullOrdering.NULLS_LAST
}
Some(SortValue(FieldReference(name), directionV2, nullOrderingV2))
case _ => None
}

sortOrders.flatMap(translateOortOrder)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The origin PR #34738 issues(ref #34738 (comment)) an incompatible error when building with Scala 2.13.
I updated the code in this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @beliefer !

}

/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
Expand Down Expand Up @@ -151,12 +152,14 @@ object JDBCRDD extends Logging {
* @param options - JDBC options that contains url, table and other information.
* @param outputSchema - The schema of the columns or aggregate columns to SELECT.
* @param groupByColumns - The pushed down group by columns.
* @param sample - The pushed down tableSample.
* @param limit - The pushed down limit. If the value is 0, it means no limit or limit
* is not pushed down.
* @param sample - The pushed down tableSample.
* @param sortValues - The sort values cooperates with limit to realize top N.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
// scalastyle:off argcount
def scanTable(
sc: SparkContext,
schema: StructType,
Expand All @@ -167,7 +170,8 @@ object JDBCRDD extends Logging {
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[String]] = None,
sample: Option[TableSampleInfo] = None,
limit: Int = 0): RDD[InternalRow] = {
limit: Int = 0,
sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
Expand All @@ -187,8 +191,10 @@ object JDBCRDD extends Logging {
options,
groupByColumns,
sample,
limit)
limit,
sortOrders)
}
// scalastyle:on argcount
}

/**
Expand All @@ -207,7 +213,8 @@ private[jdbc] class JDBCRDD(
options: JDBCOptions,
groupByColumns: Option[Array[String]],
sample: Option[TableSampleInfo],
limit: Int)
limit: Int,
sortOrders: Array[SortOrder])
extends RDD[InternalRow](sc, Nil) {

/**
Expand Down Expand Up @@ -255,6 +262,14 @@ private[jdbc] class JDBCRDD(
}
}

private def getOrderByClause: String = {
if (sortOrders.nonEmpty) {
s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}"
} else {
""
}
}

/**
* Runs the SQL query against the JDBC driver.
*
Expand Down Expand Up @@ -339,7 +354,7 @@ private[jdbc] class JDBCRDD(
val myLimitClause: String = dialect.getLimitClause(limit)

val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
s" $myWhereClause $getGroupByClause $myLimitClause"
s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -301,7 +302,8 @@ private[sql] case class JDBCRelation(
filters: Array[Filter],
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
limit: Int): RDD[Row] = {
limit: Int,
sortOrders: Array[SortOrder]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
Expand All @@ -313,7 +315,8 @@ private[sql] case class JDBCRelation(
Some(finalSchema),
groupByColumns,
tableSample,
limit).asInstanceOf[RDD[Row]]
limit,
sortOrders).asInstanceOf[RDD[Row]]
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
Expand Down Expand Up @@ -157,6 +157,17 @@ object PushDownUtils extends PredicateHelper {
}
}

/**
* Pushes down top N to the data source Scan
*/
def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = {
scanBuilder match {
case s: SupportsPushDownTopN =>
s.pushTopN(order, limit)
case _ => false
}
}

/**
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation

/**
Expand All @@ -25,4 +26,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
case class PushedDownOperators(
aggregation: Option[Aggregation],
sample: Option[TableSampleInfo],
limit: Option[Int])
limit: Option[Int],
sortValues: Seq[SortOrder]) {
assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
Expand Down Expand Up @@ -246,17 +247,39 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}

private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match {
case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit)
if (limitPushed) {
sHolder.pushedLimit = Some(limit)
}
operation
case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))
if filter.isEmpty =>
val orders = DataSourceStrategy.translateSortOrders(order)
if (orders.length == order.length) {
val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
if (topNPushed) {
sHolder.pushedLimit = Some(limit)
sHolder.sortOrders = orders
operation
} else {
s
}
} else {
s
}
case p: Project =>
val newChild = pushDownLimit(p.child, limit)
p.withNewChildren(Seq(newChild))
case other => other
}

def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform {
case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
child match {
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue)
if (limitPushed) {
sHolder.pushedLimit = Some(limitValue)
}
globalLimit
case _ => globalLimit
}
val newChild = pushDownLimit(child, limitValue)
val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild))
globalLimit.withNewChildren(Seq(newLocalLimit))
}

private def getWrappedScan(
Expand All @@ -270,8 +293,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
val pushedDownOperators =
PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit)
val pushedDownOperators = PushedDownOperators(aggregation,
sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortOrders)
V1ScanWrapper(v1, pushedFilters, pushedDownOperators)
case _ => scan
}
Expand All @@ -284,6 +307,8 @@ case class ScanBuilderHolder(
builder: ScanBuilder) extends LeafNode {
var pushedLimit: Option[Int] = None

var sortOrders: Seq[SortOrder] = Seq.empty[SortOrder]

var pushedSample: Option[TableSampleInfo] = None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
Expand All @@ -31,7 +32,8 @@ case class JDBCScan(
pushedAggregateColumn: Array[String] = Array(),
groupByColumns: Option[Array[String]],
tableSample: Option[TableSampleInfo],
pushedLimit: Int) extends V1Scan {
pushedLimit: Int,
sortOrders: Array[SortOrder]) extends V1Scan {

override def readSchema(): StructType = prunedSchema

Expand All @@ -46,8 +48,8 @@ case class JDBCScan(
} else {
pushedAggregateColumn
}
relation.buildScan(
columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit)
relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample,
pushedLimit, sortOrders)
}
}.asInstanceOf[T]
}
Expand Down
Loading