From 872107f67fd6c2093531e8a8976ff713359cba01 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 29 Dec 2020 13:34:43 +0000 Subject: [PATCH] [SPARK-33848][SQL][FOLLOWUP] Introduce allowList for push into (if / case) branches ### What changes were proposed in this pull request? Introduce allowList push into (if / case) branches to fix potential bug. ### Why are the changes needed? Fix potential bug. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test. Closes #30955 from wangyum/SPARK-33848-2. Authored-by: Yuming Wang Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/expressions.scala | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 6c5dec133d2a7..1b93d514964e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -553,41 +553,68 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { foldables.nonEmpty && others.length < 2 } + // Not all UnaryExpression can be pushed into (if / case) branches, e.g. Alias. + private def supportedUnaryExpression(e: UnaryExpression): Boolean = e match { + case _: IsNull | _: IsNotNull => true + case _: UnaryMathExpression | _: Abs | _: Bin | _: Factorial | _: Hex => true + case _: String2StringExpression | _: Ascii | _: Base64 | _: BitLength | _: Chr | _: Length => + true + case _: CastBase => true + case _: GetDateField | _: LastDay => true + case _: ExtractIntervalPart => true + case _: ArraySetLike => true + case _: ExtractValue => true + case _ => false + } + + // Not all BinaryExpression can be pushed into (if / case) branches. + private def supportedBinaryExpression(e: BinaryExpression): Boolean = e match { + case _: BinaryComparison | _: StringPredicate | _: StringRegexExpression => true + case _: BinaryArithmetic => true + case _: BinaryMathExpression => true + case _: AddMonths | _: DateAdd | _: DateAddInterval | _: DateDiff | _: DateSub => true + case _: FindInSet | _: RoundBase => true + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case a: Alias => a // Skip an alias. case u @ UnaryExpression(i @ If(_, trueValue, falseValue)) - if atMostOneUnfoldable(Seq(trueValue, falseValue)) => + if supportedUnaryExpression(u) && atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = u.withNewChildren(Array(trueValue)), falseValue = u.withNewChildren(Array(falseValue))) case u @ UnaryExpression(c @ CaseWhen(branches, elseValue)) - if atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + if supportedUnaryExpression(u) && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = u.withNewChildren(Array(e._2)))), elseValue.map(e => u.withNewChildren(Array(e)))) case b @ BinaryExpression(i @ If(_, trueValue, falseValue), right) - if right.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => + if supportedBinaryExpression(b) && right.foldable && + atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = b.withNewChildren(Array(trueValue, right)), falseValue = b.withNewChildren(Array(falseValue, right))) case b @ BinaryExpression(left, i @ If(_, trueValue, falseValue)) - if left.foldable && atMostOneUnfoldable(Seq(trueValue, falseValue)) => + if supportedBinaryExpression(b) && left.foldable && + atMostOneUnfoldable(Seq(trueValue, falseValue)) => i.copy( trueValue = b.withNewChildren(Array(left, trueValue)), falseValue = b.withNewChildren(Array(left, falseValue))) case b @ BinaryExpression(c @ CaseWhen(branches, elseValue), right) - if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + if supportedBinaryExpression(b) && right.foldable && + atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.withNewChildren(Array(e._2, right)))), elseValue.map(e => b.withNewChildren(Array(e, right)))) case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) - if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => + if supportedBinaryExpression(b) && left.foldable && + atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.withNewChildren(Array(left, e._2)))), elseValue.map(e => b.withNewChildren(Array(left, e))))