From 316610736f0993a1672bf6a4a16140e3d65db78b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 16 Dec 2021 22:28:31 +0800 Subject: [PATCH] [SPARK-37483][SQL] Support push down top N to JDBC data source V2 ### What changes were proposed in this pull request? Currently, Spark supports push down limit to data source. However, in the user's scenario, limit must have the premise of order by. Because limit and order by are more valuable together. On the other hand, push down top N(same as order by ... limit N) outputs the data with basic order to Spark sort, the the sort of Spark may have some performance improvement. ### Why are the changes needed? 1. push down top N is very useful for users scenario. 2. push down top N could improves the performance of sort. ### Does this PR introduce _any_ user-facing change? 'No'. Just change the physical execute. ### How was this patch tested? New tests. Closes #34918 from beliefer/SPARK-37483. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../connector/read/SupportsPushDownLimit.java | 2 +- .../connector/read/SupportsPushDownTopN.java | 38 ++++++ .../sql/execution/DataSourceScanExec.scala | 12 +- .../datasources/DataSourceStrategy.scala | 26 +++- .../execution/datasources/jdbc/JDBCRDD.scala | 25 +++- .../datasources/jdbc/JDBCRelation.scala | 7 +- .../datasources/v2/PushDownUtils.scala | 15 ++- .../datasources/v2/PushedDownOperators.scala | 6 +- .../v2/V2ScanRelationPushDown.scala | 49 ++++++-- .../datasources/v2/jdbc/JDBCScan.scala | 8 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 19 ++- .../apache/spark/sql/jdbc/DerbyDialect.scala | 6 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 5 - .../spark/sql/jdbc/MsSqlServerDialect.scala | 5 +- .../spark/sql/jdbc/TeradataDialect.scala | 5 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 116 ++++++++++++++---- 16 files changed, 276 insertions(+), 68 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index 7e50bf14d7817..fa6447bc068d5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to * push down LIMIT. Please note that the combination of LIMIT with other operations * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java new file mode 100644 index 0000000000000..0212895fde079 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.SortOrder; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N + * with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. + * is NOT pushed down. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownTopN extends ScanBuilder { + + /** + * Pushes down top N to the data source. + */ + boolean pushTopN(SortOrder[] orders, int limit); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 18ad5b81560e1..8bc18ef253f5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -142,13 +142,23 @@ case class RowDataSourceScanExec( handledFilters } + val topNOrLimitInfo = + if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) { + val pushedTopN = + s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" + + s" LIMIT ${pushedDownOperators.limit.get}" + Some("pushedTopN" -> pushedTopN) + } else { + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") + } + Map( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => Map("PushedAggregates" -> seqToString(v.aggregateExpressions), "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ - pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ + topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index e8fb9ca3e46c3..84df3f8dd5b65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -336,7 +336,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -410,7 +410,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -433,7 +433,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -723,6 +723,24 @@ object DataSourceStrategy } } + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { + def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { + case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => + val directionV2 = directionV1 match { + case Ascending => SortDirection.ASCENDING + case Descending => SortDirection.DESCENDING + } + val nullOrderingV2 = nullOrderingV1 match { + case NullsFirst => NullOrdering.NULLS_FIRST + case NullsLast => NullOrdering.NULLS_LAST + } + Some(SortValue(FieldReference(name), directionV2, nullOrderingV2)) + case _ => None + } + + sortOrders.flatMap(translateOortOrder) + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 394ba3f8bb8c2..baee53847a5a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -151,12 +152,14 @@ object JDBCRDD extends Logging { * @param options - JDBC options that contains url, table and other information. * @param outputSchema - The schema of the columns or aggregate columns to SELECT. * @param groupByColumns - The pushed down group by columns. + * @param sample - The pushed down tableSample. * @param limit - The pushed down limit. If the value is 0, it means no limit or limit * is not pushed down. - * @param sample - The pushed down tableSample. + * @param sortValues - The sort values cooperates with limit to realize top N. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ + // scalastyle:off argcount def scanTable( sc: SparkContext, schema: StructType, @@ -167,7 +170,8 @@ object JDBCRDD extends Logging { outputSchema: Option[StructType] = None, groupByColumns: Option[Array[String]] = None, sample: Option[TableSampleInfo] = None, - limit: Int = 0): RDD[InternalRow] = { + limit: Int = 0, + sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -187,8 +191,10 @@ object JDBCRDD extends Logging { options, groupByColumns, sample, - limit) + limit, + sortOrders) } + // scalastyle:on argcount } /** @@ -207,7 +213,8 @@ private[jdbc] class JDBCRDD( options: JDBCOptions, groupByColumns: Option[Array[String]], sample: Option[TableSampleInfo], - limit: Int) + limit: Int, + sortOrders: Array[SortOrder]) extends RDD[InternalRow](sc, Nil) { /** @@ -255,6 +262,14 @@ private[jdbc] class JDBCRDD( } } + private def getOrderByClause: String = { + if (sortOrders.nonEmpty) { + s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}" + } else { + "" + } + } + /** * Runs the SQL query against the JDBC driver. * @@ -339,7 +354,7 @@ private[jdbc] class JDBCRDD( val myLimitClause: String = dialect.getLimitClause(limit) val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + - s" $myWhereClause $getGroupByClause $myLimitClause" + s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index cd1eae89ee890..ecb207363cd59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf @@ -301,7 +302,8 @@ private[sql] case class JDBCRelation( filters: Array[Filter], groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], - limit: Int): RDD[Row] = { + limit: Int, + sortOrders: Array[SortOrder]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -313,7 +315,8 @@ private[sql] case class JDBCRelation( Some(finalSchema), groupByColumns, tableSample, - limit).asInstanceOf[RDD[Row]] + limit, + sortOrders).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index a98b8979d3e3f..2b26eee45221d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -22,10 +22,10 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -158,6 +158,17 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down top N to the data source Scan + */ + def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = { + scanBuilder match { + case s: SupportsPushDownTopN => + s.pushTopN(order, limit) + case _ => false + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala index c21354d646164..20ced9c17f7e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation /** @@ -25,4 +26,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation case class PushedDownOperators( aggregation: Option[Aggregation], sample: Option[TableSampleInfo], - limit: Option[Int]) + limit: Option[Int], + sortValues: Seq[SortOrder]) { + assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index f73f831903364..148864e8e4b3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -246,17 +247,39 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match { + case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit) + if (limitPushed) { + sHolder.pushedLimit = Some(limit) + } + operation + case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) + if filter.isEmpty => + val orders = DataSourceStrategy.translateSortOrders(order) + if (orders.length == order.length) { + val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) + if (topNPushed) { + sHolder.pushedLimit = Some(limit) + sHolder.sortOrders = orders + operation + } else { + s + } + } else { + s + } + case p: Project => + val newChild = pushDownLimit(p.child, limit) + p.withNewChildren(Seq(newChild)) + case other => other + } + def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => - child match { - case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => - val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue) - if (limitPushed) { - sHolder.pushedLimit = Some(limitValue) - } - globalLimit - case _ => globalLimit - } + val newChild = pushDownLimit(child, limitValue) + val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) + globalLimit.withNewChildren(Seq(newLocalLimit)) } private def getWrappedScan( @@ -270,8 +293,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - val pushedDownOperators = - PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit) + val pushedDownOperators = PushedDownOperators(aggregation, + sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortOrders) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } @@ -284,6 +307,8 @@ case class ScanBuilderHolder( builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None + var sortOrders: Seq[SortOrder] = Seq.empty[SortOrder] + var pushedSample: Option[TableSampleInfo] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index ff79d1a5c4144..87ec9f43804e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -31,7 +32,8 @@ case class JDBCScan( pushedAggregateColumn: Array[String] = Array(), groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], - pushedLimit: Int) extends V1Scan { + pushedLimit: Int, + sortOrders: Array[SortOrder]) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -46,8 +48,8 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan( - columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit) + relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, + pushedLimit, sortOrders) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index d3c141ed53c5c..1760122133d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,8 +20,9 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -39,6 +40,7 @@ case class JDBCScanBuilder( with SupportsPushDownAggregates with SupportsPushDownLimit with SupportsPushDownTableSample + with SupportsPushDownTopN with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis @@ -51,6 +53,8 @@ case class JDBCScanBuilder( private var pushedLimit = 0 + private var sortOrders: Array[SortOrder] = Array.empty[SortOrder] + override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) @@ -119,8 +123,17 @@ case class JDBCScanBuilder( } override def pushLimit(limit: Int): Boolean = { - if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) { + if (jdbcOptions.pushDownLimit) { + pushedLimit = limit + return true + } + false + } + + override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { + if (jdbcOptions.pushDownLimit) { pushedLimit = limit + sortOrders = orders return true } false @@ -151,6 +164,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit) + pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index ecb514abac01c..f19ef7ead5f8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -58,7 +58,7 @@ private object DerbyDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } - // ToDo: use fetch first n rows only for limit, e.g. - // select * from employee fetch first 10 rows only; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 6d90432859d71..5a445c5d56bdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -336,11 +336,6 @@ abstract class JdbcDialect extends Serializable with Logging{ if (limit > 0 ) s"LIMIT $limit" else "" } - /** - * returns whether the dialect supports limit or not - */ - def supportsLimit(): Boolean = true - def supportsTableSample: Boolean = false def getTableSample(sample: TableSampleInfo): String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8dad5ef8e1eae..8e5674a181e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -119,6 +119,7 @@ private object MsSqlServerDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } - // ToDo: use top n to get limit, e.g. select top 100 * from employee; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 2a776bdb7ab04..13f4c5fe9c926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -56,6 +56,7 @@ private case object TeradataDialect extends JdbcDialect { s"RENAME TABLE $oldTable TO $newTable" } - // ToDo: use top n to get limit, e.g. select top 100 * from employee; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 39b3c19ac1db8..5f10f2ef105d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,7 +23,8 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sort} +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{lit, sum, udf} @@ -102,14 +103,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df, true, 3) + checkPushedLimit(df, Some(3)) checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) } test("simple scan with LIMIT") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) - checkPushedLimit(df1, true, 1) + checkPushedLimit(df1, Some(1)) checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0))) val df2 = spark.read @@ -120,7 +121,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .filter($"dept" > 1) .limit(1) - checkPushedLimit(df2, true, 1) + checkPushedLimit(df2, Some(1)) checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") @@ -128,46 +129,117 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df3, true, 1) + checkPushedLimit(df3, Some(1)) checkAnswer(df3, Seq(Row("alex"))) val df4 = spark.read .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) - checkPushedLimit(df4, false, 0) + checkPushedLimit(df4, None) checkAnswer(df4, Seq(Row(1, 19000.00))) - val df5 = spark.read - .table("h2.test.employee") - .sort("SALARY") - .limit(1) - checkPushedLimit(df5, false, 0) - checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0))) - val name = udf { (x: String) => x.matches("cat|dav|amy") } val sub = udf { (x: String) => x.substring(0, 3) } - val df6 = spark.read + val df5 = spark.read .table("h2.test.employee") .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) .filter(name($"shortName")) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedLimit(df6, false, 0) - checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) + checkPushedLimit(df5, None) + checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = { + private def checkPushedLimit(df: DataFrame, limit: Option[Int] = None, + sortValues: Seq[SortValue] = Nil): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => - if (pushed) { - assert(v1.pushedDownOperators.limit === Some(limit)) - } else { - assert(v1.pushedDownOperators.limit.isEmpty) - } + assert(v1.pushedDownOperators.limit === limit) + assert(v1.pushedDownOperators.sortValues === sortValues) } } + if (sortValues.nonEmpty) { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + assert(sorts.isEmpty) + } + } + + test("simple scan with top N") { + val df1 = spark.read + .table("h2.test.employee") + .sort("salary") + .limit(1) + checkPushedLimit(df1, Some(1), createSortValues()) + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val df2 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary").limit(1) + checkPushedLimit(df2, Some(1), createSortValues()) + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val df3 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .orderBy($"salary".desc) + .limit(1) + checkPushedLimit( + df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) + checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0))) + + val df4 = + sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") + val scan = df4.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq("NAME"))) + checkPushedLimit(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) + checkAnswer(df4, Seq(Row("david"))) + + val df5 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary") + checkPushedLimit(df5, None) + checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) + + val df6 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + .limit(1) + checkPushedLimit(df6) + checkAnswer(df6, Seq(Row(1, 19000.00))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df7 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .sort($"SALARY".desc) + .limit(1) + // LIMIT is pushed down only if all the filters are pushed down + checkPushedLimit(df7) + checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + + val df8 = spark.read + .table("h2.test.employee") + .sort(sub($"NAME")) + .limit(1) + checkPushedLimit(df8) + checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0))) + } + + private def createSortValues( + sortDirection: SortDirection = SortDirection.ASCENDING, + nullOrdering: NullOrdering = NullOrdering.NULLS_FIRST): Seq[SortValue] = { + Seq(SortValue(FieldReference("salary"), sortDirection, nullOrdering)) } test("scan with filter push-down") {