From 70117e94d6c1f19baf0806baaf06cc355a948970 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 29 Nov 2021 18:01:10 +0800 Subject: [PATCH 01/18] [SPARK-37483][SQL] Support pushdown down top N to JDBC data source V2 --- .../connector/read/SupportsPushDownLimit.java | 6 ++ .../sql/execution/DataSourceScanExec.scala | 1 + .../datasources/DataSourceStrategy.scala | 25 +++++- .../datasources/jdbc/JDBCOptions.scala | 5 ++ .../execution/datasources/jdbc/JDBCRDD.scala | 23 ++++-- .../datasources/jdbc/JDBCRelation.scala | 7 +- .../datasources/v2/PushDownUtils.scala | 13 ++- .../datasources/v2/PushedDownOperators.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 24 +++++- .../datasources/v2/jdbc/JDBCScan.scala | 8 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 14 +++- .../apache/spark/sql/jdbc/JdbcDialects.scala | 2 + .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 80 ++++++++++++++++++- 13 files changed, 188 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index 7e50bf14d7817..2466a7ce5025e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.read; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.SortValue; /** * A mix-in interface for {@link Scan}. Data sources can implement this interface to @@ -33,4 +34,9 @@ public interface SupportsPushDownLimit extends ScanBuilder { * Pushes down LIMIT to the data source. */ boolean pushLimit(int limit); + + /** + * Pushes down top N to the data source. + */ + boolean pushTopN(SortValue[] orders, int limit); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 5ad1dc88d3776..072db3d8d32b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -149,6 +149,7 @@ case class RowDataSourceScanExec( Map("PushedAggregates" -> seqToString(v.aggregateExpressions), "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ +// Map("PushedSortValues" -> seqToString(pushedDownOperators.sortValues)) ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6febbd590f246..c5aca70ad315f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -336,7 +336,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -410,7 +410,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -433,7 +433,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - PushedDownOperators(None, None, None), + PushedDownOperators(None, None, None, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -726,6 +726,23 @@ object DataSourceStrategy } } + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortValue] = { + sortOrders.map { + case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => + val (directionV2, nullOrderingV2) = (directionV1, nullOrderingV1) match { + case (Ascending, NullsFirst) => + (SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + case (Ascending, NullsLast) => + (SortDirection.ASCENDING, NullOrdering.NULLS_LAST) + case (Descending, NullsFirst) => + (SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) + case (Descending, NullsLast) => + (SortDirection.DESCENDING, NullOrdering.NULLS_LAST) + } + SortValue(FieldReference(name), directionV2, nullOrderingV2) + } + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d081e0ace0e44..eb8841b5b699a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -196,6 +196,10 @@ class JDBCOptions( // This only applies to Data Source V2 JDBC val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean + // An option to allow/disallow pushing down query of top N into V2 JDBC data source + // This only applies to Data Source V2 JDBC + val pushDownTopN = parameters.getOrElse(JDBC_PUSHDOWN_TOP_N, "false").toBoolean + // An option to allow/disallow pushing down TABLESAMPLE into JDBC data source // This only applies to Data Source V2 JDBC val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean @@ -276,6 +280,7 @@ object JDBCOptions { val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit") + val JDBC_PUSHDOWN_TOP_N = newOption("pushDownTopN") val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 394ba3f8bb8c2..c0d5dd973b959 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,6 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.SortValue import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -151,9 +152,10 @@ object JDBCRDD extends Logging { * @param options - JDBC options that contains url, table and other information. * @param outputSchema - The schema of the columns or aggregate columns to SELECT. * @param groupByColumns - The pushed down group by columns. + * @param sample - The pushed down tableSample. * @param limit - The pushed down limit. If the value is 0, it means no limit or limit * is not pushed down. - * @param sample - The pushed down tableSample. + * @param sortValues - The sort values cooperates with limit to realize top N. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ @@ -167,7 +169,8 @@ object JDBCRDD extends Logging { outputSchema: Option[StructType] = None, groupByColumns: Option[Array[String]] = None, sample: Option[TableSampleInfo] = None, - limit: Int = 0): RDD[InternalRow] = { + limit: Int = 0, + sortValues: Array[SortValue] = Array.empty[SortValue]): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -187,7 +190,8 @@ object JDBCRDD extends Logging { options, groupByColumns, sample, - limit) + limit, + sortValues) } } @@ -207,7 +211,8 @@ private[jdbc] class JDBCRDD( options: JDBCOptions, groupByColumns: Option[Array[String]], sample: Option[TableSampleInfo], - limit: Int) + limit: Int, + sortValues: Array[SortValue]) extends RDD[InternalRow](sc, Nil) { /** @@ -255,6 +260,14 @@ private[jdbc] class JDBCRDD( } } + private def getOrderByClause: String = { + if (sortValues.nonEmpty) { + s" ORDER BY ${sortValues.map(_.describe()).mkString(", ")}" + } else { + "" + } + } + /** * Runs the SQL query against the JDBC driver. * @@ -339,7 +352,7 @@ private[jdbc] class JDBCRDD( val myLimitClause: String = dialect.getLimitClause(limit) val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + - s" $myWhereClause $getGroupByClause $myLimitClause" + s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index cd1eae89ee890..8cb3f59f47cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.connector.expressions.SortValue import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf @@ -301,7 +302,8 @@ private[sql] case class JDBCRelation( filters: Array[Filter], groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], - limit: Int): RDD[Row] = { + limit: Int, + sortValues: Array[SortValue]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -313,7 +315,8 @@ private[sql] case class JDBCRelation( Some(finalSchema), groupByColumns, tableSample, - limit).asInstanceOf[RDD[Row]] + limit, + sortValues).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index d51b4752175a2..9d9d4b63f4fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters} @@ -157,6 +157,17 @@ object PushDownUtils extends PredicateHelper { } } + /** + * Pushes down top N to the data source Scan + */ + def pushTopN(scanBuilder: ScanBuilder, order: Array[SortValue], limit: Int): Boolean = { + scanBuilder match { + case s: SupportsPushDownLimit => + s.pushTopN(order, limit) + case _ => false + } + } + /** * Applies column pruning to the data source, w.r.t. the references of the given expressions. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala index c21354d646164..4f8a59756f5a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 +import org.apache.spark.sql.connector.expressions.SortValue import org.apache.spark.sql.connector.expressions.aggregate.Aggregation /** @@ -25,4 +26,5 @@ import org.apache.spark.sql.connector.expressions.aggregate.Aggregation case class PushedDownOperators( aggregation: Option[Aggregation], sample: Option[TableSampleInfo], - limit: Option[Int]) + limit: Option[Int], + sortValues: Seq[SortValue]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 36923d1c0bd64..c0033c51cd882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.expressions.SortValue import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -255,7 +256,20 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { sHolder.pushedLimit = Some(limitValue) } globalLimit - case _ => globalLimit + case _ => + child transform { + case sort @ Sort(order, _, ScanOperation(_, filter, sHolder: ScanBuilderHolder)) + if filter.length == 0 => + val orders = DataSourceStrategy.translateSortOrders(order) + val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) + if (topNPushed) { + sHolder.pushedLimit = Some(limitValue) + sHolder.sortValues = orders + } + sort + case other => other + } + globalLimit } } @@ -270,8 +284,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - val pushedDownOperators = - PushedDownOperators(aggregation, sHolder.pushedSample, sHolder.pushedLimit) + val pushedDownOperators = PushedDownOperators(aggregation, + sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortValues) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } @@ -284,6 +298,8 @@ case class ScanBuilderHolder( builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None + var sortValues: Seq[SortValue] = Seq.empty[SortValue] + var pushedSample: Option[TableSampleInfo] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index ff79d1a5c4144..d958f4d9b4043 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.connector.expressions.SortValue import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -31,7 +32,8 @@ case class JDBCScan( pushedAggregateColumn: Array[String] = Array(), groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], - pushedLimit: Int) extends V1Scan { + pushedLimit: Int, + sortValues: Array[SortValue]) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -46,8 +48,8 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan( - columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, pushedLimit) + relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, + pushedLimit, sortValues) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index d3c141ed53c5c..24cfb72025d97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,6 +20,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.SortValue import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample} import org.apache.spark.sql.execution.datasources.PartitioningUtils @@ -51,6 +52,8 @@ case class JDBCScanBuilder( private var pushedLimit = 0 + private var sortValues: Array[SortValue] = Array.empty[SortValue] + override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) @@ -126,6 +129,15 @@ case class JDBCScanBuilder( false } + override def pushTopN(orders: Array[SortValue], limit: Int): Boolean = { + if (jdbcOptions.pushDownTopN && JdbcDialects.get(jdbcOptions.url).supportsTopN) { + pushedLimit = limit + sortValues = orders + return true + } + false + } + override def pruneColumns(requiredSchema: StructType): Unit = { // JDBC doesn't support nested column pruning. // TODO (SPARK-32593): JDBC support nested column and nested column pruning. @@ -151,6 +163,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit) + pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortValues) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 9a647e545d836..a5749431672b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -401,6 +401,8 @@ abstract class JdbcDialect extends Serializable with Logging{ */ def supportsLimit(): Boolean = true + def supportsTopN(): Boolean = true + def supportsTableSample: Boolean = false def getTableSample(sample: TableSampleInfo): String = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index b87b4f6d86fd1..9fd481ee6eb59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{lit, sum, udf} @@ -43,6 +44,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .set("spark.sql.catalog.h2.driver", "org.h2.Driver") .set("spark.sql.catalog.h2.pushDownAggregate", "true") .set("spark.sql.catalog.h2.pushDownLimit", "true") + .set("spark.sql.catalog.h2.pushDownTopN", "true") private def withConnection[T](f: Connection => T): T = { val conn = DriverManager.getConnection(url, new Properties()) @@ -138,11 +140,79 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkPushedLimit(df4, false, 0) checkAnswer(df4, Seq(Row(1, 19000.00))) + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } val df5 = spark.read .table("h2.test.employee") - .sort("SALARY") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) .limit(1) + // LIMIT is pushed down only if all the filters are pushed down checkPushedLimit(df5, false, 0) + checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) + } + + private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + if (pushed) { + assert(v1.pushedDownOperators.limit === Some(limit)) + } else { + assert(v1.pushedDownOperators.limit.isEmpty) + } + } + } + } + + test("simple scan with top N") { + val df1 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary").limit(1) + val expectedSorts1 = + Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + checkPushedTopN(df1, true, 1, expectedSorts1) + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .orderBy($"salary".desc) + .limit(1) + val expectedSorts2 = + Seq(SortValue(FieldReference("salary"), SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) + checkPushedTopN(df2, true, 1, expectedSorts2) + checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) + + val df3 = + sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") + val scan = df3.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq("NAME", "SALARY"))) + val expectedSorts3 = + Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_LAST)) + checkPushedTopN(df3, true, 1, expectedSorts3) + checkAnswer(df3, Seq(Row("david"))) + + val df4 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + .limit(1) + checkPushedTopN(df4, false, 0) + checkAnswer(df4, Seq(Row(1, 19000.00))) + + val df5 = spark.read + .table("h2.test.employee") + .sort("SALARY") + .limit(1) + val expectedSorts5 = + Seq(SortValue(FieldReference("SALARY"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + checkPushedTopN(df5, true, 1, expectedSorts5) checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -151,20 +221,24 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) .filter(name($"shortName")) + .sort($"SALARY".desc) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedLimit(df6, false, 0) + checkPushedTopN(df6, false, 0) checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = { + private def checkPushedTopN(df: DataFrame, pushed: Boolean, limit: Int = 0, + sortValues: Seq[SortValue] = Seq.empty): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => if (pushed) { assert(v1.pushedDownOperators.limit === Some(limit)) + assert(v1.pushedDownOperators.sortValues === sortValues) } else { assert(v1.pushedDownOperators.limit.isEmpty) + assert(v1.pushedDownOperators.sortValues.isEmpty) } } } From 45f70841940dfac8fc1f498cf3ec59c146de6b20 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 29 Nov 2021 18:35:34 +0800 Subject: [PATCH 02/18] Update code --- .../apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index c0d5dd973b959..05f014df15422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -159,6 +159,7 @@ object JDBCRDD extends Logging { * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ + // scalastyle:off argcount def scanTable( sc: SparkContext, schema: StructType, @@ -193,6 +194,7 @@ object JDBCRDD extends Logging { limit, sortValues) } + // scalastyle:on argcount } /** From e3e87bdf70de7fe19a70e5431daa717a7648e19b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 29 Nov 2021 18:57:28 +0800 Subject: [PATCH 03/18] Update code --- .../org/apache/spark/sql/execution/DataSourceScanExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 072db3d8d32b6..cab65d1a5396a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -149,7 +149,7 @@ case class RowDataSourceScanExec( Map("PushedAggregates" -> seqToString(v.aggregateExpressions), "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ -// Map("PushedSortValues" -> seqToString(pushedDownOperators.sortValues)) ++ + Map("PushedSortValues" -> seqToString(pushedDownOperators.sortValues)) ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) From 02fb7a207ea09d5267a9d9d2c74c9c25ac38fde9 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 30 Nov 2021 15:56:33 +0800 Subject: [PATCH 04/18] Update code --- .../connector/read/SupportsPushDownLimit.java | 5 --- .../connector/read/SupportsPushDownTopN.java | 37 +++++++++++++++++++ .../datasources/DataSourceStrategy.scala | 19 +--------- .../execution/datasources/jdbc/JDBCRDD.scala | 13 +++++-- .../datasources/jdbc/JDBCRelation.scala | 4 +- .../datasources/v2/PushDownUtils.scala | 10 ++--- .../datasources/v2/PushedDownOperators.scala | 4 +- .../v2/V2ScanRelationPushDown.scala | 13 +++---- .../datasources/v2/jdbc/JDBCScan.scala | 4 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 9 +++-- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 15 ++++---- 11 files changed, 75 insertions(+), 58 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index 2466a7ce5025e..af188b67d7a8e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -34,9 +34,4 @@ public interface SupportsPushDownLimit extends ScanBuilder { * Pushes down LIMIT to the data source. */ boolean pushLimit(int limit); - - /** - * Pushes down top N to the data source. - */ - boolean pushTopN(SortValue[] orders, int limit); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java new file mode 100644 index 0000000000000..7bdcc54ab3e0c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.expressions.SortOrder; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * push down top N. Please note that the combination of LIMIT with other operations + * such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownTopN extends ScanBuilder { + + /** + * Pushes down top N to the data source. + */ + boolean pushTopN(SortOrder[] orders, int limit); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c5aca70ad315f..a3b77387d5374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -726,23 +726,6 @@ object DataSourceStrategy } } - protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortValue] = { - sortOrders.map { - case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => - val (directionV2, nullOrderingV2) = (directionV1, nullOrderingV1) match { - case (Ascending, NullsFirst) => - (SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) - case (Ascending, NullsLast) => - (SortDirection.ASCENDING, NullOrdering.NULLS_LAST) - case (Descending, NullsFirst) => - (SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) - case (Descending, NullsLast) => - (SortDirection.DESCENDING, NullOrdering.NULLS_LAST) - } - SortValue(FieldReference(name), directionV2, nullOrderingV2) - } - } - /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 05f014df15422..fa2936aed02f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,8 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.expressions.SortValue +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -171,7 +172,7 @@ object JDBCRDD extends Logging { groupByColumns: Option[Array[String]] = None, sample: Option[TableSampleInfo] = None, limit: Int = 0, - sortValues: Array[SortValue] = Array.empty[SortValue]): RDD[InternalRow] = { + sortValues: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -214,7 +215,7 @@ private[jdbc] class JDBCRDD( groupByColumns: Option[Array[String]], sample: Option[TableSampleInfo], limit: Int, - sortValues: Array[SortValue]) + sortValues: Array[SortOrder]) extends RDD[InternalRow](sc, Nil) { /** @@ -264,7 +265,11 @@ private[jdbc] class JDBCRDD( private def getOrderByClause: String = { if (sortValues.nonEmpty) { - s" ORDER BY ${sortValues.map(_.describe()).mkString(", ")}" + val values = sortValues.map { + case SortOrder(PushableColumnWithoutNestedColumn(name), direction, nullOrdering, _) => + s"$name ${direction.sql} ${nullOrdering.sql}" + } + s" ORDER BY ${values.mkString(", ")}" } else { "" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 8cb3f59f47cdf..40e59ed3f4e16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -25,9 +25,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} -import org.apache.spark.sql.connector.expressions.SortValue import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf @@ -303,7 +303,7 @@ private[sql] case class JDBCRelation( groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], limit: Int, - sortValues: Array[SortValue]): RDD[Row] = { + sortValues: Array[SortOrder]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 9d9d4b63f4fbe..7a12ec819fb32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.{FieldReference, SortValue} +import org.apache.spark.sql.connector.expressions.FieldReference import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownV2Filters} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources @@ -160,9 +160,9 @@ object PushDownUtils extends PredicateHelper { /** * Pushes down top N to the data source Scan */ - def pushTopN(scanBuilder: ScanBuilder, order: Array[SortValue], limit: Int): Boolean = { + def pushTopN(scanBuilder: ScanBuilder, order: Array[SortOrder], limit: Int): Boolean = { scanBuilder match { - case s: SupportsPushDownLimit => + case s: SupportsPushDownTopN => s.pushTopN(order, limit) case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala index 4f8a59756f5a5..ff404e9079581 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.connector.expressions.SortValue +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation /** @@ -27,4 +27,4 @@ case class PushedDownOperators( aggregation: Option[Aggregation], sample: Option[TableSampleInfo], limit: Option[Int], - sortValues: Seq[SortValue]) + sortValues: Seq[SortOrder]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index c0033c51cd882..1628162bc1fda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.expressions.{aggregate, And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.expressions.SortValue import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -250,7 +248,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => child match { - case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.length == 0 => + case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue) if (limitPushed) { sHolder.pushedLimit = Some(limitValue) @@ -258,9 +256,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { globalLimit case _ => child transform { - case sort @ Sort(order, _, ScanOperation(_, filter, sHolder: ScanBuilderHolder)) - if filter.length == 0 => - val orders = DataSourceStrategy.translateSortOrders(order) + case sort @ Sort(orders, _, ScanOperation(_, filter, sHolder: ScanBuilderHolder)) + if filter.isEmpty => val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) if (topNPushed) { sHolder.pushedLimit = Some(limitValue) @@ -298,7 +295,7 @@ case class ScanBuilderHolder( builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None - var sortValues: Seq[SortValue] = Seq.empty[SortValue] + var sortValues: Seq[SortOrder] = Seq.empty[SortOrder] var pushedSample: Option[TableSampleInfo] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index d958f4d9b4043..a1e0fe23bee1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.connector.expressions.SortValue +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -33,7 +33,7 @@ case class JDBCScan( groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], pushedLimit: Int, - sortValues: Array[SortValue]) extends V1Scan { + sortValues: Array[SortOrder]) extends V1Scan { override def readSchema(): StructType = prunedSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 24cfb72025d97..3510df0568b78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,9 +20,9 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.expressions.SortValue +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample} +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -40,6 +40,7 @@ case class JDBCScanBuilder( with SupportsPushDownAggregates with SupportsPushDownLimit with SupportsPushDownTableSample + with SupportsPushDownTopN with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis @@ -52,7 +53,7 @@ case class JDBCScanBuilder( private var pushedLimit = 0 - private var sortValues: Array[SortValue] = Array.empty[SortValue] + private var sortValues: Array[SortOrder] = Array.empty[SortOrder] override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { @@ -129,7 +130,7 @@ case class JDBCScanBuilder( false } - override def pushTopN(orders: Array[SortValue], limit: Int): Boolean = { + override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { if (jdbcOptions.pushDownTopN && JdbcDialects.get(jdbcOptions.url).supportsTopN) { pushedLimit = limit sortValues = orders diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 9fd481ee6eb59..64016b97d88ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,8 +23,8 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException +import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, NullsFirst, NullsLast} import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{lit, sum, udf} @@ -168,8 +168,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("simple scan with top N") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary").limit(1) - val expectedSorts1 = - Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + val expectedSorts1 = Seq(s"h2.test.employee.salary ${Ascending.sql} ${NullsFirst.sql}") checkPushedTopN(df1, true, 1, expectedSorts1) checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) @@ -183,7 +182,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .orderBy($"salary".desc) .limit(1) val expectedSorts2 = - Seq(SortValue(FieldReference("salary"), SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) + Seq(s"h2.test.employee.salary ${Descending.sql} ${NullsLast.sql}") checkPushedTopN(df2, true, 1, expectedSorts2) checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) @@ -194,7 +193,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel }.get assert(scan.schema.names.sameElements(Seq("NAME", "SALARY"))) val expectedSorts3 = - Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_LAST)) + Seq(s"h2.test.employee.salary ${Ascending.sql} ${NullsLast.sql}") checkPushedTopN(df3, true, 1, expectedSorts3) checkAnswer(df3, Seq(Row("david"))) @@ -211,7 +210,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort("SALARY") .limit(1) val expectedSorts5 = - Seq(SortValue(FieldReference("SALARY"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) + Seq(s"h2.test.employee.SALARY ${Ascending.sql} ${NullsFirst.sql}") checkPushedTopN(df5, true, 1, expectedSorts5) checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0))) @@ -229,13 +228,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } private def checkPushedTopN(df: DataFrame, pushed: Boolean, limit: Int = 0, - sortValues: Seq[SortValue] = Seq.empty): Unit = { + sortValues: Seq[String] = Seq.empty): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => if (pushed) { assert(v1.pushedDownOperators.limit === Some(limit)) - assert(v1.pushedDownOperators.sortValues === sortValues) + assert(v1.pushedDownOperators.sortValues.map(_.sql) === sortValues) } else { assert(v1.pushedDownOperators.limit.isEmpty) assert(v1.pushedDownOperators.sortValues.isEmpty) From ae42f090018f8b9290b6cab9a8ca2f32777d5bf2 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 30 Nov 2021 15:59:39 +0800 Subject: [PATCH 05/18] Update code --- .../apache/spark/sql/connector/read/SupportsPushDownLimit.java | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index af188b67d7a8e..7e50bf14d7817 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector.read; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.SortValue; /** * A mix-in interface for {@link Scan}. Data sources can implement this interface to From 79caffade309f395c2d6fef557c80b9a41f85152 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 30 Nov 2021 16:01:07 +0800 Subject: [PATCH 06/18] Update code --- .../org/apache/spark/sql/execution/DataSourceScanExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index cab65d1a5396a..34943f2d1afdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -149,7 +149,7 @@ case class RowDataSourceScanExec( Map("PushedAggregates" -> seqToString(v.aggregateExpressions), "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ - Map("PushedSortValues" -> seqToString(pushedDownOperators.sortValues)) ++ + Map("PushedSortOrders" -> seqToString(pushedDownOperators.sortValues)) ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) From 0dd9bb458a98b0e3318d7409ef27225fa0584064 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 1 Dec 2021 22:36:56 +0800 Subject: [PATCH 07/18] Update code --- .../connector/read/SupportsPushDownTopN.java | 2 +- .../sql/execution/DataSourceScanExec.scala | 12 ++++++- .../datasources/DataSourceStrategy.scala | 19 ++++++++++- .../execution/datasources/jdbc/JDBCRDD.scala | 7 +--- .../datasources/jdbc/JDBCRelation.scala | 2 +- .../datasources/v2/PushDownUtils.scala | 4 +-- .../datasources/v2/PushedDownOperators.scala | 6 ++-- .../v2/V2ScanRelationPushDown.scala | 32 ++++++++++++------- .../datasources/v2/jdbc/JDBCScan.scala | 2 +- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 6 ++-- .../apache/spark/sql/jdbc/DerbyDialect.scala | 6 ++-- .../apache/spark/sql/jdbc/JdbcDialects.scala | 7 ---- .../spark/sql/jdbc/MsSqlServerDialect.scala | 5 +-- .../spark/sql/jdbc/TeradataDialect.scala | 5 +-- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 15 +++++---- 15 files changed, 79 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java index 7bdcc54ab3e0c..7392fbf110ec6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.read; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.connector.expressions.SortOrder; /** * A mix-in interface for {@link Scan}. Data sources can implement this interface to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 34943f2d1afdb..77c32272905f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -142,6 +142,16 @@ case class RowDataSourceScanExec( handledFilters } + val pushedTopN = + if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) { + s""" + |ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))} + |LIMIT ${pushedDownOperators.limit.get} + |""".stripMargin + } else { + "" + } + Map( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ @@ -149,7 +159,7 @@ case class RowDataSourceScanExec( Map("PushedAggregates" -> seqToString(v.aggregateExpressions), "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ - Map("PushedSortOrders" -> seqToString(pushedDownOperators.sortValues)) ++ + Map("pushedTopN" -> pushedTopN) ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a3b77387d5374..c1983b238bf1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortOrder => SortOrderV2, SortValue} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} @@ -726,6 +726,23 @@ object DataSourceStrategy } } + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { + sortOrders.map { + case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => + val (directionV2, nullOrderingV2) = (directionV1, nullOrderingV1) match { + case (Ascending, NullsFirst) => + (SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + case (Ascending, NullsLast) => + (SortDirection.ASCENDING, NullOrdering.NULLS_LAST) + case (Descending, NullsFirst) => + (SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) + case (Descending, NullsLast) => + (SortDirection.DESCENDING, NullOrdering.NULLS_LAST) + } + SortValue(FieldReference(name), directionV2, nullOrderingV2) + } + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index fa2936aed02f2..e008c81858c6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -265,11 +264,7 @@ private[jdbc] class JDBCRDD( private def getOrderByClause: String = { if (sortValues.nonEmpty) { - val values = sortValues.map { - case SortOrder(PushableColumnWithoutNestedColumn(name), direction, nullOrdering, _) => - s"$name ${direction.sql} ${nullOrdering.sql}" - } - s" ORDER BY ${values.mkString(", ")}" + s" ORDER BY ${sortValues.map(_.describe()).mkString(", ")}" } else { "" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 40e59ed3f4e16..1f9c3e0fa73b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -25,9 +25,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 7a12ec819fb32..39374b6924820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.{FieldReference, SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala index ff404e9079581..20ced9c17f7e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation /** @@ -27,4 +27,6 @@ case class PushedDownOperators( aggregation: Option[Aggregation], sample: Option[TableSampleInfo], limit: Option[Int], - sortValues: Seq[SortOrder]) + sortValues: Seq[SortOrder]) { + assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 1628162bc1fda..0510c8c762a09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{aggregate, And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy @@ -254,19 +256,25 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { sHolder.pushedLimit = Some(limitValue) } globalLimit - case _ => - child transform { - case sort @ Sort(orders, _, ScanOperation(_, filter, sHolder: ScanBuilderHolder)) - if filter.isEmpty => - val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) - if (topNPushed) { - sHolder.pushedLimit = Some(limitValue) - sHolder.sortValues = orders - } - sort - case other => other + case Sort(order, _, ScanOperation(_, filter, sHolder: ScanBuilderHolder)) + if filter.isEmpty => + val orders = DataSourceStrategy.translateSortOrders(order) + val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) + if (topNPushed) { + sHolder.pushedLimit = Some(limitValue) + sHolder.sortValues = orders + } + globalLimit + case Project(_, Sort(order, _, ScanOperation(_, filter, sHolder: ScanBuilderHolder))) + if filter.isEmpty => + val orders = DataSourceStrategy.translateSortOrders(order) + val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) + if (topNPushed) { + sHolder.pushedLimit = Some(limitValue) + sHolder.sortValues = orders } globalLimit + case _ => globalLimit } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index a1e0fe23bee1f..0bb3812e5a33a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 3510df0568b78..83ea1317552ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,7 +20,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN} import org.apache.spark.sql.execution.datasources.PartitioningUtils @@ -123,7 +123,7 @@ case class JDBCScanBuilder( } override def pushLimit(limit: Int): Boolean = { - if (jdbcOptions.pushDownLimit && JdbcDialects.get(jdbcOptions.url).supportsLimit) { + if (jdbcOptions.pushDownLimit) { pushedLimit = limit return true } @@ -131,7 +131,7 @@ case class JDBCScanBuilder( } override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { - if (jdbcOptions.pushDownTopN && JdbcDialects.get(jdbcOptions.url).supportsTopN) { + if (jdbcOptions.pushDownTopN) { pushedLimit = limit sortValues = orders return true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index ecb514abac01c..f19ef7ead5f8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -58,7 +58,7 @@ private object DerbyDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } - // ToDo: use fetch first n rows only for limit, e.g. - // select * from employee fetch first 10 rows only; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a5749431672b0..e0f11afcc2550 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -396,13 +396,6 @@ abstract class JdbcDialect extends Serializable with Logging{ if (limit > 0 ) s"LIMIT $limit" else "" } - /** - * returns whether the dialect supports limit or not - */ - def supportsLimit(): Boolean = true - - def supportsTopN(): Boolean = true - def supportsTableSample: Boolean = false def getTableSample(sample: TableSampleInfo): String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8dad5ef8e1eae..8e5674a181e7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -119,6 +119,7 @@ private object MsSqlServerDialect extends JdbcDialect { throw QueryExecutionErrors.commentOnTableUnsupportedError() } - // ToDo: use top n to get limit, e.g. select top 100 * from employee; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 2a776bdb7ab04..13f4c5fe9c926 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -56,6 +56,7 @@ private case object TeradataDialect extends JdbcDialect { s"RENAME TABLE $oldTable TO $newTable" } - // ToDo: use top n to get limit, e.g. select top 100 * from employee; - override def supportsLimit(): Boolean = false + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 64016b97d88ba..9fd481ee6eb59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,8 +23,8 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, NullsFirst, NullsLast} import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.functions.{lit, sum, udf} @@ -168,7 +168,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("simple scan with top N") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary").limit(1) - val expectedSorts1 = Seq(s"h2.test.employee.salary ${Ascending.sql} ${NullsFirst.sql}") + val expectedSorts1 = + Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) checkPushedTopN(df1, true, 1, expectedSorts1) checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) @@ -182,7 +183,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .orderBy($"salary".desc) .limit(1) val expectedSorts2 = - Seq(s"h2.test.employee.salary ${Descending.sql} ${NullsLast.sql}") + Seq(SortValue(FieldReference("salary"), SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) checkPushedTopN(df2, true, 1, expectedSorts2) checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) @@ -193,7 +194,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel }.get assert(scan.schema.names.sameElements(Seq("NAME", "SALARY"))) val expectedSorts3 = - Seq(s"h2.test.employee.salary ${Ascending.sql} ${NullsLast.sql}") + Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_LAST)) checkPushedTopN(df3, true, 1, expectedSorts3) checkAnswer(df3, Seq(Row("david"))) @@ -210,7 +211,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort("SALARY") .limit(1) val expectedSorts5 = - Seq(s"h2.test.employee.SALARY ${Ascending.sql} ${NullsFirst.sql}") + Seq(SortValue(FieldReference("SALARY"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) checkPushedTopN(df5, true, 1, expectedSorts5) checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0))) @@ -228,13 +229,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } private def checkPushedTopN(df: DataFrame, pushed: Boolean, limit: Int = 0, - sortValues: Seq[String] = Seq.empty): Unit = { + sortValues: Seq[SortValue] = Seq.empty): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => if (pushed) { assert(v1.pushedDownOperators.limit === Some(limit)) - assert(v1.pushedDownOperators.sortValues.map(_.sql) === sortValues) + assert(v1.pushedDownOperators.sortValues === sortValues) } else { assert(v1.pushedDownOperators.limit.isEmpty) assert(v1.pushedDownOperators.sortValues.isEmpty) From a95a2f3027932f5e99de4163601d1c4569db396d Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 4 Dec 2021 00:11:21 +0800 Subject: [PATCH 08/18] Update code --- .../connector/read/SupportsPushDownTopN.java | 6 ++-- .../sql/execution/DataSourceScanExec.scala | 27 ++++++++-------- .../datasources/DataSourceStrategy.scala | 16 ++++------ .../v2/V2ScanRelationPushDown.scala | 19 +++++++---- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 32 +++++++++++++------ 5 files changed, 59 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java index 7392fbf110ec6..5fd6f55f5f403 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -21,9 +21,9 @@ import org.apache.spark.sql.connector.expressions.SortOrder; /** - * A mix-in interface for {@link Scan}. Data sources can implement this interface to - * push down top N. Please note that the combination of LIMIT with other operations - * such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. + * A mix-in interface for {@link Scan}. Data sources can implement this interface to push down + * top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N with other + * operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. * * @since 3.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 77c32272905f4..ed50aefba0fa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -142,24 +142,25 @@ case class RowDataSourceScanExec( handledFilters } - val pushedTopN = + val optionOutputMap = if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) { - s""" - |ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))} - |LIMIT ${pushedDownOperators.limit.get} - |""".stripMargin - } else { - "" - } + val pushedTopN = + s""" + |ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))} + |LIMIT ${pushedDownOperators.limit.get} + |""".stripMargin + Map("pushedTopN" -> pushedTopN) + } else { + pushedDownOperators.aggregation.fold(Map[String, String]()) { v => + Map("PushedAggregates" -> seqToString(v.aggregateExpressions), + "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") + } Map( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ - pushedDownOperators.aggregation.fold(Map[String, String]()) { v => - Map("PushedAggregates" -> seqToString(v.aggregateExpressions), - "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ - pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") ++ - Map("pushedTopN" -> pushedTopN) ++ + optionOutputMap ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c1983b238bf1f..ae30b5eeef862 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -729,15 +729,13 @@ object DataSourceStrategy protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { sortOrders.map { case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => - val (directionV2, nullOrderingV2) = (directionV1, nullOrderingV1) match { - case (Ascending, NullsFirst) => - (SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) - case (Ascending, NullsLast) => - (SortDirection.ASCENDING, NullOrdering.NULLS_LAST) - case (Descending, NullsFirst) => - (SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) - case (Descending, NullsLast) => - (SortDirection.DESCENDING, NullOrdering.NULLS_LAST) + val directionV2 = directionV1 match { + case Ascending => SortDirection.ASCENDING + case Descending => SortDirection.DESCENDING + } + val nullOrderingV2 = nullOrderingV1 match { + case NullsFirst => NullOrdering.NULLS_FIRST + case NullsLast => NullOrdering.NULLS_LAST } SortValue(FieldReference(name), directionV2, nullOrderingV2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 0510c8c762a09..91d50483566a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeRefer import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LogicalPlan, Project, Sample, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation @@ -256,24 +256,31 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { sHolder.pushedLimit = Some(limitValue) } globalLimit - case Sort(order, _, ScanOperation(_, filter, sHolder: ScanBuilderHolder)) + case Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) if filter.isEmpty => val orders = DataSourceStrategy.translateSortOrders(order) val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) if (topNPushed) { sHolder.pushedLimit = Some(limitValue) sHolder.sortValues = orders + val localLimit = globalLimit.child.asInstanceOf[LocalLimit].copy(child = operation) + globalLimit.copy(child = localLimit) + } else { + globalLimit } - globalLimit - case Project(_, Sort(order, _, ScanOperation(_, filter, sHolder: ScanBuilderHolder))) - if filter.isEmpty => + case project @ Project(_, Sort(order, _, + operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))) if filter.isEmpty => val orders = DataSourceStrategy.translateSortOrders(order) val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) if (topNPushed) { sHolder.pushedLimit = Some(limitValue) sHolder.sortValues = orders + val localLimit = globalLimit.child.asInstanceOf[LocalLimit] + .copy(child = project.copy(child = operation)) + globalLimit.copy(child = localLimit) + } else { + globalLimit } - globalLimit case _ => globalLimit } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 9fd481ee6eb59..a2777d14eccb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -192,40 +192,50 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val scan = df3.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get - assert(scan.schema.names.sameElements(Seq("NAME", "SALARY"))) + assert(scan.schema.names.sameElements(Seq("NAME"))) val expectedSorts3 = Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_LAST)) checkPushedTopN(df3, true, 1, expectedSorts3) checkAnswer(df3, Seq(Row("david"))) - val df4 = spark.read + val df4 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary") + checkPushedTopN(df4, false, 0) + checkAnswer(df4, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) + + val df5 = spark.read.table("h2.test.employee") + .where($"dept" === 1).limit(1) + checkPushedTopN(df5, false, 1) + checkAnswer(df5, Seq(Row(1, "amy", 10000.00, 1000.0))) + + val df6 = spark.read .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .orderBy("DEPT") .limit(1) - checkPushedTopN(df4, false, 0) - checkAnswer(df4, Seq(Row(1, 19000.00))) + checkPushedTopN(df6, false, 0) + checkAnswer(df6, Seq(Row(1, 19000.00))) - val df5 = spark.read + val df7 = spark.read .table("h2.test.employee") .sort("SALARY") .limit(1) val expectedSorts5 = Seq(SortValue(FieldReference("SALARY"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) - checkPushedTopN(df5, true, 1, expectedSorts5) - checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0))) + checkPushedTopN(df7, true, 1, expectedSorts5) + checkAnswer(df7, Seq(Row(1, "cathy", 9000.00, 1200.0))) val name = udf { (x: String) => x.matches("cat|dav|amy") } val sub = udf { (x: String) => x.substring(0, 3) } - val df6 = spark.read + val df8 = spark.read .table("h2.test.employee") .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) .filter(name($"shortName")) .sort($"SALARY".desc) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedTopN(df6, false, 0) - checkAnswer(df6, Seq(Row(10000.00, 1000.0, "amy"))) + checkPushedTopN(df8, false, 0) + checkAnswer(df8, Seq(Row(10000.00, 1000.0, "amy"))) } private def checkPushedTopN(df: DataFrame, pushed: Boolean, limit: Int = 0, @@ -236,6 +246,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel if (pushed) { assert(v1.pushedDownOperators.limit === Some(limit)) assert(v1.pushedDownOperators.sortValues === sortValues) + } else if (limit > 0) { + assert(v1.pushedDownOperators.limit === Some(limit)) } else { assert(v1.pushedDownOperators.limit.isEmpty) assert(v1.pushedDownOperators.sortValues.isEmpty) From bfdf849fe7732a8e35ed588a8f2b700e87be00bd Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 9 Dec 2021 21:18:32 +0800 Subject: [PATCH 09/18] Update sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java Co-authored-by: Wenchen Fan --- .../apache/spark/sql/connector/read/SupportsPushDownTopN.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java index 5fd6f55f5f403..d79b372d4e747 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -21,7 +21,7 @@ import org.apache.spark.sql.connector.expressions.SortOrder; /** - * A mix-in interface for {@link Scan}. Data sources can implement this interface to push down + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to push down * top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N with other * operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. * From afc2bc63b936ec6b344d5e30c6c230f3e15841ff Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 9 Dec 2021 22:43:53 +0800 Subject: [PATCH 10/18] Update code --- .../connector/read/SupportsPushDownLimit.java | 2 +- .../sql/execution/DataSourceScanExec.scala | 16 +-- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 104 ++++++++---------- 3 files changed, 55 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java index 7e50bf14d7817..fa6447bc068d5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -20,7 +20,7 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to * push down LIMIT. Please note that the combination of LIMIT with other operations * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index ed50aefba0fa1..007f92a0c5e0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -142,25 +142,25 @@ case class RowDataSourceScanExec( handledFilters } - val optionOutputMap = + val topNOrLimitInfo = if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) { val pushedTopN = s""" |ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))} |LIMIT ${pushedDownOperators.limit.get} - |""".stripMargin - Map("pushedTopN" -> pushedTopN) + |""".stripMargin.replaceAll("\n", " ") + Some("pushedTopN" -> pushedTopN) } else { - pushedDownOperators.aggregation.fold(Map[String, String]()) { v => - Map("PushedAggregates" -> seqToString(v.aggregateExpressions), - "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ - pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") } Map( "ReadSchema" -> requiredSchema.catalogString, "PushedFilters" -> seqToString(markedFilters.toSeq)) ++ - optionOutputMap ++ + pushedDownOperators.aggregation.fold(Map[String, String]()) { v => + Map("PushedAggregates" -> seqToString(v.aggregateExpressions), + "PushedGroupByColumns" -> seqToString(v.groupByColumns))} ++ + topNOrLimitInfo ++ pushedDownOperators.sample.map(v => "PushedSample" -> s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index a2777d14eccb7..b2f9a52cbb1d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -104,14 +104,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df, true, 3) + checkPushedLimit(df, Some(3)) checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) } test("simple scan with LIMIT") { val df1 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) - checkPushedLimit(df1, true, 1) + checkPushedLimit(df1, Some(1)) checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0))) val df2 = spark.read @@ -122,7 +122,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .filter($"dept" > 1) .limit(1) - checkPushedLimit(df2, true, 1) + checkPushedLimit(df2, Some(1)) checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") @@ -130,14 +130,14 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedLimit(df3, true, 1) + checkPushedLimit(df3, Some(1)) checkAnswer(df3, Seq(Row("alex"))) val df4 = spark.read .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .limit(1) - checkPushedLimit(df4, false, 0) + checkPushedLimit(df4, None) checkAnswer(df4, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -148,32 +148,33 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter(name($"shortName")) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedLimit(df5, false, 0) + checkPushedLimit(df5, None) checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkPushedLimit(df: DataFrame, pushed: Boolean, limit: Int): Unit = { + private def checkPushedLimit(df: DataFrame, limit: Option[Int]): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => - if (pushed) { - assert(v1.pushedDownOperators.limit === Some(limit)) - } else { - assert(v1.pushedDownOperators.limit.isEmpty) - } + assert(v1.pushedDownOperators.limit === limit) } } } test("simple scan with top N") { - val df1 = spark.read.table("h2.test.employee") - .where($"dept" === 1).orderBy($"salary").limit(1) - val expectedSorts1 = - Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) - checkPushedTopN(df1, true, 1, expectedSorts1) + val df1 = spark.read + .table("h2.test.employee") + .sort("salary") + .limit(1) + checkPushedTopN(df1, Some(1), createSortValues()) checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) - val df2 = spark.read + val df2 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary").limit(1) + checkPushedTopN(df2, Some(1), createSortValues()) + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0))) + + val df3 = spark.read .option("partitionColumn", "dept") .option("lowerBound", "0") .option("upperBound", "2") @@ -182,48 +183,36 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .orderBy($"salary".desc) .limit(1) - val expectedSorts2 = - Seq(SortValue(FieldReference("salary"), SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) - checkPushedTopN(df2, true, 1, expectedSorts2) - checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0))) + checkPushedTopN( + df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) + checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0))) - val df3 = + val df4 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") - val scan = df3.queryExecution.optimizedPlan.collectFirst { + val scan = df4.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - val expectedSorts3 = - Seq(SortValue(FieldReference("salary"), SortDirection.ASCENDING, NullOrdering.NULLS_LAST)) - checkPushedTopN(df3, true, 1, expectedSorts3) - checkAnswer(df3, Seq(Row("david"))) + checkPushedTopN(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) + checkAnswer(df4, Seq(Row("david"))) - val df4 = spark.read.table("h2.test.employee") + val df5 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary") - checkPushedTopN(df4, false, 0) - checkAnswer(df4, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) + checkPushedTopN(df5, None, Seq.empty) + checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) - val df5 = spark.read.table("h2.test.employee") + val df6 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) - checkPushedTopN(df5, false, 1) - checkAnswer(df5, Seq(Row(1, "amy", 10000.00, 1000.0))) + checkPushedTopN(df6, Some(1), Seq.empty) + checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0))) - val df6 = spark.read + val df7 = spark.read .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .orderBy("DEPT") .limit(1) - checkPushedTopN(df6, false, 0) - checkAnswer(df6, Seq(Row(1, 19000.00))) - - val df7 = spark.read - .table("h2.test.employee") - .sort("SALARY") - .limit(1) - val expectedSorts5 = - Seq(SortValue(FieldReference("SALARY"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST)) - checkPushedTopN(df7, true, 1, expectedSorts5) - checkAnswer(df7, Seq(Row(1, "cathy", 9000.00, 1200.0))) + checkPushedTopN(df7) + checkAnswer(df7, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } val sub = udf { (x: String) => x.substring(0, 3) } @@ -234,24 +223,23 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort($"SALARY".desc) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedTopN(df8, false, 0) + checkPushedTopN(df8) checkAnswer(df8, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkPushedTopN(df: DataFrame, pushed: Boolean, limit: Int = 0, - sortValues: Seq[SortValue] = Seq.empty): Unit = { + private def createSortValues( + sortDirection: SortDirection = SortDirection.ASCENDING, + nullOrdering: NullOrdering = NullOrdering.NULLS_FIRST): Seq[SortValue] = { + Seq(SortValue(FieldReference("salary"), sortDirection, nullOrdering)) + } + + private def checkPushedTopN(df: DataFrame, limit: Option[Int] = None, + sortValues: Seq[SortValue] = Seq.empty[SortValue]): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => - if (pushed) { - assert(v1.pushedDownOperators.limit === Some(limit)) - assert(v1.pushedDownOperators.sortValues === sortValues) - } else if (limit > 0) { - assert(v1.pushedDownOperators.limit === Some(limit)) - } else { - assert(v1.pushedDownOperators.limit.isEmpty) - assert(v1.pushedDownOperators.sortValues.isEmpty) - } + assert(v1.pushedDownOperators.limit === limit) + assert(v1.pushedDownOperators.sortValues === sortValues) } } } From b24a804900272ccfb8c8124b2fbfe0531283628b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 10 Dec 2021 10:09:09 +0800 Subject: [PATCH 11/18] Update code --- .../connector/read/SupportsPushDownTopN.java | 7 +- .../v2/V2ScanRelationPushDown.scala | 67 ++++++++++--------- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java index d79b372d4e747..0212895fde079 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -21,9 +21,10 @@ import org.apache.spark.sql.connector.expressions.SortOrder; /** - * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to push down - * top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N with other - * operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N + * with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. + * is NOT pushed down. * * @since 3.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 91d50483566a3..5472795abc352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -247,41 +247,42 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match { + case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit) + if (limitPushed) { + sHolder.pushedLimit = Some(limit) + } + operation + case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) + if filter.isEmpty => + val orders = DataSourceStrategy.translateSortOrders(order) + val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) + if (topNPushed) { + sHolder.pushedLimit = Some(limit) + sHolder.sortValues = orders + operation + } else { + s + } + case p: Project => + val newChild = pushDownLimit(p.child, limit) + if (newChild == p.child) { + p + } else { + p.copy(child = newChild) + } + case other => other + } + def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => - child match { - case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => - val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limitValue) - if (limitPushed) { - sHolder.pushedLimit = Some(limitValue) - } - globalLimit - case Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) - if filter.isEmpty => - val orders = DataSourceStrategy.translateSortOrders(order) - val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) - if (topNPushed) { - sHolder.pushedLimit = Some(limitValue) - sHolder.sortValues = orders - val localLimit = globalLimit.child.asInstanceOf[LocalLimit].copy(child = operation) - globalLimit.copy(child = localLimit) - } else { - globalLimit - } - case project @ Project(_, Sort(order, _, - operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder))) if filter.isEmpty => - val orders = DataSourceStrategy.translateSortOrders(order) - val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limitValue) - if (topNPushed) { - sHolder.pushedLimit = Some(limitValue) - sHolder.sortValues = orders - val localLimit = globalLimit.child.asInstanceOf[LocalLimit] - .copy(child = project.copy(child = operation)) - globalLimit.copy(child = localLimit) - } else { - globalLimit - } - case _ => globalLimit + val newChild = pushDownLimit(child, limitValue) + if (newChild == child) { + globalLimit + } else { + val localLimit = globalLimit.child.asInstanceOf[LocalLimit].copy(child = newChild) + globalLimit.copy(child = localLimit) } } From 29e8265313bdd451464f22294a1ae4a5407505f9 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 13 Dec 2021 18:35:15 +0800 Subject: [PATCH 12/18] Update code --- .../sql/execution/DataSourceScanExec.scala | 8 ++--- .../datasources/jdbc/JDBCOptions.scala | 4 --- .../v2/V2ScanRelationPushDown.scala | 14 ++------ .../datasources/v2/jdbc/JDBCScanBuilder.scala | 2 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 33 +++++++------------ 5 files changed, 19 insertions(+), 42 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 007f92a0c5e0b..9304566801278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -145,13 +145,11 @@ case class RowDataSourceScanExec( val topNOrLimitInfo = if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) { val pushedTopN = - s""" - |ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))} - |LIMIT ${pushedDownOperators.limit.get} - |""".stripMargin.replaceAll("\n", " ") + s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" + + s" LIMIT ${pushedDownOperators.limit.get}" Some("pushedTopN" -> pushedTopN) } else { - pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") } Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index eb8841b5b699a..fc7353ad889d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -196,10 +196,6 @@ class JDBCOptions( // This only applies to Data Source V2 JDBC val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean - // An option to allow/disallow pushing down query of top N into V2 JDBC data source - // This only applies to Data Source V2 JDBC - val pushDownTopN = parameters.getOrElse(JDBC_PUSHDOWN_TOP_N, "false").toBoolean - // An option to allow/disallow pushing down TABLESAMPLE into JDBC data source // This only applies to Data Source V2 JDBC val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 5472795abc352..99d43b64d8e6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -267,23 +267,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } case p: Project => val newChild = pushDownLimit(p.child, limit) - if (newChild == p.child) { - p - } else { - p.copy(child = newChild) - } + p.withNewChildren(Seq(newChild)) case other => other } def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => val newChild = pushDownLimit(child, limitValue) - if (newChild == child) { - globalLimit - } else { - val localLimit = globalLimit.child.asInstanceOf[LocalLimit].copy(child = newChild) - globalLimit.copy(child = localLimit) - } + val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) + globalLimit } private def getWrappedScan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 83ea1317552ba..fd01e67984bbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -131,7 +131,7 @@ case class JDBCScanBuilder( } override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { - if (jdbcOptions.pushDownTopN) { + if (jdbcOptions.pushDownLimit) { pushedLimit = limit sortValues = orders return true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index b2f9a52cbb1d4..66449cce7c75f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -152,11 +152,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkPushedLimit(df: DataFrame, limit: Option[Int]): Unit = { + private def checkPushedLimit(df: DataFrame, limit: Option[Int] = None, + sortValues: Seq[SortValue] = Nil): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => assert(v1.pushedDownOperators.limit === limit) + assert(v1.pushedDownOperators.sortValues === sortValues) } } } @@ -166,12 +168,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .sort("salary") .limit(1) - checkPushedTopN(df1, Some(1), createSortValues()) + checkPushedLimit(df1, Some(1), createSortValues()) checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) val df2 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary").limit(1) - checkPushedTopN(df2, Some(1), createSortValues()) + checkPushedLimit(df2, Some(1), createSortValues()) checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0))) val df3 = spark.read @@ -183,7 +185,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .orderBy($"salary".desc) .limit(1) - checkPushedTopN( + checkPushedLimit( df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0))) @@ -192,18 +194,18 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val scan = df4.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get - assert(scan.schema.names.sameElements(Seq("NAME"))) - checkPushedTopN(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) + assert(scan.schema.names.sameElements(Seq("NAME", "SALARY"))) + checkPushedLimit(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) checkAnswer(df4, Seq(Row("david"))) val df5 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary") - checkPushedTopN(df5, None, Seq.empty) + checkPushedLimit(df5, None, Seq.empty) checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) val df6 = spark.read.table("h2.test.employee") .where($"dept" === 1).limit(1) - checkPushedTopN(df6, Some(1), Seq.empty) + checkPushedLimit(df6, Some(1), Seq.empty) checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0))) val df7 = spark.read @@ -211,7 +213,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy("DEPT").sum("SALARY") .orderBy("DEPT") .limit(1) - checkPushedTopN(df7) + checkPushedLimit(df7) checkAnswer(df7, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } @@ -223,7 +225,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .sort($"SALARY".desc) .limit(1) // LIMIT is pushed down only if all the filters are pushed down - checkPushedTopN(df8) + checkPushedLimit(df8) checkAnswer(df8, Seq(Row(10000.00, 1000.0, "amy"))) } @@ -233,17 +235,6 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel Seq(SortValue(FieldReference("salary"), sortDirection, nullOrdering)) } - private def checkPushedTopN(df: DataFrame, limit: Option[Int] = None, - sortValues: Seq[SortValue] = Seq.empty[SortValue]): Unit = { - df.queryExecution.optimizedPlan.collect { - case relation: DataSourceV2ScanRelation => relation.scan match { - case v1: V1ScanWrapper => - assert(v1.pushedDownOperators.limit === limit) - assert(v1.pushedDownOperators.sortValues === sortValues) - } - } - } - test("scan with filter push-down") { val df = spark.table("h2.test.people").filter($"id" > 1) val filters = df.queryExecution.optimizedPlan.collect { From 0682f062e6d8a59bb09b5c2f0cecda06e06d8685 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 13 Dec 2021 18:39:16 +0800 Subject: [PATCH 13/18] Update code --- .../spark/sql/execution/datasources/jdbc/JDBCOptions.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index fc7353ad889d5..d081e0ace0e44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -276,7 +276,6 @@ object JDBCOptions { val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit") - val JDBC_PUSHDOWN_TOP_N = newOption("pushDownTopN") val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") From 8f11248db6d41dad41cab4d4ea784f296a802348 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 15 Dec 2021 09:44:09 +0800 Subject: [PATCH 14/18] Update code --- .../v2/V2ScanRelationPushDown.scala | 2 +- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 42 ++++++++++++------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 99d43b64d8e6a..c977c3cf66291 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -275,7 +275,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case globalLimit @ Limit(IntegerLiteral(limitValue), child) => val newChild = pushDownLimit(child, limitValue) val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) - globalLimit + globalLimit.withNewChildren(Seq(newLocalLimit)) } private def getWrappedScan( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 66449cce7c75f..2dbf3c97605b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -23,7 +23,7 @@ import java.util.Properties import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sort} import org.apache.spark.sql.connector.expressions.{FieldReference, NullOrdering, SortDirection, SortValue} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -44,7 +44,6 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .set("spark.sql.catalog.h2.driver", "org.h2.Driver") .set("spark.sql.catalog.h2.pushDownAggregate", "true") .set("spark.sql.catalog.h2.pushDownLimit", "true") - .set("spark.sql.catalog.h2.pushDownTopN", "true") private def withConnection[T](f: Connection => T): T = { val conn = DriverManager.getConnection(url, new Properties()) @@ -168,11 +167,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .sort("salary") .limit(1) + checkSortRemoved(df1) checkPushedLimit(df1, Some(1), createSortValues()) checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) val df2 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary").limit(1) + checkSortRemoved(df2) checkPushedLimit(df2, Some(1), createSortValues()) checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0))) @@ -185,6 +186,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .orderBy($"salary".desc) .limit(1) + checkSortRemoved(df3) checkPushedLimit( df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0))) @@ -194,39 +196,49 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val scan = df4.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get - assert(scan.schema.names.sameElements(Seq("NAME", "SALARY"))) + assert(scan.schema.names.sameElements(Seq("NAME"))) + checkSortRemoved(df4) checkPushedLimit(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) checkAnswer(df4, Seq(Row("david"))) val df5 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary") - checkPushedLimit(df5, None, Seq.empty) + checkSortRemoved(df5, false) + checkPushedLimit(df5, None) checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) - val df6 = spark.read.table("h2.test.employee") - .where($"dept" === 1).limit(1) - checkPushedLimit(df6, Some(1), Seq.empty) - checkAnswer(df6, Seq(Row(1, "amy", 10000.00, 1000.0))) - - val df7 = spark.read + val df6 = spark.read .table("h2.test.employee") .groupBy("DEPT").sum("SALARY") .orderBy("DEPT") .limit(1) - checkPushedLimit(df7) - checkAnswer(df7, Seq(Row(1, 19000.00))) + checkSortRemoved(df6, false) + checkPushedLimit(df6) + checkAnswer(df6, Seq(Row(1, 19000.00))) val name = udf { (x: String) => x.matches("cat|dav|amy") } val sub = udf { (x: String) => x.substring(0, 3) } - val df8 = spark.read + val df7 = spark.read .table("h2.test.employee") .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) .filter(name($"shortName")) .sort($"SALARY".desc) .limit(1) + checkSortRemoved(df7, false) // LIMIT is pushed down only if all the filters are pushed down - checkPushedLimit(df8) - checkAnswer(df8, Seq(Row(10000.00, 1000.0, "amy"))) + checkPushedLimit(df7) + checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + } + + private def checkSortRemoved(df: DataFrame, removed: Boolean = true): Unit = { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + if (removed) { + assert(sorts.isEmpty) + } else { + assert(sorts.nonEmpty) + } } private def createSortValues( From c44097027ffffaa0ca753284771c5effb9159d20 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 15 Dec 2021 13:50:12 +0800 Subject: [PATCH 15/18] Update code --- .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 24 +++++-------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 2dbf3c97605b0..f999bb0edb58b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -160,6 +160,12 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel assert(v1.pushedDownOperators.sortValues === sortValues) } } + if (sortValues.nonEmpty) { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + assert(sorts.isEmpty) + } } test("simple scan with top N") { @@ -167,13 +173,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .table("h2.test.employee") .sort("salary") .limit(1) - checkSortRemoved(df1) checkPushedLimit(df1, Some(1), createSortValues()) checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0))) val df2 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary").limit(1) - checkSortRemoved(df2) checkPushedLimit(df2, Some(1), createSortValues()) checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0))) @@ -186,7 +190,6 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter($"dept" > 1) .orderBy($"salary".desc) .limit(1) - checkSortRemoved(df3) checkPushedLimit( df3, Some(1), createSortValues(SortDirection.DESCENDING, NullOrdering.NULLS_LAST)) checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0))) @@ -197,13 +200,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel case s: DataSourceV2ScanRelation => s }.get assert(scan.schema.names.sameElements(Seq("NAME"))) - checkSortRemoved(df4) checkPushedLimit(df4, Some(1), createSortValues(nullOrdering = NullOrdering.NULLS_LAST)) checkAnswer(df4, Seq(Row("david"))) val df5 = spark.read.table("h2.test.employee") .where($"dept" === 1).orderBy($"salary") - checkSortRemoved(df5, false) checkPushedLimit(df5, None) checkAnswer(df5, Seq(Row(1, "cathy", 9000.00, 1200.0), Row(1, "amy", 10000.00, 1000.0))) @@ -212,7 +213,6 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .groupBy("DEPT").sum("SALARY") .orderBy("DEPT") .limit(1) - checkSortRemoved(df6, false) checkPushedLimit(df6) checkAnswer(df6, Seq(Row(1, 19000.00))) @@ -224,23 +224,11 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .filter(name($"shortName")) .sort($"SALARY".desc) .limit(1) - checkSortRemoved(df7, false) // LIMIT is pushed down only if all the filters are pushed down checkPushedLimit(df7) checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) } - private def checkSortRemoved(df: DataFrame, removed: Boolean = true): Unit = { - val sorts = df.queryExecution.optimizedPlan.collect { - case s: Sort => s - } - if (removed) { - assert(sorts.isEmpty) - } else { - assert(sorts.nonEmpty) - } - } - private def createSortValues( sortDirection: SortDirection = SortDirection.ASCENDING, nullOrdering: NullOrdering = NullOrdering.NULLS_FIRST): Seq[SortValue] = { From 5a2f8b851dae13e9d29d06e8380217a835280f01 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 16 Dec 2021 13:33:54 +0800 Subject: [PATCH 16/18] Update code --- .../sql/execution/datasources/DataSourceStrategy.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index ae30b5eeef862..c296ba9f29dd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -727,7 +727,7 @@ object DataSourceStrategy } protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[SortOrderV2] = { - sortOrders.map { + def translateOortOrder(sortOrder: SortOrder): Option[SortOrderV2] = sortOrder match { case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => val directionV2 = directionV1 match { case Ascending => SortDirection.ASCENDING @@ -737,8 +737,11 @@ object DataSourceStrategy case NullsFirst => NullOrdering.NULLS_FIRST case NullsLast => NullOrdering.NULLS_LAST } - SortValue(FieldReference(name), directionV2, nullOrderingV2) + Some(SortValue(FieldReference(name), directionV2, nullOrderingV2)) + case _ => None } + + sortOrders.flatMap(translateOortOrder) } /** From 31d4ea52ed5ba9e50cfd29a66799e758da9effc3 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 16 Dec 2021 16:01:54 +0800 Subject: [PATCH 17/18] Update code --- .../datasources/v2/V2ScanRelationPushDown.scala | 14 +++++++++----- .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 7 +++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index c977c3cf66291..fb5ee3312b52d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -257,11 +257,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case s @ Sort(order, _, operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder)) if filter.isEmpty => val orders = DataSourceStrategy.translateSortOrders(order) - val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) - if (topNPushed) { - sHolder.pushedLimit = Some(limit) - sHolder.sortValues = orders - operation + if (orders.length == order.length) { + val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) + if (topNPushed) { + sHolder.pushedLimit = Some(limit) + sHolder.sortValues = orders + operation + } else { + s + } } else { s } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index f999bb0edb58b..5df875b569669 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -227,6 +227,13 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel // LIMIT is pushed down only if all the filters are pushed down checkPushedLimit(df7) checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + + val df8 = spark.read + .table("h2.test.employee") + .sort(sub($"NAME")) + .limit(1) + checkPushedLimit(df8) + checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0))) } private def createSortValues( From bbecf9d956db74c0815f7fe6928befd17d31893c Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 16 Dec 2021 19:53:11 +0800 Subject: [PATCH 18/18] Update code --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 12 ++++++------ .../execution/datasources/jdbc/JDBCRelation.scala | 4 ++-- .../datasources/v2/V2ScanRelationPushDown.scala | 6 +++--- .../sql/execution/datasources/v2/jdbc/JDBCScan.scala | 4 ++-- .../datasources/v2/jdbc/JDBCScanBuilder.scala | 6 +++--- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e008c81858c6f..baee53847a5a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.sources._ @@ -171,7 +171,7 @@ object JDBCRDD extends Logging { groupByColumns: Option[Array[String]] = None, sample: Option[TableSampleInfo] = None, limit: Int = 0, - sortValues: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = { + sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -192,7 +192,7 @@ object JDBCRDD extends Logging { groupByColumns, sample, limit, - sortValues) + sortOrders) } // scalastyle:on argcount } @@ -214,7 +214,7 @@ private[jdbc] class JDBCRDD( groupByColumns: Option[Array[String]], sample: Option[TableSampleInfo], limit: Int, - sortValues: Array[SortOrder]) + sortOrders: Array[SortOrder]) extends RDD[InternalRow](sc, Nil) { /** @@ -263,8 +263,8 @@ private[jdbc] class JDBCRDD( } private def getOrderByClause: String = { - if (sortValues.nonEmpty) { - s" ORDER BY ${sortValues.map(_.describe()).mkString(", ")}" + if (sortOrders.nonEmpty) { + s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}" } else { "" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 1f9c3e0fa73b0..ecb207363cd59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -303,7 +303,7 @@ private[sql] case class JDBCRelation( groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], limit: Int, - sortValues: Array[SortOrder]): RDD[Row] = { + sortOrders: Array[SortOrder]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, @@ -316,7 +316,7 @@ private[sql] case class JDBCRelation( groupByColumns, tableSample, limit, - sortValues).asInstanceOf[RDD[Row]] + sortOrders).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index fb5ee3312b52d..e7c06d0b7520e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -261,7 +261,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val topNPushed = PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) if (topNPushed) { sHolder.pushedLimit = Some(limit) - sHolder.sortValues = orders + sHolder.sortOrders = orders operation } else { s @@ -294,7 +294,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case _ => Array.empty[sources.Filter] } val pushedDownOperators = PushedDownOperators(aggregation, - sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortValues) + sHolder.pushedSample, sHolder.pushedLimit, sHolder.sortOrders) V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } @@ -307,7 +307,7 @@ case class ScanBuilderHolder( builder: ScanBuilder) extends LeafNode { var pushedLimit: Option[Int] = None - var sortValues: Seq[SortOrder] = Seq.empty[SortOrder] + var sortOrders: Seq[SortOrder] = Seq.empty[SortOrder] var pushedSample: Option[TableSampleInfo] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index 0bb3812e5a33a..87ec9f43804e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -33,7 +33,7 @@ case class JDBCScan( groupByColumns: Option[Array[String]], tableSample: Option[TableSampleInfo], pushedLimit: Int, - sortValues: Array[SortOrder]) extends V1Scan { + sortOrders: Array[SortOrder]) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -49,7 +49,7 @@ case class JDBCScan( pushedAggregateColumn } relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns, tableSample, - pushedLimit, sortValues) + pushedLimit, sortOrders) } }.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index fd01e67984bbb..1760122133d22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -53,7 +53,7 @@ case class JDBCScanBuilder( private var pushedLimit = 0 - private var sortValues: Array[SortOrder] = Array.empty[SortOrder] + private var sortOrders: Array[SortOrder] = Array.empty[SortOrder] override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { @@ -133,7 +133,7 @@ case class JDBCScanBuilder( override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { if (jdbcOptions.pushDownLimit) { pushedLimit = limit - sortValues = orders + sortOrders = orders return true } false @@ -164,6 +164,6 @@ case class JDBCScanBuilder( // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortValues) + pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders) } }