From f9ca4fb23ba7c4de910f7f3c2ab13fb65b1a40f8 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 14 Oct 2021 11:55:44 -0700 Subject: [PATCH] Add Std dev for windowing Signed-off-by: Raza Jafri --- .../src/main/python/window_function_test.py | 17 +++++++++++++++++ .../com/nvidia/spark/rapids/GpuOverrides.scala | 13 ++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/integration_tests/src/main/python/window_function_test.py b/integration_tests/src/main/python/window_function_test.py index e625e01bc50b..e925c15863d5 100644 --- a/integration_tests/src/main/python/window_function_test.py +++ b/integration_tests/src/main/python/window_function_test.py @@ -606,6 +606,7 @@ def do_it(spark): .withColumn('inc_min_c', f.min('c').over(inclusiveWindowSpec)) \ .withColumn('rank_val', f.rank().over(baseWindowSpec)) \ .withColumn('dense_rank_val', f.dense_rank().over(baseWindowSpec)) \ + .withColumn('stddev_val', f.stddev('').over(baseWindowSpec)) \ .withColumn('row_num', f.row_number().over(baseWindowSpec)) assert_gpu_and_cpu_are_equal_collect(do_it, conf={'spark.rapids.sql.hasNans': 'false'}) @@ -905,3 +906,19 @@ def test_window_ride_along(ride_along): ' row_number() over (order by a) as row_num ' 'from window_agg_table ', conf = allow_negative_scale_of_decimal_conf) + +def test_window_stddev(): + window_spec_agg = Window.partitionBy('_1') + window_spec = Window.partitionBy('_1').orderBy("_2") + + def do_it(spark): + data = [[1,3],[1,5],[2,3],[2,7],[9,9]] + schema=[StructField("_1", IntegerType(), True), StructField("_2", IntegerType(), True)] + df=spark.createDataFrame(SparkContext.getOrCreate().parallelize(data), StructType(schema)) + return df.withColumn("row", f.row_number().over(window_spec))\ + .withColumn("stddev", f.stddev("_2").over(window_spec_agg)).select("stddev") + + assert_gpu_and_cpu_are_equal_collect(do_it, conf={ + 'spark.rapids.sql.decimalType.enabled': 'true', + 'spark.rapids.sql.castDecimalToFloat.enabled': 'true'}) + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index b0adcb2f042c..8727f5bad504 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3128,9 +3128,20 @@ object GpuOverrides extends Logging { }), expr[StddevSamp]( "Aggregation computing sample standard deviation", + ExprChecksImpl( ExprChecks.groupByOnly( TypeSig.DOUBLE, TypeSig.DOUBLE, - Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), + Seq(ParamCheck("input", TypeSig.DOUBLE, + TypeSig.DOUBLE))).asInstanceOf[ExprChecksImpl].contexts + ++ + ExprChecks.windowOnly( + TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL, TypeSig.orderable, + Seq(ParamCheck("input", + (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL) + .withPsNote(TypeEnum.DOUBLE, nanAggPsNote) + .withPsNote(TypeEnum.FLOAT, nanAggPsNote), + TypeSig.orderable)) + ).asInstanceOf[ExprChecksImpl].contexts), (a, conf, p, r) => new AggExprMeta[StddevSamp](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate