Skip to content

Commit

Permalink
Add Std dev for windowing
Browse files Browse the repository at this point in the history
Signed-off-by: Raza Jafri <[email protected]>
  • Loading branch information
razajafri committed Oct 26, 2021
1 parent 40b35b2 commit 005ba4e
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 121 deletions.
18 changes: 18 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,21 @@ def test_window_ride_along(ride_along):
' row_number() over (order by a) as row_num '
'from window_agg_table ',
conf = allow_negative_scale_of_decimal_conf)

@approximate_float
@pytest.mark.parametrize('part_gen', numeric_gens, ids=idfn)
def test_window_stddev(part_gen):
window_spec_agg = Window.partitionBy('_1')
window_spec = Window.partitionBy('_1').orderBy("_2")

def do_it(spark):
data_gen = [('_1', IntegerGen()), ('_2', part_gen)]
df = gen_df(spark, data_gen)
return df.withColumn("row", f.row_number().over(window_spec))\
.withColumn("standard_dev", f.stddev("_2").over(window_spec_agg))\
.selectExpr("standard_dev")

assert_gpu_and_cpu_are_equal_collect(do_it, conf={
'spark.rapids.sql.decimalType.enabled': 'true',
'spark.rapids.sql.castDecimalToFloat.enabled': 'true'})

