Skip to content

Commit cdb3f71

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-37483][SQL] Support push down top N to JDBC data source V2
### What changes were proposed in this pull request? Currently, Spark supports push down limit to data source. However, in the user's scenario, limit must have the premise of order by. Because limit and order by are more valuable together. On the other hand, push down top N(same as order by ... limit N) outputs the data with basic order to Spark sort, the the sort of Spark may have some performance improvement. ### Why are the changes needed? 1. push down top N is very useful for users scenario. 2. push down top N could improves the performance of sort. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the physical execute. ### How was this patch tested? New tests. Closes #34918 from beliefer/SPARK-37483. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 45a145e commit cdb3f71

16 files changed

+276
-68
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import org.apache.spark.annotation.Evolving;
2121

2222
/**
23-
* A mix-in interface for {@link Scan}. Data sources can implement this interface to
23+
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
2424
* push down LIMIT. Please note that the combination of LIMIT with other operations
2525
* such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down.
2626
*
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.read;
19+
20+
import org.apache.spark.annotation.Evolving;
21+
import org.apache.spark.sql.connector.expressions.SortOrder;
22+
23+
/**
24+
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
25+
* push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N
26+
* with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc.
27+
* is NOT pushed down.
28+
*
29+
* @since 3.3.0
30+
*/
31+
@Evolving
32+
public interface SupportsPushDownTopN extends ScanBuilder {
33+
34+
/**
35+
* Pushes down top N to the data source.
36+
*/
37+
boolean pushTopN(SortOrder[] orders, int limit);
38+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

+11-1
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,23 @@ case class RowDataSourceScanExec(
142142
handledFilters
143143
}
144144

145+
val topNOrLimitInfo =
146+
if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) {
147+
val pushedTopN =
148+
s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" +
149+
s" LIMIT ${pushedDownOperators.limit.get}"
150+
Some("pushedTopN" -> pushedTopN)
151+
} else {
152+
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value")
153+
}
154+
145155
Map(
146156
"ReadSchema" -> requiredSchema.catalogString,
147157
"PushedFilters" -> seqToString(markedFilters.toSeq)) ++
148158
pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
149159
Map("PushedAggregates" -> seqToString(v.aggregateExpressions),
150160
"PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++
151-
pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++
161+
topNOrLimitInfo ++
152162
pushedDownOperators.sample.map(v => "PushedSample" ->
153163
s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
154164
)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

+22-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
4040
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
4141
import org.apache.spark.sql.connector.catalog.SupportsRead
4242
import org.apache.spark.sql.connector.catalog.TableCapability._
43-
import org.apache.spark.sql.connector.expressions.FieldReference
43+
import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue}
4444
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
4545
import org.apache.spark.sql.errors.QueryCompilationErrors
4646
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
@@ -336,7 +336,7 @@ object DataSourceStrategy
336336
l.output.toStructType,
337337
Set.empty,
338338
Set.empty,
339-
PushedDownOperators(None, None, None),
339+
PushedDownOperators(None, None, None, Seq.empty),
340340
toCatalystRDD(l, baseRelation.buildScan()),
341341
baseRelation,
342342
None) :: Nil
@@ -410,7 +410,7 @@ object DataSourceStrategy
410410
requestedColumns.toStructType,
411411
pushedFilters.toSet,
412412
handledFilters,
413-
PushedDownOperators(None, None, None),
413+
PushedDownOperators(None, None, None, Seq.empty),
414414
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
415415
relation.relation,
416416
relation.catalogTable.map(_.identifier))
@@ -433,7 +433,7 @@ object DataSourceStrategy
433433
requestedColumns.toStructType,
434434
pushedFilters.toSet,
435435
handledFilters,
436-
PushedDownOperators(None, None, None),
436+
PushedDownOperators(None, None, None, Seq.empty),
437437
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
438438
relation.relation,
439439
relation.catalogTable.map(_.identifier))
@@ -726,6 +726,24 @@ object DataSourceStrategy
726726
}
727727
}
728728

