From c2eac1de020bd64501acbdfe341f2f4b6657a6e9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 28 Dec 2020 16:44:57 -0800 Subject: [PATCH] [SPARK-33845][SQL][FOLLOWUP] fix SimplifyConditionals ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/30849, to fix a correctness issue caused by null value handling. ### Why are the changes needed? Fix a correctness issue. `If(null, true, false)` should return false, not true. ### Does this PR introduce _any_ user-facing change? Yes, but the bug only exist in the master branch. ### How was this patch tested? updated tests. Closes #30953 from cloud-fan/bug. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../ReplaceNullWithFalseInPredicate.scala | 11 +++-- .../sql/catalyst/optimizer/expressions.scala | 6 ++- .../PushFoldableIntoBranchesSuite.scala | 4 +- ...ReplaceNullWithFalseInPredicateSuite.scala | 14 +++--- .../optimizer/SimplifyConditionalSuite.scala | 44 ++++++++++++++----- 5 files changed, 53 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala index 92401131e8b82..df3da3e8a9982 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, If} -import org.apache.spark.sql.catalyst.expressions.{LambdaFunction, Literal, MapFilter, Or} -import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, EqualNullSafe, Expression, If, LambdaFunction, Literal, MapFilter, Or} +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter, Join, LogicalPlan, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.BooleanType @@ -56,6 +55,12 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { case d @ DeleteFromTable(_, Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond))) case u @ UpdateTable(_, _, Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond))) case p: LogicalPlan => p transformExpressions { + // For `EqualNullSafe` with a `TrueLiteral`, whether the other side is null or false has no + // difference, as `null <=> true` and `false <=> true` both return false. + case EqualNullSafe(left, TrueLiteral) => + EqualNullSafe(replaceNullWithFalse(left), TrueLiteral) + case EqualNullSafe(TrueLiteral, right) => + EqualNullSafe(TrueLiteral, replaceNullWithFalse(right)) case i @ If(pred, _, _) => i.copy(predicate = replaceNullWithFalse(pred)) case cw @ CaseWhen(branches, _) => val newBranches = branches.map { case (cond, value) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index f01df5e5e6768..b2625bddeecf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -475,8 +475,10 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue - case If(cond, TrueLiteral, FalseLiteral) => cond - case If(cond, FalseLiteral, TrueLiteral) => Not(cond) + case If(cond, TrueLiteral, FalseLiteral) => + if (cond.nullable) EqualNullSafe(cond, TrueLiteral) else cond + case If(cond, FalseLiteral, TrueLiteral) => + if (cond.nullable) Not(EqualNullSafe(cond, TrueLiteral)) else Not(cond) case If(cond, trueValue, falseValue) if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 2d826e7b55a68..7c9a67d7554e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -53,7 +53,7 @@ class PushFoldableIntoBranchesSuite test("Push down EqualTo through If") { assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a)) + assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a <=> TrueLiteral)) // Push down at most one not foldable expressions. assertEquivalent( @@ -102,7 +102,7 @@ class PushFoldableIntoBranchesSuite assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3))) assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)), If(a, Literal(2.0), Literal(3.0))) - assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a)) + assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a <=> TrueLiteral)) assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index f49e6921fd46a..ae97d53256837 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} @@ -237,8 +237,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { TrueLiteral, FalseLiteral) val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)) - val expectedCond = - CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen))) + val expectedCond = CaseWhen(Seq( + (UnresolvedAttribute("i") > Literal(10), (Literal(2) === nestedCaseWhen) <=> TrueLiteral))) testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) @@ -253,10 +253,10 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(3)), TrueLiteral, FalseLiteral) - val expectedCond = Literal(5) > If( + val expectedCond = (Literal(5) > If( UnresolvedAttribute("i") === Literal(15), Literal(null, IntegerType), - Literal(3)) + Literal(3))) <=> TrueLiteral testFilter(originalCond = condition, expectedCond = expectedCond) testJoin(originalCond = condition, expectedCond = expectedCond) testDelete(originalCond = condition, expectedCond = expectedCond) @@ -443,9 +443,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val lambda1 = LambdaFunction( function = If(cond, Literal(null, BooleanType), TrueLiteral), arguments = lambdaArgs) - // the optimized lambda body is: if(arg > 0, false, true) => arg <= 0 + // the optimized lambda body is: if(arg > 0, false, true) => !((arg > 0) <=> true) val lambda2 = LambdaFunction( - function = LessThanOrEqual(condArg, Literal(0)), + function = !(cond <=> TrueLiteral), arguments = lambdaArgs) testProjection( originalExpr = createExpr(argument, lambda1) as 'x, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 1876be21dea4b..317984eba2261 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -201,19 +201,39 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } test("SPARK-33845: remove unnecessary if when the outputs are boolean type") { - assertEquivalent( - If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), - IsNotNull(UnresolvedAttribute("a"))) - assertEquivalent( - If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), - IsNull(UnresolvedAttribute("a"))) + // verify the boolean equivalence of all transformations involved + val fields = Seq( + 'cond.boolean.notNull, + 'cond_nullable.boolean, + 'a.boolean, + 'b.boolean + ) + val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) } + + val exprs = Seq( + // actual expressions of the transformations: original -> transformed + If(cond, true, false) -> cond, + If(cond, false, true) -> !cond, + If(cond_nullable, true, false) -> (cond_nullable <=> true), + If(cond_nullable, false, true) -> (!(cond_nullable <=> true))) + + // check plans + for ((originalExpr, expectedExpr) <- exprs) { + assertEquivalent(originalExpr, expectedExpr) + } - assertEquivalent( - If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), - GreaterThan(Rand(0), UnresolvedAttribute("a"))) - assertEquivalent( - If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), - LessThanOrEqual(Rand(0), UnresolvedAttribute("a"))) + // check evaluation + val binaryBooleanValues = Seq(true, false) + val ternaryBooleanValues = Seq(true, false, null) + for (condVal <- binaryBooleanValues; + condNullableVal <- ternaryBooleanValues; + aVal <- ternaryBooleanValues; + bVal <- ternaryBooleanValues; + (originalExpr, expectedExpr) <- exprs) { + val inputRow = create_row(condVal, condNullableVal, aVal, bVal) + val optimizedVal = evaluateWithoutCodegen(expectedExpr, inputRow) + checkEvaluation(originalExpr, optimizedVal, inputRow) + } } test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") {