diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index 7e50bf14d7817..fa6447bc068d5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to * push down LIMIT. Please note that the combination of LIMIT with other operations * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java new file mode 100644 index 0000000000000..0212895fde079 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.SortOrder; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N + * with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. + * is NOT pushed down. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownTopN extends ScanBuilder { + + /** + * Pushes down top N to the data source. + */ + boolean pushTopN(SortOrder[] orders, int limit); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 5ad1dc88d3776..9304566801278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -142,13 +142,23 @@ case class RowDataSourceScanExec( handledFilters } + val topNOrLimitInfo = + if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) { + val pushedTopN = + s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" + + s" LIMIT ${pushedDownOperators.limit.get}" + Some("pushedTopN" -> pushedTopN) + } else { + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") + } + Map( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ pushedDownOperators.aggregation.fold(Map[String, String]()) { v => Map("PushedAggregates" -> seqToString(v.aggregateExpressions), "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ - pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ + topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6febbd590f246..c296ba9f29dd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -336,7 +336,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -410,7 +410,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -433,7 +433,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -726,6 +726,24 @@ object DataSourceStrategy } } + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { + def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { + case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => + val directionV2 = directionV1 match { + case Ascending => SortDirection.ASCENDING + case Descending => SortDirection.DESCENDING + } + val nullOrderingV2 = nullOrderingV1 match { + case NullsFirst => NullOrdering.NULLS_FIRST + case NullsLast => NullOrdering.NULLS_LAST + } + Some(SortValue(FieldReference(name), directionV2, nullOrderingV2)) + case _ => None + } + + sortOrders.flatMap(translateOortOrder) + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 394ba3f8bb8c2..baee53847a5a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -151,12 +152,14 @@ object JDBCRDD extends Logging { * @param options - JDBC options that contains url, table and other information. * @param outputSchema - The schema of the columns or aggregate columns to SELECT. * @param groupByColumns - The pushed down group by columns. + * @param sample - The pushed down tableSample. * @param limit - The pushed down limit. If the value is 0, it means no limit or limit * is not pushed down. - * @param sample - The pushed down tableSample. + * @param sortValues - The sort values cooperates with limit to realize top N. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ + // scalastyle:off argcount def scanTable( sc: SparkContext, schema: StructType, @@ -167,7 +170,8 @@ object JDBCRDD extends Logging { outputSchema: Option[StructType] = None, groupByColumns: Option[Array[String]] = None, sample: Option[TableSampleInfo] = None, - limit: Int = 0): RDD[InternalRow] = { + limit: Int = 0, + sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -187,8 +191,10 @@ object JDBCRDD extends Logging { options, groupByColumns, sample, - limit) + limit, + sortOrders) } + // scalastyle:on argcount } /** @@ -207,7 +213,8 @@ private[jdbc] class JDBCRDD( options: JDBCOptions, groupByColumns: Option[Array[String]], sample: Option[TableSampleInfo], - limit: Int) + limit: Int, + sortOrders: Array[SortOrder]) extends RDD[InternalRow](sc, Nil) { /** @@ -255,6 +262,14 @@ private[jdbc] class JDBCRDD( } } + private def getOrderByClause: String = { + if (sortOrders.nonEmpty) { + s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}" + } else { + "" + } + } + /** * Runs the SQL query against the JDBC driver. * @@ -339,7 +354,7 @@ private[jdbc] class JDBCRDD( val myLimitClause: String = dialect.getLimitClause(limit) val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + - s" $myWhereClause $getGroupByClause $myLimitClause" + s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index cd1eae89ee890..ecb207363cd59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf @@ -301,7 +302,8 @@ private[sql] case class JDBCRelation( filters: Array[Filter], groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], - limit: Int): RDD[Row] = { + limit: Int, + sortOrders: Array[SortOrder]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -313,7 +315,8 @@ private[sql] case class JDBCRelation( Some(finalSchema), groupByColumns, tableSample, - limit).asInstanceOf[RDD[Row]] + limit, + sortOrders).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index d51b4752175a2..39374b6924820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -22,10 +22,10 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -157,6 +157,17 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down top N to the data source Scan + */ + def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = { + scanBuilder match { + case s: SupportsPushDownTopN => + s.pushTopN(order, limit) + case _ => false + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala index c21354d646164..20ced9c17f7e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation /** @@ -25,4 +26,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation case class PushedDownOperators( aggregation: Option[Aggregation], sample: Option[TableSampleInfo], - limit: Option[Int]) + limit: Option[Int], + sortValues: Seq[SortOrder]) { + assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 36923d1c0bd64..e7c06d0b7520e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -246,17 +247,39 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match { + case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit) + if (limitPushed) { + sHolder.pushedLimit = Some(limit) + } + operation + case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) + if filter.isEmpty => + val orders = DataSourceStrategy.translateSortOrders(order) + if (orders.length == order.length) { + val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) + if (topNPushed) { + sHolder.pushedLimit = Some(limit) + sHolder.sortOrders = orders + operation + } else { + s + } + } else { + s + } + case p: Project => + val newChild = pushDownLimit(p.child, limit) + p.withNewChildren(Seq(newChild)) + case other => other + } + def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => - child match { - case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => - val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue) - if (limitPushed) { - sHolder.pushedLimit = Some(limitValue) - } - globalLimit - case _ => globalLimit - } + val newChild = pushDownLimit(child, limitValue) + val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) + globalLimit.withNewChildren(Seq(newLocalLimit)) } private def getWrappedScan( @@ -270,8 +293,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - val pushedDownOperators = - PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit) + val pushedDownOperators = PushedDownOperators(aggregation, + sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortOrders) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } @@ -284,6 +307,8 @@ case class ScanBuilderHolder( builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None + var sortOrders: Seq[SortOrder] = Seq.empty[SortOrder] + var pushedSample: Option[TableSampleInfo] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index ff79d1a5c4144..87ec9f43804e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -31,7 +32,8 @@ case class JDBCScan( pushedAggregateColumn: Array[String] = Array(), groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], - pushedLimit: Int) extends V1Scan { + pushedLimit: Int, + sortOrders: Array[SortOrder]) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -46,8 +48,8 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan( - columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit) + relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, + pushedLimit, sortOrders) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index d3c141ed53c5c..1760122133d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,8 +20,9 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -39,6 +40,7 @@ case class JDBCScanBuilder( with SupportsPushDownAggregates with SupportsPushDownLimit with SupportsPushDownTableSample + with SupportsPushDownTopN with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis @@ -51,6 +53,8 @@ case class JDBCScanBuilder( private var pushedLimit = 0 + private var sortOrders: Array[SortOrder] = Array.empty[SortOrder] + override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) @@ -119,8 +123,17 @@ case class JDBCScanBuilder( } override def pushLimit(limit: Int): Boolean = { - if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) { + if (jdbcOptions.pushDownLimit) { + pushedLimit = limit + return true + } + false + } + + override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { + if (jdbcOptions.pushDownLimit) { pushedLimit = limit + sortOrders = orders return true } false @@ -151,6 +164,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit) + pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index ecb514abac01c..f19ef7ead5f8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -58,7 +58,7 @@ private object DerbyDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } - // ToDo: use fetch first n rows only for limit, e.g. - // select * from employee fetch first 10 rows only; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 9a647e545d836..e0f11afcc2550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -396,11 +396,6 @@ abstract class JdbcDialect extends Serializable with Logging{ if (limit > 0 ) s"LIMIT $limit" else "" } - /** - * returns whether the dialect supports limit or not - */ - def supportsLimit(): Boolean = true - def supportsTableSample: Boolean = false def getTableSample(sample: TableSampleInfo): String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8dad5ef8e1eae..8e5674a181e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -119,6 +119,7 @@ private object MsSqlServerDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } - // ToDo: use top n to get limit, e.g. select top 100 * from employee; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 2a776bdb7ab04..13f4c5fe9c926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -56,6 +56,7 @@ private case object TeradataDialect extends JdbcDialect { s"RENAME TABLE $oldTable TO $newTable" } - // ToDo: use top n to get limit, e.g. select top 100 * from employee; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index b87b4f6d86fd1..5df875b569669 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,7 +23,8 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sort} +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{lit, sum, udf} @@ -102,14 +103,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df, true, 3) + checkPushedLimit(df, Some(3)) checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) } test("simple scan with LIMIT") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) - checkPushedLimit(df1, true, 1) + checkPushedLimit(df1, Some(1)) checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0))) val df2 = spark.read @@ -120,7 +121,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .filter($"dept" > 1) .limit(1) - checkPushedLimit(df2, true, 1) + checkPushedLimit(df2, Some(1)) checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") @@ -128,46 +129,117 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df3, true, 1) + checkPushedLimit(df3, Some(1)) checkAnswer(df3, Seq(Row("alex"))) val df4 = spark.read .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) - checkPushedLimit(df4, false, 0) + checkPushedLimit(df4, None) checkAnswer(df4, Seq(Row(1, 19000.00))) - val df5 = spark.read - .table("h2.test.employee") - .sort("SALARY") - .limit(1) - checkPushedLimit(df5, false, 0) - checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0))) - val name = udf { (x: String) => x.matches("cat|dav|amy") } val sub = udf { (x: String) => x.substring(0, 3) } - val df6 = spark.read + val df5 = spark.read .table("h2.test.employee") .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) .filter(name($"shortName")) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedLimit(df6, false, 0) - checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) + checkPushedLimit(df5, None) + checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = { + private def checkPushedLimit(df: DataFrame, limit: Option[Int] = None, + sortValues: Seq[SortValue] = Nil): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => - if (pushed) { - assert(v1.pushedDownOperators.limit === Some(limit)) - } else { - assert(v1.pushedDownOperators.limit.isEmpty) - } + assert(v1.pushedDownOperators.limit === limit) + assert(v1.pushedDownOperators.sortValues === sortValues) } } + if (sortValues.nonEmpty) { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + assert(sorts.isEmpty) + } + } + + test("simple scan with top N") { + val df1 = spark.read + .table("h2.test.employee") + .sort("salary") + .limit(1) + checkPushedLimit(df1, Some(1), createSortValues()) + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val df2 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary").limit(1) + checkPushedLimit(df2, Some(1), createSortValues()) + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val df3 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .orderBy($"salary".desc) + .limit(1) + checkPushedLimit( + df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) + checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0))) + + val df4 = + sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") + val scan = df4.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq("NAME"))) + checkPushedLimit(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) + checkAnswer(df4, Seq(Row("david"))) + + val df5 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary") + checkPushedLimit(df5, None) + checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) + + val df6 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + .limit(1) + checkPushedLimit(df6) + checkAnswer(df6, Seq(Row(1, 19000.00))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df7 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .sort($"SALARY".desc) + .limit(1) + // LIMIT is pushed down only if all the filters are pushed down + checkPushedLimit(df7) + checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + + val df8 = spark.read + .table("h2.test.employee") + .sort(sub($"NAME")) + .limit(1) + checkPushedLimit(df8) + checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0))) + } + + private def createSortValues( + sortDirection: SortDirection = SortDirection.ASCENDING, + nullOrdering: NullOrdering = NullOrdering.NULLS_FIRST): Seq[SortValue] = { + Seq(SortValue(FieldReference("salary"), sortDirection, nullOrdering)) } test("scan with filter push-down") {