729+
protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = {
730+
def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match {
731+
case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
732+
val directionV2 = directionV1 match {
733+
case Ascending => SortDirection.ASCENDING
734+
case Descending => SortDirection.DESCENDING
735+
}
736+
val nullOrderingV2 = nullOrderingV1 match {
737+
case NullsFirst => NullOrdering.NULLS_FIRST
738+
case NullsLast => NullOrdering.NULLS_LAST
739+
}
740+
Some(SortValue(FieldReference(name), directionV2, nullOrderingV2))
741+
case _ => None
742+
}
743+
744+
sortOrders.flatMap(translateOortOrder)
745+
}
746+
729747
/**
730748
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
731749
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

+20-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.rdd.RDD
2727
import org.apache.spark.sql.catalyst.InternalRow
28+
import org.apache.spark.sql.connector.expressions.SortOrder
2829
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
2930
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
3031
import org.apache.spark.sql.sources._
@@ -151,12 +152,14 @@ object JDBCRDD extends Logging {
151152
* @param options - JDBC options that contains url, table and other information.
152153
* @param outputSchema - The schema of the columns or aggregate columns to SELECT.
153154
* @param groupByColumns - The pushed down group by columns.
155+
* @param sample - The pushed down tableSample.
154156
* @param limit - The pushed down limit. If the value is 0, it means no limit or limit
155157
* is not pushed down.
156-
* @param sample - The pushed down tableSample.
158+
* @param sortValues - The sort values cooperates with limit to realize top N.
157159
*
158160
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
159161
*/
162+
// scalastyle:off argcount
160163
def scanTable(
161164
sc: SparkContext,
162165
schema: StructType,
@@ -167,7 +170,8 @@ object JDBCRDD extends Logging {
167170
outputSchema: Option[StructType] = None,
168171
groupByColumns: Option[Array[String]] = None,
169172
sample: Option[TableSampleInfo] = None,
170-
limit: Int = 0): RDD[InternalRow] = {
173+
limit: Int = 0,
174+
sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = {
171175
val url = options.url
172176
val dialect = JdbcDialects.get(url)
173177
val quotedColumns = if (groupByColumns.isEmpty) {
@@ -187,8 +191,10 @@ object JDBCRDD extends Logging {
187191
options,
188192
groupByColumns,
189193
sample,
190-
limit)
194+
limit,
195+
sortOrders)
191196
}
197+
// scalastyle:on argcount
192198
}
193199

194200
/**
@@ -207,7 +213,8 @@ private[jdbc] class JDBCRDD(
207213
options: JDBCOptions,
208214
groupByColumns: Option[Array[String]],
209215
sample: Option[TableSampleInfo],
210-
limit: Int)
216+
limit: Int,
217+
sortOrders: Array[SortOrder])
211218
extends RDD[InternalRow](sc, Nil) {
212219

213220
/**
@@ -255,6 +262,14 @@ private[jdbc] class JDBCRDD(
255262
}
256263
}
257264

265+
private def getOrderByClause: String = {
266+
if (sortOrders.nonEmpty) {
267+
s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}"
268+
} else {
269+
""
270+
}
271+
}
272+
258273
/**
259274
* Runs the SQL query against the JDBC driver.
260275
*
@@ -339,7 +354,7 @@ private[jdbc] class JDBCRDD(
339354
val myLimitClause: String = dialect.getLimitClause(limit)
340355

341356
val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
342-
s" $myWhereClause $getGroupByClause $myLimitClause"
357+
s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause"
343358
stmt = conn.prepareStatement(sqlText,
344359
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
345360
stmt.setFetchSize(options.fetchSize)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala

+5-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
2727
import org.apache.spark.sql.catalyst.analysis._
2828
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
2929
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
30+
import org.apache.spark.sql.connector.expressions.SortOrder
3031
import org.apache.spark.sql.errors.QueryCompilationErrors
3132
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
3233
import org.apache.spark.sql.internal.SQLConf
@@ -301,7 +302,8 @@ private[sql] case class JDBCRelation(
301302
filters: Array[Filter],
302303
groupByColumns: Option[Array[String]],
303304
tableSample: Option[TableSampleInfo],
304-
limit: Int): RDD[Row] = {
305+
limit: Int,
306+
sortOrders: Array[SortOrder]): RDD[Row] = {
305307
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
306308
JDBCRDD.scanTable(
307309
sparkSession.sparkContext,
@@ -313,7 +315,8 @@ private[sql] case class JDBCRelation(
313315
Some(finalSchema),
314316
groupByColumns,
315317
tableSample,
316-
limit).asInstanceOf[RDD[Row]]
318+
limit,
319+
sortOrders).asInstanceOf[RDD[Row]]
317320
}
318321

319322
override def insert(data: DataFrame, overwrite: Boolean): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala

+13-2
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ import scala.collection.mutable
2222
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
25-
import org.apache.spark.sql.connector.expressions.FieldReference
25+
import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder}
2626
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
2727
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
28-
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters}
28+
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
2929
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn}
3030
import org.apache.spark.sql.internal.SQLConf
3131
import org.apache.spark.sql.sources
@@ -157,6 +157,17 @@ object PushDownUtils extends PredicateHelper {
157157
}
158158
}
159159

160+
/**
161+
* Pushes down top N to the data source Scan
162+
*/
163+
def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = {
164+
scanBuilder match {
165+
case s: SupportsPushDownTopN =>
166+
s.pushTopN(order, limit)
167+
case _ => false
168+
}
169+
}
170+
160171
/**
161172
* Applies column pruning to the data source, w.r.t. the references of the given expressions.
162173
*

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20+
import org.apache.spark.sql.connector.expressions.SortOrder
2021
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
2122

2223
/**
@@ -25,4 +26,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
2526
case class PushedDownOperators(
2627
aggregation: Option[Aggregation],
2728
sample: Option[TableSampleInfo],
28-
limit: Option[Int])
29+
limit: Option[Int],
30+
sortValues: Seq[SortOrder]) {
31+
assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
32+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

+37-12
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer
2323
import org.apache.spark.sql.catalyst.expressions.aggregate
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2525
import org.apache.spark.sql.catalyst.planning.ScanOperation
26-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample}
26+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
2727
import org.apache.spark.sql.catalyst.rules.Rule
28+
import org.apache.spark.sql.connector.expressions.SortOrder
2829
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
2930
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
3031
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
@@ -246,17 +247,39 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
246247
}
247248
}
248249

250+
private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match {
251+
case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
252+
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit)
253+
if (limitPushed) {
254+
sHolder.pushedLimit = Some(limit)
255+
}
256+
operation
257+
case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))
258+
if filter.isEmpty =>
259+
val orders = DataSourceStrategy.translateSortOrders(order)
260+
if (orders.length == order.length) {
261+
val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
262+
if (topNPushed) {
263+
sHolder.pushedLimit = Some(limit)
264+
sHolder.sortOrders = orders
265+
operation
266+
} else {
267+
s
268+
}
269+
} else {
270+
s
271+
}
272+
case p: Project =>
273+
val newChild = pushDownLimit(p.child, limit)
274+
p.withNewChildren(Seq(newChild))
275+
case other => other
276+
}
277+
249278
def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform {
250279
case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
251-
child match {
252-
case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 =>
253-
val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue)
254-
if (limitPushed) {
255-
sHolder.pushedLimit = Some(limitValue)
256-
}
257-
globalLimit
258-
case _ => globalLimit
259-
}
280+
val newChild = pushDownLimit(child, limitValue)
281+
val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild))
282+
globalLimit.withNewChildren(Seq(newLocalLimit))
260283
}
261284

262285
private def getWrappedScan(
@@ -270,8 +293,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
270293
f.pushedFilters()
271294
case _ => Array.empty[sources.Filter]
272295
}
273-
val pushedDownOperators =
274-
PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit)
296+
val pushedDownOperators = PushedDownOperators(aggregation,
297+
sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortOrders)
275298
V1ScanWrapper(v1, pushedFilters, pushedDownOperators)
276299
case _ => scan
277300
}
@@ -284,6 +307,8 @@ case class ScanBuilderHolder(
284307
builder: ScanBuilder) extends LeafNode {
285308
var pushedLimit: Option[Int] = None
286309

310+
var sortOrders: Seq[SortOrder] = Seq.empty[SortOrder]
311+
287312
var pushedSample: Option[TableSampleInfo] = None
288313
}
289314

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
1818

1919
import org.apache.spark.rdd.RDD
2020
import org.apache.spark.sql.{Row, SQLContext}
21+
import org.apache.spark.sql.connector.expressions.SortOrder
2122
import org.apache.spark.sql.connector.read.V1Scan
2223
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
2324
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
@@ -31,7 +32,8 @@ case class JDBCScan(
3132
pushedAggregateColumn: Array[String] = Array(),
3233
groupByColumns: Option[Array[String]],
3334
tableSample: Option[TableSampleInfo],
34-
pushedLimit: Int) extends V1Scan {
35+
pushedLimit: Int,
36+
sortOrders: Array[SortOrder]) extends V1Scan {
3537

3638
override def readSchema(): StructType = prunedSchema
3739

@@ -46,8 +48,8 @@ case class JDBCScan(
4648
} else {
4749
pushedAggregateColumn
4850
}
49-
relation.buildScan(
50-
columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit)
51+
relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample,
52+
pushedLimit, sortOrders)
5153
}
5254
}.asInstanceOf[T]
5355
}

0 commit comments

Comments
 (0)