-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-33890][SQL] Improve the implement of trim/trimleft/trimright #30905
Changes from 42 commits
4a6f903
96456e2
4314005
d6af4a7
f69094f
b86a42d
2ac5159
9021d6c
74a2ef4
9828158
9cd1aaf
abfcbb9
07c6c81
580130b
3712808
6107413
4b799b4
ee0ecbf
596bc61
0164e2f
90b79fc
2cef3a9
c26b64f
2e02cd2
a6d0741
82e5b2c
70bbf5d
126a51e
f2ceacd
5ad208f
970917e
ddc1b8b
2b1ed0b
a7d3729
17ef8fc
f7a2902
a803c9b
9d79697
7127f5e
21fb9b9
7d32023
ba91af5
42f75f3
b8ad0cb
8f15837
fb13fe4
3331b48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -756,6 +756,54 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { | |
override def nullable: Boolean = children.exists(_.nullable) | ||
override def foldable: Boolean = children.forall(_.foldable) | ||
|
||
protected def doEval(srcString: UTF8String): Any | ||
protected def doEval(srcString: UTF8String, trimString: UTF8String): Any | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val srcString = srcStr.eval(input).asInstanceOf[UTF8String] | ||
if (srcString == null) { | ||
null | ||
} else if (trimStr.isDefined) { | ||
doEval(srcString, trimStr.get.eval(input).asInstanceOf[UTF8String]) | ||
} else { | ||
doEval(srcString) | ||
} | ||
} | ||
|
||
protected val trimMethod: String | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val evals = children.map(_.genCode(ctx)) | ||
val srcString = evals(0) | ||
|
||
if (evals.length == 1) { | ||
ev.copy(code = evals.map(_.code) :+ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
code""" | ||
|boolean ${ev.isNull} = false; | ||
|UTF8String ${ev.value} = null; | ||
|if (${srcString.isNull}) { | ||
| ${ev.isNull} = true; | ||
|} else { | ||
| ${ev.value} = ${srcString.value}.$trimMethod(); | ||
|} | ||
""") | ||
} else { | ||
val trimString = evals(1) | ||
ev.copy(code = evals.map(_.code) :+ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can skip evaluating trim string if possible
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. |
||
code""" | ||
|boolean ${ev.isNull} = false; | ||
|UTF8String ${ev.value} = null; | ||
|if (${srcString.isNull}) { | ||
| ${ev.isNull} = true; | ||
|} else if (${trimString.isNull}) { | ||
| ${ev.isNull} = true; | ||
|} else { | ||
| ${ev.value} = ${srcString.value}.$trimMethod(${trimString.value}); | ||
|} | ||
""") | ||
} | ||
} | ||
|
||
override def sql: String = if (trimStr.isDefined) { | ||
s"TRIM($direction ${trimStr.get.sql} FROM ${srcStr.sql})" | ||
} else { | ||
|
@@ -844,51 +892,12 @@ case class StringTrim( | |
|
||
override protected def direction: String = "BOTH" | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val srcString = srcStr.eval(input).asInstanceOf[UTF8String] | ||
if (srcString == null) { | ||
null | ||
} else { | ||
if (trimStr.isDefined) { | ||
srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) | ||
} else { | ||
srcString.trim() | ||
} | ||
} | ||
} | ||
override def doEval(srcString: UTF8String): Any = srcString.trim() | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val evals = children.map(_.genCode(ctx)) | ||
val srcString = evals(0) | ||
override def doEval(srcString: UTF8String, trimString: UTF8String): Any = | ||
srcString.trim(trimString) | ||
|
||
if (evals.length == 1) { | ||
ev.copy(evals.map(_.code) :+ code""" | ||
boolean ${ev.isNull} = false; | ||
UTF8String ${ev.value} = null; | ||
if (${srcString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
${ev.value} = ${srcString.value}.trim(); | ||
}""") | ||
} else { | ||
val trimString = evals(1) | ||
val getTrimFunction = | ||
s""" | ||
if (${trimString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
${ev.value} = ${srcString.value}.trim(${trimString.value}); | ||
}""" | ||
ev.copy(evals.map(_.code) :+ code""" | ||
boolean ${ev.isNull} = false; | ||
UTF8String ${ev.value} = null; | ||
if (${srcString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
$getTrimFunction | ||
}""") | ||
} | ||
} | ||
override val trimMethod: String = "trim" | ||
} | ||
|
||
object StringTrimLeft { | ||
|
@@ -937,51 +946,12 @@ case class StringTrimLeft( | |
|
||
override protected def direction: String = "LEADING" | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val srcString = srcStr.eval(input).asInstanceOf[UTF8String] | ||
if (srcString == null) { | ||
null | ||
} else { | ||
if (trimStr.isDefined) { | ||
srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) | ||
} else { | ||
srcString.trimLeft() | ||
} | ||
} | ||
} | ||
override def doEval(srcString: UTF8String): Any = srcString.trimLeft() | ||
|
||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val evals = children.map(_.genCode(ctx)) | ||
val srcString = evals(0) | ||
override def doEval(srcString: UTF8String, trimString: UTF8String): Any = | ||
srcString.trimLeft(trimString) | ||
|
||
if (evals.length == 1) { | ||
ev.copy(evals.map(_.code) :+ code""" | ||
boolean ${ev.isNull} = false; | ||
UTF8String ${ev.value} = null; | ||
if (${srcString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
${ev.value} = ${srcString.value}.trimLeft(); | ||
}""") | ||
} else { | ||
val trimString = evals(1) | ||
val getTrimLeftFunction = | ||
s""" | ||
if (${trimString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); | ||
}""" | ||
ev.copy(evals.map(_.code) :+ code""" | ||
boolean ${ev.isNull} = false; | ||
UTF8String ${ev.value} = null; | ||
if (${srcString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
$getTrimLeftFunction | ||
}""") | ||
} | ||
} | ||
override val trimMethod: String = "trimLeft" | ||
} | ||
|
||
object StringTrimRight { | ||
|
@@ -1032,51 +1002,12 @@ case class StringTrimRight( | |
|
||
override protected def direction: String = "TRAILING" | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val srcString = srcStr.eval(input).asInstanceOf[UTF8String] | ||
if (srcString == null) { | ||
null | ||
} else { | ||
if (trimStr.isDefined) { | ||
srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) | ||
} else { | ||
srcString.trimRight() | ||
} | ||
} | ||
} | ||
override def doEval(srcString: UTF8String): Any = srcString.trimRight() | ||
|
||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val evals = children.map(_.genCode(ctx)) | ||
val srcString = evals(0) | ||
override def doEval(srcString: UTF8String, trimString: UTF8String): Any = | ||
srcString.trimRight(trimString) | ||
|
||
if (evals.length == 1) { | ||
ev.copy(evals.map(_.code) :+ code""" | ||
boolean ${ev.isNull} = false; | ||
UTF8String ${ev.value} = null; | ||
if (${srcString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
${ev.value} = ${srcString.value}.trimRight(); | ||
}""") | ||
} else { | ||
val trimString = evals(1) | ||
val getTrimRightFunction = | ||
s""" | ||
if (${trimString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
${ev.value} = ${srcString.value}.trimRight(${trimString.value}); | ||
}""" | ||
ev.copy(evals.map(_.code) :+ code""" | ||
boolean ${ev.isNull} = false; | ||
UTF8String ${ev.value} = null; | ||
if (${srcString.isNull}) { | ||
${ev.isNull} = true; | ||
} else { | ||
$getTrimRightFunction | ||
}""") | ||
} | ||
} | ||
override val trimMethod: String = "trimRight" | ||
} | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These will definitely return UTF8String right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I updated it.