From 3731c9c6a4629b6bcb9f7225c5e416eeee046b1b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 18 Feb 2022 22:22:04 +0800 Subject: [PATCH] [SPARK-37867][SQL][FOLLOWUP] Compile aggregate functions for build-in DB2 dialect ### What changes were proposed in this pull request? This PR follows up https://github.com/apache/spark/pull/35166. The previously referenced DB2 documentation is incorrect, resulting in the lack of compile that supports some aggregate functions. The correct documentation is https://www.ibm.com/docs/en/db2/11.5?topic=af-regression-functions-regr-avgx-regr-avgy-regr-count ### Why are the changes needed? Make build-in DB2 dialect support complete aggregate push-down more aggregate functions. ### Does this PR introduce _any_ user-facing change? 'Yes'. Users could use complete aggregate push-down with build-in DB2 dialect. ### How was this patch tested? New tests. Closes #35520 from beliefer/SPARK-37867_followup. Authored-by: Jiaan Geng Signed-off-by: Wenchen Fan --- .../sql/jdbc/v2/DB2IntegrationSuite.scala | 9 +++ .../jdbc/v2/MsSqlServerIntegrationSuite.scala | 4 ++ .../jdbc/v2/PostgresIntegrationSuite.scala | 7 +++ .../apache/spark/sql/jdbc/v2/V2JDBCTest.scala | 63 ++++++++++--------- .../apache/spark/sql/jdbc/DB2Dialect.scala | 19 ++++++ .../apache/spark/sql/jdbc/DerbyDialect.scala | 23 +++---- .../spark/sql/jdbc/MsSqlServerDialect.scala | 3 + .../apache/spark/sql/jdbc/MySQLDialect.scala | 21 +++---- .../apache/spark/sql/jdbc/OracleDialect.scala | 38 +++++------ .../spark/sql/jdbc/PostgresDialect.scala | 1 + .../spark/sql/jdbc/TeradataDialect.scala | 18 +++--- 11 files changed, 123 insertions(+), 83 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index d0479e9032e06..35711e57d0b72 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -97,4 +97,13 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) testVarPop() + testVarPop(true) + testVarSamp() + testVarSamp(true) + testStddevPop() + testStddevPop(true) + testStddevSamp() + testStddevSamp(true) + testCovarPop() + testCovarSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 536eb465ceb11..4df5f4525a0fa 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -97,7 +97,11 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD } testVarPop() + testVarPop(true) testVarSamp() + testVarSamp(true) testStddevPop() + testStddevPop(true) testStddevSamp() + testStddevSamp(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index b3004e1c21c89..d76e13c1cd421 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -91,10 +91,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT override def indexOptions: String = "FILLFACTOR=70" testVarPop() + testVarPop(true) testVarSamp() + testVarSamp(true) testStddevPop() + testStddevPop(true) testStddevSamp() + testStddevSamp(true) testCovarPop() + testCovarPop(true) testCovarSamp() + testCovarSamp(true) testCorr() + testCorr(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 667579b20eaf7..7cab8cd77df66 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -386,10 +386,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu protected def caseConvert(tableName: String): String = tableName - protected def testVarPop(): Unit = { - test(s"scan with aggregate push-down: VAR_POP") { - val df = sql(s"SELECT VAR_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + protected def testVarPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct") { + val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "VAR_POP") @@ -401,11 +402,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testVarSamp(): Unit = { - test(s"scan with aggregate push-down: VAR_SAMP") { + protected def testVarSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: VAR_SAMP with distinct: $isDistinct") { val df = sql( - s"SELECT VAR_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "VAR_SAMP") @@ -417,11 +419,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testStddevPop(): Unit = { - test("scan with aggregate push-down: STDDEV_POP") { + protected def testStddevPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: STDDEV_POP with distinct: $isDistinct") { val df = sql( - s"SELECT STDDEV_POP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "STDDEV_POP") @@ -433,11 +436,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testStddevSamp(): Unit = { - test("scan with aggregate push-down: STDDEV_SAMP") { + protected def testStddevSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: STDDEV_SAMP with distinct: $isDistinct") { val df = sql( - s"SELECT STDDEV_SAMP(bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "STDDEV_SAMP") @@ -449,11 +453,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testCovarPop(): Unit = { - test("scan with aggregate push-down: COVAR_POP") { + protected def testCovarPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: COVAR_POP with distinct: $isDistinct") { val df = sql( - s"SELECT COVAR_POP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "COVAR_POP") @@ -465,11 +470,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testCovarSamp(): Unit = { - test("scan with aggregate push-down: COVAR_SAMP") { + protected def testCovarSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: COVAR_SAMP with distinct: $isDistinct") { val df = sql( - s"SELECT COVAR_SAMP(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "COVAR_SAMP") @@ -481,11 +487,12 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testCorr(): Unit = { - test("scan with aggregate push-down: CORR") { + protected def testCorr(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") { val df = sql( - s"SELECT CORR(bonus, bonus) FROM $catalogAndNamespace.${caseConvert("employee")}" + - " WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "CORR") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index ffda7545c6e9f..dd68953badf7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -30,6 +30,7 @@ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") + // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { @@ -37,6 +38,24 @@ private object DB2Dialect extends JdbcDialect { assert(f.inputs().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" Some(s"VARIANCE($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARIANCE_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.inputs().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVARIANCE(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => + assert(f.inputs().length == 2) + Some(s"COVARIANCE_SAMP(${f.inputs().head}, ${f.inputs().last})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index e87d4d08ae031..bf838b8ed66eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -30,25 +30,22 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") + // See https://db.apache.org/derby/docs/10.15/ref/index.html override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + Some(s"STDDEV_SAMP(${f.inputs().head})") case _ => None } ) @@ -72,7 +69,7 @@ private object DerbyDialect extends JdbcDialect { override def isCascadingTruncateTable(): Option[Boolean] = Some(false) - // See https://db.apache.org/derby/docs/10.5/ref/rrefsqljrenametablestatement.html + // See https://db.apache.org/derby/docs/10.15/ref/rrefsqljrenametablestatement.html override def renameTable(oldTable: String, newTable: String): String = { s"RENAME TABLE $oldTable TO $newTable" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index 3d8a48a66ea8f..841f1c87319b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -40,6 +40,9 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver") + // scalastyle:off line.size.limit + // See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15 + // scalastyle:on line.size.limit override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index c32499b5f32e1..b1093a4f2f7c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -38,25 +38,22 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") + // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") + Some(s"STDDEV_SAMP(${f.inputs().head})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 4fe7d93142c1e..71db7e9285f5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -34,37 +34,33 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") + // scalastyle:off line.size.limit + // https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848 + // scalastyle:on line.size.limit override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { - case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + Some(s"VAR_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"VAR_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + Some(s"VAR_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_POP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + Some(s"STDDEV_POP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => assert(f.inputs().length == 1) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"STDDEV_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + Some(s"STDDEV_SAMP(${f.inputs().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => + Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"CORR(${f.inputs().head}, ${f.inputs().last})") case _ => None } ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 46e79404f3e54..e2023d110ae4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -36,6 +36,7 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") + // See https://www.postgresql.org/docs/8.4/functions-aggregate.html override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 6344667b3180e..13e16d24d048d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -28,6 +28,9 @@ private case object TeradataDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata") + // scalastyle:off line.size.limit + // See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions + // scalastyle:on line.size.limit override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { super.compileAggregate(aggFunction).orElse( aggFunction match { @@ -47,18 +50,15 @@ private case object TeradataDialect extends JdbcDialect { assert(f.inputs().length == 1) val distinct = if (f.isDistinct) "DISTINCT " else "" Some(s"STDDEV_SAMP($distinct${f.inputs().head})") - case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_POP($distinct${f.inputs().head}, ${f.inputs().last})") - case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + Some(s"COVAR_POP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"COVAR_SAMP($distinct${f.inputs().head}, ${f.inputs().last})") - case f: GeneralAggregateFunc if f.name() == "CORR" => + Some(s"COVAR_SAMP(${f.inputs().head}, ${f.inputs().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => assert(f.inputs().length == 2) - val distinct = if (f.isDistinct) "DISTINCT " else "" - Some(s"CORR($distinct${f.inputs().head}, ${f.inputs().last})") + Some(s"CORR(${f.inputs().head}, ${f.inputs().last})") case _ => None } )