diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 5ac9a5191b010..2cfb21395d8a9 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection +import java.util.Locale import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -36,8 +37,9 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "db2" + override val namespaceOpt: Option[String] = Some("DB2INST1") override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", "ibmcom/db2:11.5.6.0a") override val env = Map( @@ -59,8 +61,13 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.db2.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.db2.pushDownAggregate", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)") + .executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -86,4 +93,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + testVarPop() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala new file mode 100644 index 0000000000000..72edfc9f1bf1c --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import org.apache.spark.sql.jdbc.DockerJDBCIntegrationSuite + +abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite { + + /** + * Prepare databases and tables for testing. + */ + override def dataPreparation(connection: Connection): Unit = { + tablePreparation(connection) + connection.prepareStatement("INSERT INTO employee VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() + } + + def tablePreparation(connection: Connection): Unit +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 75446fb50e45b..e9521ec35a8ce 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -37,7 +37,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mssql" @@ -58,10 +58,15 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBC override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mssql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.mssql.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)") + .executeUpdate() + } override def notSupportsTableComment: Boolean = true @@ -91,4 +96,9 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBC assert(msg.contains("UpdateColumnNullability is not supported")) } + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 71adc51b87441..bc4bf54324ee5 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest * */ @DockerTest -class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mysql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") @@ -57,13 +57,17 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mysql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mysql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.mysql.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) private var mySQLVersion = -1 - override def dataPreparation(conn: Connection): Unit = { - mySQLVersion = conn.getMetaData.getDatabaseMajorVersion + override def tablePreparation(connection: Connection): Unit = { + mySQLVersion = connection.getMetaData.getDatabaseMajorVersion + connection.prepareStatement( + "CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," + + " bonus DOUBLE)").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -119,4 +123,9 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def supportsIndex: Boolean = true override def indexOptions: String = "KEY_BLOCK_SIZE=10" + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index ef8fe5354c540..2669924dc28c0 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection +import java.util.Locale import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -54,8 +55,9 @@ import org.apache.spark.tags.DockerTest * This procedure has been validated with Oracle 18.4.0 Express Edition. */ @DockerTest -class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "oracle" + override val namespaceOpt: Option[String] = Some("SYSTEM") override val db = new DatabaseOnDocker { lazy override val imageName = sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-xe:18.4.0") @@ -73,9 +75,15 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.oracle.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) - override def dataPreparation(conn: Connection): Unit = {} + + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + + " bonus BINARY_DOUBLE)").executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -93,4 +101,14 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest assert(msg1.contains( s"Cannot update $catalogName.alt_table field ID: string cannot be cast to int")) } + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() + testCovarPop() + testCovarSamp() + testCorr() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 7fba6671ffe71..86f5c3c8cd418 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -22,7 +22,7 @@ import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -34,7 +34,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:14.0-alpine") @@ -51,8 +51,13 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) .set("spark.sql.catalog.postgresql.pushDownTableSample", "true") .set("spark.sql.catalog.postgresql.pushDownLimit", "true") + .set("spark.sql.catalog.postgresql.pushDownAggregate", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + + " bonus double precision)").executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -84,4 +89,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes override def supportsIndex: Boolean = true override def indexOptions: String = "FILLFACTOR=70" + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() + testCovarPop() + testCovarSamp() + testCorr() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 49aa20387e38e..6ea2099346781 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.jdbc.v2 import org.apache.logging.log4j.Level -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Sample} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample} import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex +import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession @@ -36,6 +36,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu import testImplicits._ val catalogName: String + + val namespaceOpt: Option[String] = None + + private def catalogAndNamespace = + namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) + // dialect specific update column type test def testUpdateColumnType(tbl: String): Unit @@ -246,22 +252,30 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu def supportsTableSample: Boolean = false - private def samplePushed(df: DataFrame): Boolean = { + private def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { val sample = df.queryExecution.optimizedPlan.collect { case s: Sample => s } - sample.isEmpty + if (pushed) { + assert(sample.isEmpty) + } else { + assert(sample.nonEmpty) + } } - private def filterPushed(df: DataFrame): Boolean = { + private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { val filter = df.queryExecution.optimizedPlan.collect { case f: Filter => f } - filter.isEmpty + if (pushed) { + assert(filter.isEmpty) + } else { + assert(filter.nonEmpty) + } } private def limitPushed(df: DataFrame, limit: Int): Boolean = { - val filter = df.queryExecution.optimizedPlan.collect { + df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => return v1.pushedDownOperators.limit == Some(limit) @@ -270,11 +284,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu false } - private def columnPruned(df: DataFrame, col: String): Boolean = { + private def checkColumnPruned(df: DataFrame, col: String): Unit = { val scan = df.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get - scan.schema.names.sameElements(Seq(col)) + assert(scan.schema.names.sameElements(Seq(col))) } test("SPARK-37038: Test TABLESAMPLE") { @@ -286,37 +300,37 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu // sample push down + column pruning val df1 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + " REPEATABLE (12345)") - assert(samplePushed(df1)) - assert(columnPruned(df1, "col1")) + checkSamplePushed(df1) + checkColumnPruned(df1, "col1") assert(df1.collect().length < 10) // sample push down only val df2 = sql(s"SELECT * FROM $catalogName.new_table TABLESAMPLE (50 PERCENT)" + " REPEATABLE (12345)") - assert(samplePushed(df2)) + checkSamplePushed(df2) assert(df2.collect().length < 10) // sample(BUCKET ... OUT OF) push down + limit push down + column pruning val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + " LIMIT 2") - assert(samplePushed(df3)) + checkSamplePushed(df3) assert(limitPushed(df3, 2)) - assert(columnPruned(df3, "col1")) + checkColumnPruned(df3, "col1") assert(df3.collect().length <= 2) // sample(... PERCENT) push down + limit push down + column pruning val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" + " TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2") - assert(samplePushed(df4)) + checkSamplePushed(df4) assert(limitPushed(df4, 2)) - assert(columnPruned(df4, "col1")) + checkColumnPruned(df4, "col1") assert(df4.collect().length <= 2) // sample push down + filter push down + limit push down val df5 = sql(s"SELECT * FROM $catalogName.new_table" + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") - assert(samplePushed(df5)) - assert(filterPushed(df5)) + checkSamplePushed(df5) + checkFilterPushed(df5) assert(limitPushed(df5, 2)) assert(df5.collect().length <= 2) @@ -325,27 +339,161 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu // Todo: push down filter/limit val df6 = sql(s"SELECT col1 FROM $catalogName.new_table" + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") - assert(samplePushed(df6)) - assert(!filterPushed(df6)) + checkSamplePushed(df6) + checkFilterPushed(df6, false) assert(!limitPushed(df6, 2)) - assert(columnPruned(df6, "col1")) + checkColumnPruned(df6, "col1") assert(df6.collect().length <= 2) // sample + limit // Push down order is sample -> filter -> limit // only limit is pushed down because in this test sample is after limit val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5) - assert(!samplePushed(df7)) + checkSamplePushed(df7, false) assert(limitPushed(df7, 2)) // sample + filter // Push down order is sample -> filter -> limit // only filter is pushed down because in this test sample is after filter val df8 = spark.read.table(s"$catalogName.new_table").where($"col1" > 1).sample(0.5) - assert(!samplePushed(df8)) - assert(filterPushed(df8)) + checkSamplePushed(df8, false) + checkFilterPushed(df8) assert(df8.collect().length < 10) } } } + + protected def checkAggregateRemoved(df: DataFrame): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + assert(aggregates.isEmpty) + } + + private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions.length == 1) + assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc]) + assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName) + } + } + + protected def caseConvert(tableName: String): String = tableName + + protected def testVarPop(): Unit = { + test(s"scan with aggregate push-down: VAR_POP") { + val df = sql(s"SELECT VAR_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "VAR_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 10000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testVarSamp(): Unit = { + test(s"scan with aggregate push-down: VAR_SAMP") { + val df = sql( + s"SELECT VAR_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "VAR_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000d) + assert(row(1).getDouble(0) === 5000d) + assert(row(2).isNullAt(0)) + } + } + + protected def testStddevPop(): Unit = { + test("scan with aggregate push-down: STDDEV_POP") { + val df = sql( + s"SELECT STDDEV_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "STDDEV_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 100d) + assert(row(1).getDouble(0) === 50d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testStddevSamp(): Unit = { + test("scan with aggregate push-down: STDDEV_SAMP") { + val df = sql( + s"SELECT STDDEV_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "STDDEV_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 141.4213562373095d) + assert(row(1).getDouble(0) === 70.71067811865476d) + assert(row(2).isNullAt(0)) + } + } + + protected def testCovarPop(): Unit = { + test("scan with aggregate push-down: COVAR_POP") { + val df = sql( + s"SELECT COVAR_POP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "COVAR_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 10000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testCovarSamp(): Unit = { + test("scan with aggregate push-down: COVAR_SAMP") { + val df = sql( + s"SELECT COVAR_SAMP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "COVAR_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000d) + assert(row(1).getDouble(0) === 5000d) + assert(row(2).isNullAt(0)) + } + } + + protected def testCorr(): Unit = { + test("scan with aggregate push-down: CORR") { + val df = sql( + s"SELECT CORR(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + + " WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "CORR") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 1d) + assert(row(1).getDouble(0) === 1d) + assert(row(2).isNullAt(0)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 0b394db5c8932..9e9aac679ab39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { @@ -27,6 +28,18 @@ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARIANCE($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index f19ef7ead5f8e..e87d4d08ae031 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -29,6 +30,30 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) Option(FloatType) else None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 8e5674a181e7a..442c5599b3ab3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -36,6 +37,30 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDEVP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDEV($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (typeName.contains("datetimeoffset")) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index fb98996e6bf8b..9fcb7a27d17af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} @@ -35,6 +36,30 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index b741ece8dda9b..4fe7d93142c1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp, Types} import java.util.{Locale, TimeZone} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -33,6 +34,42 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + private def supportTimeZoneTypes: Boolean = { val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) // TODO: support timezone types when users are not using the JVM timezone, which diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 356cb4ddbd008..3b1a2c81fffd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ @@ -35,6 +36,42 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 13f4c5fe9c926..6344667b3180e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ @@ -27,6 +28,42 @@ private case object TeradataDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata") + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.inputs().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + case _ => None + } + ) + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR))