Skip to content

Commit

Permalink
[SPARK-33890][SQL] Improve the implement of trim/trimleft/trimright
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The current implement of trim/trimleft/trimright have somewhat redundant.

### Why are the changes needed?
Improve the implement of trim/trimleft/trimright

### Does this PR introduce _any_ user-facing change?
'No'.

### How was this patch tested?
Jenkins test

Closes apache#30905 from beliefer/SPARK-33890.

Lead-authored-by: gengjiaan <[email protected]>
Co-authored-by: beliefer <[email protected]>
Co-authored-by: Jiaan Geng <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
3 people committed Dec 30, 2020
1 parent 49aa6eb commit 687f465
Showing 1 changed file with 64 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,55 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes {
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)

protected def doEval(srcString: UTF8String): UTF8String
protected def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String

override def eval(input: InternalRow): Any = {
val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
if (srcString == null) {
null
} else if (trimStr.isDefined) {
doEval(srcString, trimStr.get.eval(input).asInstanceOf[UTF8String])
} else {
doEval(srcString)
}
}

protected val trimMethod: String

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val srcString = evals(0)

if (evals.length == 1) {
ev.copy(code = code"""
|${srcString.code}
|boolean ${ev.isNull} = false;
|UTF8String ${ev.value} = null;
|if (${srcString.isNull}) {
| ${ev.isNull} = true;
|} else {
| ${ev.value} = ${srcString.value}.$trimMethod();
|}""".stripMargin)
} else {
val trimString = evals(1)
ev.copy(code = code"""
|${srcString.code}
|boolean ${ev.isNull} = false;
|UTF8String ${ev.value} = null;
|if (${srcString.isNull}) {
| ${ev.isNull} = true;
|} else {
| ${trimString.code}
| if (${trimString.isNull}) {
| ${ev.isNull} = true;
| } else {
| ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value});
| }
|}""".stripMargin)
}
}

override def sql: String = if (trimStr.isDefined) {
s"TRIM($direction ${trimStr.get.sql} FROM ${srcStr.sql})"
} else {
Expand Down Expand Up @@ -840,9 +889,7 @@ object StringTrim {
""",
since = "1.5.0",
group = "string_funcs")
case class StringTrim(
srcStr: Expression,
trimStr: Option[Expression] = None)
case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None)
extends String2TrimExpression {

def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
Expand All @@ -853,51 +900,12 @@ case class StringTrim(

override protected def direction: String = "BOTH"

override def eval(input: InternalRow): Any = {
val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
if (srcString == null) {
null
} else {
if (trimStr.isDefined) {
srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String])
} else {
srcString.trim()
}
}
}
override def doEval(srcString: UTF8String): UTF8String = srcString.trim()

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val srcString = evals(0)
override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
srcString.trim(trimString)

if (evals.length == 1) {
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
${ev.isNull} = true;
} else {
${ev.value} = ${srcString.value}.trim();
}""")
} else {
val trimString = evals(1)
val getTrimFunction =
s"""
if (${trimString.isNull}) {
${ev.isNull} = true;
} else {
${ev.value} = ${srcString.value}.trim(${trimString.value});
}"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
${ev.isNull} = true;
} else {
$getTrimFunction
}""")
}
}
override val trimMethod: String = "trim"
}

object StringTrimLeft {
Expand Down Expand Up @@ -934,9 +942,7 @@ object StringTrimLeft {
""",
since = "1.5.0",
group = "string_funcs")
case class StringTrimLeft(
srcStr: Expression,
trimStr: Option[Expression] = None)
case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None)
extends String2TrimExpression {

def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
Expand All @@ -947,51 +953,12 @@ case class StringTrimLeft(

override protected def direction: String = "LEADING"

override def eval(input: InternalRow): Any = {
val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
if (srcString == null) {
null
} else {
if (trimStr.isDefined) {
srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String])
} else {
srcString.trimLeft()
}
}
}
override def doEval(srcString: UTF8String): UTF8String = srcString.trimLeft()

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val srcString = evals(0)
override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
srcString.trimLeft(trimString)

if (evals.length == 1) {
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
${ev.isNull} = true;
} else {
${ev.value} = ${srcString.value}.trimLeft();
}""")
} else {
val trimString = evals(1)
val getTrimLeftFunction =
s"""
if (${trimString.isNull}) {
${ev.isNull} = true;
} else {
${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
}"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
${ev.isNull} = true;
} else {
$getTrimLeftFunction
}""")
}
}
override val trimMethod: String = "trimLeft"
}

object StringTrimRight {
Expand Down Expand Up @@ -1030,9 +997,7 @@ object StringTrimRight {
since = "1.5.0",
group = "string_funcs")
// scalastyle:on line.size.limit
case class StringTrimRight(
srcStr: Expression,
trimStr: Option[Expression] = None)
case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = None)
extends String2TrimExpression {

def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr))
Expand All @@ -1043,51 +1008,12 @@ case class StringTrimRight(

override protected def direction: String = "TRAILING"

override def eval(input: InternalRow): Any = {
val srcString = srcStr.eval(input).asInstanceOf[UTF8String]
if (srcString == null) {
null
} else {
if (trimStr.isDefined) {
srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String])
} else {
srcString.trimRight()
}
}
}
override def doEval(srcString: UTF8String): UTF8String = srcString.trimRight()

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val srcString = evals(0)
override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String =
srcString.trimRight(trimString)

if (evals.length == 1) {
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
${ev.isNull} = true;
} else {
${ev.value} = ${srcString.value}.trimRight();
}""")
} else {
val trimString = evals(1)
val getTrimRightFunction =
s"""
if (${trimString.isNull}) {
${ev.isNull} = true;
} else {
${ev.value} = ${srcString.value}.trimRight(${trimString.value});
}"""
ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
${ev.isNull} = true;
} else {
$getTrimRightFunction
}""")
}
}
override val trimMethod: String = "trimRight"
}

/**
Expand Down

0 comments on commit 687f465

Please sign in to comment.