Original file line number Diff line number Diff line change
Expand Up @@ -3128,9 +3128,20 @@ object GpuOverrides extends Logging {
}),
expr[StddevSamp](
"Aggregation computing sample standard deviation",
ExprChecksImpl(
ExprChecks.groupByOnly(
TypeSig.DOUBLE, TypeSig.DOUBLE,
Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))),
Seq(ParamCheck("input", TypeSig.DOUBLE,
TypeSig.DOUBLE))).asInstanceOf[ExprChecksImpl].contexts
++
ExprChecks.windowOnly(
TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL, TypeSig.orderable,
Seq(ParamCheck("input",
(TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL)
.withPsNote(TypeEnum.DOUBLE, nanAggPsNote)
.withPsNote(TypeEnum.FLOAT, nanAggPsNote),
TypeSig.orderable))
).asInstanceOf[ExprChecksImpl].contexts),
(a, conf, p, r) => new AggExprMeta[StddevSamp](a, conf, p, r) {
override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = {
val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.rapids.GpuAggregateExpression
import org.apache.spark.sql.types.{ArrayType, ByteType, CalendarIntervalType, DataType, IntegerType, LongType, MapType, ShortType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -319,9 +319,19 @@ object GpuWindowExec extends Arm {

exprs.foreach { expr =>
if (hasGpuWindowFunction(expr)) {
// First pass looks for GpuWindowFunctions and GpuWindowSpecDefinitions to build up
// First pass replace any operations that should be totally replaced.
val replacePass = expr.transformDown {
case GpuWindowExpression(
GpuAggregateExpression(rep: GpuReplaceWindowFunction, _, _, _, _), spec) =>
// We don't actually care about the GpuAggregateExpression because it is ignored
// by our GPU window operations anyways.
rep.windowReplacement(spec)
case GpuWindowExpression(rep: GpuReplaceWindowFunction, spec) =>
rep.windowReplacement(spec)
}
// Second pass looks for GpuWindowFunctions and GpuWindowSpecDefinitions to build up
// the preProject phase
val firstPass = expr.transformDown {
val secondPass = replacePass.transformDown {
case wf: GpuWindowFunction =>
// All window functions, including those that are also aggregation functions, are
// wrapped in a GpuWindowExpression, so dedup and save their children into the pre
Expand All @@ -340,14 +350,15 @@ object GpuWindowExec extends Arm {
}.toArray.toSeq
wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec)
}
val secondPass = firstPass.transformDown {
// Final pass is to extract, dedup, and save the results.
val finalPass = secondPass.transformDown {
case we: GpuWindowExpression =>
// A window Expression holds a window function or an aggregate function, so put it into
// the windowOps phase, and create a new alias for it for the post phase
extractAndSave(we, windowOps, windowDedupe)
}.asInstanceOf[NamedExpression]

postProject += secondPass
postProject += finalPass
} else {
// There is no window function so pass the result through all of the phases (with deduping)
postProject += extractAndSave(
Expand Down Expand Up @@ -470,28 +481,17 @@ case class BoundGpuWindowFunction(
aggFunc.windowAggregation(inputs).overWindow(windowOpts)
}

def windowOutput(cv: cudf.ColumnVector): cudf.ColumnVector = {
val aggFunc = windowFunc.asInstanceOf[GpuAggregateWindowFunction]
aggFunc.windowOutput(cv)
}

val dataType: DataType = windowFunc.dataType
}

case class ParsedBoundary(isUnbounded: Boolean, valueAsLong: Long)

object GroupedAggregations extends Arm {
// In some cases a scan or a group by scan produces a different type than window would for the
// same aggregation. A lot of this is because scan has a limited set of aggregations so we can
// end up using a SUM aggregation to work around other issues, and cudf rightly makes the output
// an INT64 instead of an INT32. This is here to fix that up.
private def castIfNeeded(
col: cudf.ColumnVector,
dataType: DataType): GpuColumnVector = {
dataType match {
case _: ArrayType | _: StructType | _: MapType =>
GpuColumnVector.from(col, dataType).incRefCount()
case other =>
val dtype = GpuColumnVector.getNonNestedRapidsType(other)
GpuColumnVector.from(col.castTo(dtype), dataType)
}
}

/**
* Get the window options for an aggregation
* @param orderSpec the order by spec
Expand Down Expand Up @@ -702,11 +702,11 @@ class GroupedAggregations extends Arm {
}
withResource(result) { result =>
functions.zipWithIndex.foreach {
case ((_, outputIndexes), resultIndex) =>
case ((func, outputIndexes), resultIndex) =>
val aggColumn = result.getColumn(resultIndex)

outputIndexes.foreach { outIndex =>
outputColumns(outIndex) = aggColumn.incRefCount()
outputColumns(outIndex) = func.windowOutput(aggColumn)
}
}
}
Expand Down Expand Up @@ -956,18 +956,17 @@ class GroupedAggregations extends Arm {
}

/**
* Turn the final result of the aggregations into a ColumnarBatch. Because of some differences in
* output types between cudf and Spark a cast may be done before to fix it up.
* Turn the final result of the aggregations into a ColumnarBatch.
*/
def castAggOutputsIfNeeded(dataTypes: Array[DataType],
def convertToColumnarBatch(dataTypes: Array[DataType],
aggOutputColumns: Array[cudf.ColumnVector]): ColumnarBatch = {
assert(dataTypes.length == aggOutputColumns.length)
val numRows = aggOutputColumns.head.getRowCount.toInt
closeOnExcept(new Array[ColumnVector](aggOutputColumns.length)) { finalOutputColumns =>
dataTypes.indices.foreach { index =>
val dt = dataTypes(index)
val col = aggOutputColumns(index)
finalOutputColumns(index) = castIfNeeded(col, dt)
finalOutputColumns(index) = GpuColumnVector.from(col, dt).incRefCount()
}
new ColumnarBatch(finalOutputColumns, numRows)
}
Expand Down Expand Up @@ -1066,9 +1065,9 @@ trait BasicWindowCalc extends Arm {
}
}

def castResultsIfNeeded(dataTypes: Array[DataType],
def convertToBatch(dataTypes: Array[DataType],
cols: Array[cudf.ColumnVector]): ColumnarBatch =
aggregations.castAggOutputsIfNeeded(dataTypes, cols)
aggregations.convertToColumnarBatch(dataTypes, cols)
}

/**
Expand All @@ -1094,7 +1093,7 @@ class GpuWindowIterator(
withResource(input.next()) { cb =>
withResource(new NvtxWithMetrics("window", NvtxColor.CYAN, opTime)) { _ =>
val ret = withResource(computeBasicWindow(cb)) { cols =>
castResultsIfNeeded(outputTypes, cols)
convertToBatch(outputTypes, cols)
}
numOutputBatches += 1
numOutputRows += ret.numRows()
Expand Down Expand Up @@ -1295,7 +1294,7 @@ class GpuRunningWindowIterator(
}
withResource(fixedUp) { fixed =>
saveLastParts(getScalarRow(numRows - 1, partColumns))
castResultsIfNeeded(outputTypes, fixed)
convertToBatch(outputTypes, fixed)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,14 @@ case class GpuSpecialFrameBoundary(boundary : SpecialFrameBoundary)
// Spark. This may expand in the future if other types of window functions show up.
trait GpuWindowFunction extends GpuUnevaluable with ShimExpression

/**
* This is a special window function that simply replaces itself with one or more
* window functions and other expressions that can be executed.
*/
trait GpuReplaceWindowFunction extends GpuWindowFunction {
def windowReplacement(spec: GpuWindowSpecDefinition): Expression
}

/**
* GPU Counterpart of `AggregateWindowFunction`.
* On the CPU this would extend `DeclarativeAggregate` and use the provided methods
Expand All @@ -638,6 +646,13 @@ trait GpuAggregateWindowFunction extends GpuWindowFunction {
* corresponding ColumnVector. Some aggregations need extra values.
*/
def windowAggregation(inputs: Seq[(ColumnVector, Int)]): RollingAggregationOnColumn

/**
* Do a final pass over the window aggregation output. This lets us cast the result to a desired
* type or check for overflow. This is not used for GpuRunningWindowFunction. There you can use
* `scanCombine`.
*/
def windowOutput(result: ColumnVector): ColumnVector = result.incRefCount()
}

/**
Expand Down Expand Up @@ -693,7 +708,7 @@ trait GpuRunningWindowFunction extends GpuWindowFunction {
def scanAggregation(isRunningBatched: Boolean): Seq[AggAndReplace[ScanAggregation]]

/**
* Should a group by scan be run or not. This should never return false unless this is also an
* Should a scan be run or not. This should never return false unless this is also an
* instance of `GpuAggregateWindowFunction` so the window code can fall back to it for
* computation.
*/
Expand Down Expand Up @@ -1318,6 +1333,10 @@ case object GpuRowNumber extends GpuRunningWindowFunction
groupByScanInputProjection(isRunningBatched)
override def scanAggregation(isRunningBatched: Boolean): Seq[AggAndReplace[ScanAggregation]] =
Seq(AggAndReplace(ScanAggregation.sum(), None))

override def scanCombine(isRunningBatched: Boolean, cols: Seq[ColumnVector]): ColumnVector = {
cols.head.castTo(DType.INT32)
}
}

trait GpuOffsetWindowFunction extends GpuAggregateWindowFunction {
Expand Down
Loading

0 comments on commit 005ba4e

Please sign in to comment.