diff --git a/docs/configs.md b/docs/configs.md
index 796d8a612c9..ae6ec6286a7 100644
--- a/docs/configs.md
+++ b/docs/configs.md
@@ -110,6 +110,7 @@ Name | Description | Default Value
spark.rapids.sql.join.leftSemi.enabled|When set to true left semi joins are enabled on the GPU|true
spark.rapids.sql.join.rightOuter.enabled|When set to true right outer joins are enabled on the GPU|true
spark.rapids.sql.metrics.level|GPU plans can produce a lot more metrics than CPU plans do. In very large queries this can sometimes result in going over the max result size limit for the driver. Supported values include DEBUG which will enable all metrics supported and typically only needs to be enabled when debugging the plugin. MODERATE which should output enough metrics to understand how long each part of the query is taking and how much data is going to each part of the query. ESSENTIAL which disables most metrics except those Apache Spark CPU plans will also report or their equivalents.|MODERATE
+spark.rapids.sql.opt.condition.maxBranchNumber|Maximum number of branches for GPU case-when to enable the lazy evaluation of true and else expressions if the predicates on a batch are all-true or all-false. Big number may get GPU OOM easily since the predicates are cached during the computation.|2
spark.rapids.sql.python.gpu.enabled|This is an experimental feature and is likely to change in the future. Enable (true) or disable (false) support for scheduling Python Pandas UDFs with GPU resources. When enabled, pandas UDFs are assumed to share the same GPU that the RAPIDs accelerator uses and will honor the python GPU configs|false
spark.rapids.sql.reader.batchSizeBytes|Soft limit on the maximum number of bytes the reader reads per batch. The readers will read chunks of data until this limit is met or exceeded. Note that the reader may estimate the number of bytes that will be used on the GPU in some cases based on the schema and number of rows in each batch.|2147483647
spark.rapids.sql.reader.batchSizeRows|Soft limit on the maximum number of rows the reader will read per batch. The orc and parquet readers will read row groups until this limit is met or exceeded. The limit is respected by the csv reader.|2147483647
diff --git a/integration_tests/src/main/python/conditionals_test.py b/integration_tests/src/main/python/conditionals_test.py
index 06e43f0f54c..5c67ebe63a6 100644
--- a/integration_tests/src/main/python/conditionals_test.py
+++ b/integration_tests/src/main/python/conditionals_test.py
@@ -40,11 +40,13 @@
if_nested_gens = if_array_gens_sample + if_struct_gens_sample
@pytest.mark.parametrize('data_gen', all_gens + if_nested_gens + decimal_128_gens_no_neg, ids=idfn)
-def test_if_else(data_gen):
+@pytest.mark.parametrize('pred_value', [True, False, None, "random"])
+def test_if_else(data_gen, pred_value):
(s1, s2) = gen_scalars_for_sql(data_gen, 2, force_no_nulls=not isinstance(data_gen, NullGen))
null_lit = get_null_lit_string(data_gen.data_type)
+ bool_gen = boolean_gen if pred_value == "random" else SetValuesGen(BooleanType(), [pred_value])
assert_gpu_and_cpu_are_equal_collect(
- lambda spark : three_col_df(spark, boolean_gen, data_gen, data_gen).selectExpr(
+ lambda spark : three_col_df(spark, bool_gen, data_gen, data_gen).selectExpr(
'IF(TRUE, b, c)',
'IF(TRUE, {}, {})'.format(s1, null_lit),
'IF(FALSE, {}, {})'.format(s1, null_lit),
@@ -67,11 +69,15 @@ def test_if_else_map(data_gen):
@pytest.mark.order(1) # at the head of xdist worker queue if pytest-order is installed
@pytest.mark.parametrize('data_gen', all_gens + all_nested_gens + decimal_128_gens, ids=idfn)
-def test_case_when(data_gen):
+@pytest.mark.parametrize('pred_value', [True, False, None, "random"])
+def test_case_when(data_gen, pred_value):
num_cmps = 20
s1 = gen_scalar(data_gen, force_no_nulls=not isinstance(data_gen, NullGen))
- # we want lots of false
- bool_gen = BooleanGen().with_special_case(False, weight=1000.0)
+ if pred_value == "random":
+ # we want lots of false
+ bool_gen = BooleanGen().with_special_case(False, weight=1000.0)
+ else:
+ bool_gen = SetValuesGen(BooleanType(), [pred_value])
gen_cols = [('_b' + str(x), bool_gen) for x in range(0, num_cmps)]
gen_cols = gen_cols + [('_c' + str(x), data_gen) for x in range(0, num_cmps)]
gen = StructGen(gen_cols, nullable=False)
diff --git a/jenkins/spark-premerge-build.sh b/jenkins/spark-premerge-build.sh
index a82254ce240..82367d43840 100755
--- a/jenkins/spark-premerge-build.sh
+++ b/jenkins/spark-premerge-build.sh
@@ -108,7 +108,7 @@ ci_2() {
export TEST_TYPE="pre-commit"
export TEST_PARALLEL=4
# separate process to avoid OOM kill
- TEST='conditionals_test or window_function_test' ./integration_tests/run_pyspark_from_build.sh
+ TEST_PARALLEL=2 TEST='conditionals_test or window_function_test' ./integration_tests/run_pyspark_from_build.sh
TEST_PARALLEL=5 TEST='struct_test or time_window_test' ./integration_tests/run_pyspark_from_build.sh
TEST='not conditionals_test and not window_function_test and not struct_test and not time_window_test' \
./integration_tests/run_pyspark_from_build.sh
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index e40697b6b46..bc67af685f8 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -2045,7 +2045,7 @@ object GpuOverrides extends Logging {
} else {
None
}
- GpuCaseWhen(branches, elseValue)
+ GpuCaseWhen(branches, elseValue, conf.maxConditionBranchNumber)
}
}),
expr[If](
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
index 6783f9b5ddf..7e55a9ce11c 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
@@ -1325,6 +1325,13 @@ object RapidsConf {
.booleanConf
.createWithDefault(value = false)
+ val MAX_CONDITION_BRANCH_NUMBER = conf("spark.rapids.sql.opt.condition.maxBranchNumber")
+ .doc("Maximum number of branches for GPU case-when to enable the lazy evaluation of true " +
+ "and else expressions if the predicates on a batch are all-true or all-false. Big number " +
+ "may get GPU OOM easily since the predicates are cached during the computation.")
+ .integerConf
+ .createWithDefault(2)
+
private def printSectionHeader(category: String): Unit =
println(s"\n### $category")
@@ -1741,6 +1748,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {
lazy val isFastSampleEnabled: Boolean = get(ENABLE_FAST_SAMPLE)
+ lazy val maxConditionBranchNumber: Int = get(MAX_CONDITION_BRANCH_NUMBER)
+
private val optimizerDefaults = Map(
// this is not accurate because CPU projections do have a cost due to appending values
// to each row that is produced, but this needs to be a really small number because
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala
index d108c64bf3b..62563cf4532 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala
@@ -29,31 +29,55 @@ trait GpuConditionalExpression extends ComplexTypeMergingExpression with GpuExpr
protected def computeIfElse(
batch: ColumnarBatch,
- predExpr: Expression,
+ pred: GpuColumnVector,
trueExpr: Expression,
falseValue: Any): GpuColumnVector = {
withResourceIfAllowed(falseValue) { falseRet =>
- withResource(GpuExpressionsUtils.columnarEvalToColumn(predExpr, batch)) { pred =>
- withResourceIfAllowed(trueExpr.columnarEval(batch)) { trueRet =>
- val finalRet = (trueRet, falseRet) match {
- case (t: GpuColumnVector, f: GpuColumnVector) =>
- pred.getBase.ifElse(t.getBase, f.getBase)
- case (t: GpuScalar, f: GpuColumnVector) =>
- pred.getBase.ifElse(t.getBase, f.getBase)
- case (t: GpuColumnVector, f: GpuScalar) =>
- pred.getBase.ifElse(t.getBase, f.getBase)
- case (t: GpuScalar, f: GpuScalar) =>
- pred.getBase.ifElse(t.getBase, f.getBase)
- case (t, f) =>
- throw new IllegalStateException(s"Unexpected inputs" +
- s" ($t: ${t.getClass}, $f: ${f.getClass})")
- }
- GpuColumnVector.from(finalRet, dataType)
+ withResourceIfAllowed(trueExpr.columnarEval(batch)) { trueRet =>
+ val finalRet = (trueRet, falseRet) match {
+ case (t: GpuColumnVector, f: GpuColumnVector) =>
+ pred.getBase.ifElse(t.getBase, f.getBase)
+ case (t: GpuScalar, f: GpuColumnVector) =>
+ pred.getBase.ifElse(t.getBase, f.getBase)
+ case (t: GpuColumnVector, f: GpuScalar) =>
+ pred.getBase.ifElse(t.getBase, f.getBase)
+ case (t: GpuScalar, f: GpuScalar) =>
+ pred.getBase.ifElse(t.getBase, f.getBase)
+ case (t, f) =>
+ throw new IllegalStateException(s"Unexpected inputs" +
+ s" ($t: ${t.getClass}, $f: ${f.getClass})")
}
+ GpuColumnVector.from(finalRet, dataType)
}
}
}
+ protected def isAllTrue(col: GpuColumnVector): Boolean = {
+ assert(BooleanType == col.dataType())
+ if (col.getRowCount == 0) {
+ return true
+ }
+ if (col.hasNull) {
+ return false
+ }
+ withResource(col.getBase.all()) { allTrue =>
+ // Guaranteed there is at least one row and no nulls so result must be valid
+ allTrue.getBoolean
+ }
+ }
+
+ protected def isAllFalse(col: GpuColumnVector): Boolean = {
+ assert(BooleanType == col.dataType())
+ if (col.getRowCount == col.numNulls()) {
+ // all nulls, and null values are false values here
+ return true
+ }
+ withResource(col.getBase.any()) { anyTrue =>
+ // null values are considered false values in this context
+ !anyTrue.getBoolean
+ }
+ }
+
}
case class GpuIf(
@@ -82,8 +106,19 @@ case class GpuIf(
}
}
- override def columnarEval(batch: ColumnarBatch): Any = computeIfElse(batch, predicateExpr,
- trueExpr, falseExpr.columnarEval(batch))
+ override def columnarEval(batch: ColumnarBatch): Any = {
+ withResource(GpuExpressionsUtils.columnarEvalToColumn(predicateExpr, batch)) { pred =>
+ if (isAllTrue(pred)) {
+ // All are true
+ GpuExpressionsUtils.columnarEvalToColumn(trueExpr, batch)
+ } else if (isAllFalse(pred)) {
+ // All are false
+ GpuExpressionsUtils.columnarEvalToColumn(falseExpr, batch)
+ } else {
+ computeIfElse(batch, pred, trueExpr, falseExpr.columnarEval(batch))
+ }
+ }
+ }
override def toString: String = s"if ($predicateExpr) $trueExpr else $falseExpr"
@@ -93,7 +128,8 @@ case class GpuIf(
case class GpuCaseWhen(
branches: Seq[(Expression, Expression)],
- elseValue: Option[Expression] = None) extends GpuConditionalExpression with Serializable {
+ elseValue: Option[Expression] = None,
+ maxBranchNumForOpt: Int = 2) extends GpuConditionalExpression with Serializable {
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
@@ -129,16 +165,59 @@ case class GpuCaseWhen(
}
}
- override def columnarEval(batch: ColumnarBatch): Any = {
- // `elseRet` will be closed in `computeIfElse`.
- val elseRet = elseValue
- .map(_.columnarEval(batch))
- .getOrElse(GpuScalar(null, branches.last._2.dataType))
- branches.foldRight[Any](elseRet) { case ((predicateExpr, trueExpr), falseRet) =>
- computeIfElse(batch, predicateExpr, trueExpr, falseRet)
+ private def computeWithTrueFalseOpt(batch: ColumnarBatch, trueExprs: Seq[Expression]): Any = {
+ val predicates = new Array[GpuColumnVector](branches.size)
+ var isAllPredsFalse = true
+
+ withResource(predicates) { preds =>
+ branches.zipWithIndex.foreach { case ((predExpr, trueExpr), i) =>
+ val p = GpuExpressionsUtils.columnarEvalToColumn(predExpr, batch)
+ preds(i) = p
+ if (isAllPredsFalse && isAllTrue(p)) {
+ // If any predicate is the first all-true, then evaluate its true expression
+ // and return the result.
+ return GpuExpressionsUtils.columnarEvalToColumn(trueExpr, batch)
+ }
+ isAllPredsFalse = isAllPredsFalse && isAllFalse(p)
+ }
+
+ val elseRet = elseValue
+ .map(_.columnarEval(batch))
+ .getOrElse(GpuScalar(null, branches.last._2.dataType))
+ if (isAllPredsFalse) {
+ // No predicate has a true, so return the else value.
+ GpuExpressionsUtils.resolveColumnVector(elseRet, batch.numRows())
+ } else {
+ preds.zip(trueExprs).foldRight[Any](elseRet) { case ((p, trueExpr), falseRet) =>
+ computeIfElse(batch, p, trueExpr, falseRet)
+ }
+ }
}
}
+ @transient
+ private[this] lazy val computationFunc = if (branches.length <= maxBranchNumForOpt) {
+ // Run into the optimization only when the branch number is not bigger than the
+ // limitation. Since the predicate result will be cached during the computation,
+ // and caching too many predicates can get GPU OOM easily.
+ val trueExpressions = branches.map(_._2)
+ (batch: ColumnarBatch) => computeWithTrueFalseOpt(batch, trueExpressions)
+ } else {
+ (batch: ColumnarBatch) => {
+ // `elseRet` will be closed in `computeIfElse`.
+ val elseRet = elseValue
+ .map(_.columnarEval(batch))
+ .getOrElse(GpuScalar(null, branches.last._2.dataType))
+ branches.foldRight[Any](elseRet) { case ((predicateExpr, trueExpr), falseRet) =>
+ withResource(GpuExpressionsUtils.columnarEvalToColumn(predicateExpr, batch)) { pred =>
+ computeIfElse(batch, pred, trueExpr, falseRet)
+ }
+ }
+ }
+ }
+
+ override def columnarEval(batch:ColumnarBatch): Any = computationFunc(batch)
+
override def toString: String = {
val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
val elseCase = elseValue.map(" ELSE " + _).getOrElse("")