-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Implement Spark unhex #342
Changes from 8 commits
5dbd4aa
04bb619
6cb88c7
c649aef
bb4ad43
a0bdbbe
70c9ddd
663aef5
bfe92c4
966d307
97eae4b
a378f74
112c7c6
6146f3e
1de0887
bd07fed
36baf8e
d5a1c46
fb1c24a
c5c3fcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,7 +43,8 @@ use datafusion::{ | |
}; | ||
use datafusion_common::{ | ||
cast::{as_binary_array, as_generic_string_array}, | ||
exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, | ||
exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result as DataFusionResult, | ||
ScalarValue, | ||
}; | ||
use datafusion_physical_expr::{math_expressions, udf::ScalarUDF}; | ||
use num::{ | ||
|
@@ -105,6 +106,9 @@ pub fn create_comet_physical_fun( | |
"make_decimal" => { | ||
make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) | ||
} | ||
"unhex" => { | ||
make_comet_scalar_udf!("unhex", spark_unhex, data_type) | ||
} | ||
"decimal_div" => { | ||
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) | ||
} | ||
|
@@ -123,11 +127,10 @@ pub fn create_comet_physical_fun( | |
make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type) | ||
} | ||
_ => { | ||
let fun = BuiltinScalarFunction::from_str(fun_name); | ||
if fun.is_err() { | ||
Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) | ||
if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated, but more idiomatic IMO |
||
Ok(ScalarFunctionDefinition::BuiltIn(fun)) | ||
} else { | ||
Ok(ScalarFunctionDefinition::BuiltIn(fun?)) | ||
Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) | ||
} | ||
} | ||
} | ||
|
@@ -573,6 +576,115 @@ fn spark_rpad_internal<T: OffsetSizeTrait>( | |
Ok(ColumnarValue::Array(Arc::new(result))) | ||
} | ||
|
||
fn unhex(string: &str, result: &mut Vec<u8>) -> Result<(), DataFusionError> { | ||
// https://docs.databricks.com/en/sql/language-manual/functions/unhex.html | ||
// If the length of expr is odd, the first character is discarded and the result is padded with | ||
// a null byte. If expr contains non hex characters the result is NULL. | ||
let string = if string.len() % 2 == 1 { | ||
&string[1..] | ||
} else { | ||
string | ||
}; | ||
|
||
let mut iter = string.chars().peekable(); | ||
while let Some(c) = iter.next() { | ||
let high = if let Some(high) = c.to_digit(16) { | ||
high | ||
} else { | ||
return Ok(()); | ||
}; | ||
|
||
let low = iter | ||
.next() | ||
.ok_or_else(|| DataFusionError::Internal("Odd number of hex characters".to_string()))? | ||
.to_digit(16); | ||
|
||
let low = if let Some(low) = low { | ||
low | ||
} else { | ||
return Ok(()); | ||
}; | ||
|
||
result.push((high << 4 | low) as u8); | ||
} | ||
|
||
if string.len() % 2 == 1 { | ||
result.push(0); | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
fn spark_unhex_inner<T: OffsetSizeTrait>( | ||
array: &ColumnarValue, | ||
fail_on_error: bool, | ||
) -> Result<ColumnarValue, DataFusionError> { | ||
let string_array = match array { | ||
ColumnarValue::Array(array) => as_generic_string_array::<T>(array)?, | ||
ColumnarValue::Scalar(ScalarValue::Utf8(Some(_string))) => { | ||
return not_impl_err!("unhex with scalar string is not implemented yet"); | ||
} | ||
_ => { | ||
return internal_err!( | ||
"The first argument must be a string scalar or array, but got: {:?}", | ||
array.data_type() | ||
); | ||
} | ||
}; | ||
|
||
let mut builder = arrow::array::BinaryBuilder::new(); | ||
let mut encoded = Vec::new(); | ||
|
||
for i in 0..string_array.len() { | ||
let string = string_array.value(i); | ||
|
||
if unhex(string, &mut encoded).is_ok() { | ||
builder.append_value(encoded.as_slice()); | ||
encoded.clear(); | ||
} else if fail_on_error { | ||
return plan_err!("Input to unhex is not a valid hex string: {:?}", string); | ||
} else { | ||
builder.append_null(); | ||
} | ||
} | ||
Ok(ColumnarValue::Array(Arc::new(builder.finish()))) | ||
} | ||
|
||
fn spark_unhex( | ||
args: &[ColumnarValue], | ||
_data_type: &DataType, | ||
) -> Result<ColumnarValue, DataFusionError> { | ||
if args.len() > 2 { | ||
return plan_err!("unhex takes at most 2 arguments, but got: {}", args.len()); | ||
} | ||
|
||
let val_to_unhex = &args[0]; | ||
let fail_on_error = if args.len() == 2 { | ||
match &args[1] { | ||
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, | ||
_ => { | ||
return plan_err!( | ||
"The second argument must be boolean scalar, but got: {:?}", | ||
args[1] | ||
); | ||
} | ||
} | ||
} else { | ||
false | ||
}; | ||
|
||
match val_to_unhex.data_type() { | ||
DataType::Utf8 => spark_unhex_inner::<i32>(val_to_unhex, fail_on_error), | ||
DataType::LargeUtf8 => spark_unhex_inner::<i64>(val_to_unhex, fail_on_error), | ||
other => { | ||
internal_err!( | ||
"The first argument must be a string scalar or array, but got: {:?}", | ||
other | ||
) | ||
} | ||
} | ||
} | ||
|
||
// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). | ||
// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to | ||
// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since | ||
|
@@ -701,3 +813,18 @@ fn wrap_digest_result_as_hex_string( | |
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use super::unhex; | ||
|
||
#[test] | ||
fn test_unhex() { | ||
let mut result = Vec::new(); | ||
|
||
unhex("537061726B2053514C", &mut result).unwrap(); | ||
let result_str = std::str::from_utf8(&result).unwrap(); | ||
assert_eq!(result_str, "Spark SQL"); | ||
result.clear(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1301,24 +1301,26 @@ impl PhysicalPlanner { | |
.iter() | ||
.map(|x| x.data_type(input_schema.as_ref())) | ||
.collect::<Result<Vec<_>, _>>()?; | ||
|
||
let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) { | ||
Some(t) => t, | ||
None => { | ||
// If no data type is provided from Spark, we'll use DF's return type from the | ||
// scalar function | ||
// Note this assumes the `fun_name` is a defined function in DF. Otherwise, it'll | ||
// throw error. | ||
let fun = BuiltinScalarFunction::from_str(fun_name); | ||
if fun.is_err() { | ||
|
||
if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated, but more idiomatic IMO |
||
fun.return_type(&input_expr_types)? | ||
} else { | ||
self.session_ctx | ||
.udf(fun_name)? | ||
.inner() | ||
.return_type(&input_expr_types)? | ||
} else { | ||
fun?.return_type(&input_expr_types)? | ||
} | ||
} | ||
}; | ||
|
||
let fun_expr = | ||
create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,7 +88,8 @@ under the License. | |
<argLine>-ea -Xmx4g -Xss4m ${extraJavaTestArgs}</argLine> | ||
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source> | ||
<additional.3_4.test.source>spark-3.4</additional.3_4.test.source> | ||
<shims.source>spark-3.x</shims.source> | ||
<shims.majorSource>spark-3.x</shims.majorSource> | ||
<shims.minorSource>spark-3.4.x</shims.minorSource> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI
|
||
</properties> | ||
|
||
<dependencyManagement> | ||
|
@@ -500,6 +501,7 @@ under the License. | |
<!-- we don't add special test suits for spark-3.2, so a not existed dir is specified--> | ||
<additional.3_3.test.source>not-needed-yet</additional.3_3.test.source> | ||
<additional.3_4.test.source>not-needed-yet</additional.3_4.test.source> | ||
<shims.minorSource>spark-3.3.x</shims.minorSource> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this should be |
||
</properties> | ||
</profile> | ||
|
||
|
@@ -512,6 +514,7 @@ under the License. | |
<parquet.version>1.12.0</parquet.version> | ||
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source> | ||
<additional.3_4.test.source>not-needed-yet</additional.3_4.test.source> | ||
<shims.minorSource>spark-3.3.x</shims.minorSource> | ||
</properties> | ||
</profile> | ||
|
||
|
@@ -523,6 +526,7 @@ under the License. | |
<parquet.version>1.13.1</parquet.version> | ||
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source> | ||
<additional.3_4.test.source>spark-3.4</additional.3_4.test.source> | ||
<shims.minorSource>spark-3.4.x</shims.minorSource> | ||
</properties> | ||
</profile> | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,12 +45,13 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isC | |
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} | ||
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} | ||
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} | ||
import org.apache.comet.shims.ShimCometUnhexExpr | ||
import org.apache.comet.shims.ShimQueryPlanSerde | ||
|
||
/** | ||
* An utility object for query plan and expression serialization. | ||
*/ | ||
object QueryPlanSerde extends Logging with ShimQueryPlanSerde { | ||
object QueryPlanSerde extends Logging with ShimQueryPlanSerde with ShimCometUnhexExpr { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say Eventually we should merge |
||
def emitWarning(reason: String): Unit = { | ||
logWarning(s"Comet native execution is disabled due to: $reason") | ||
} | ||
|
@@ -1396,6 +1397,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { | |
val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) | ||
optExprWithInfo(optExpr, expr, left, right) | ||
|
||
case e: Unhex => | ||
val unHex = unhexSerde(e) | ||
|
||
val childCast = Cast(unHex._1, StringType) | ||
val failOnErrorCast = Cast(unHex._2, BooleanType) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would happen if we do not use |
||
|
||
val childExpr = exprToProtoInternal(childCast, inputs) | ||
val failOnErrorExpr = exprToProtoInternal(failOnErrorCast, inputs) | ||
|
||
val optExpr = | ||
scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) | ||
optExprWithInfo(optExpr, expr, unHex._1) | ||
|
||
case e @ Ceil(child) => | ||
val childExpr = exprToProtoInternal(child, inputs) | ||
child.dataType match { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.apache.comet.shims | ||
|
||
import org.apache.spark.sql.catalyst.expressions._ | ||
|
||
/** | ||
* `ShimCometUnhexExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x. | ||
*/ | ||
trait ShimCometUnhexExpr { | ||
def unhexSerde(unhex: Unhex): (Expression, Expression) = { | ||
(unhex.child, Literal(false)) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.apache.comet.shims | ||
|
||
import org.apache.spark.sql.catalyst.expressions._ | ||
|
||
/** | ||
* `ShimCometUnhexExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x. | ||
*/ | ||
trait ShimCometUnhexExpr { | ||
def unhexSerde(unhex: Unhex): (Expression, Expression) = { | ||
(unhex.child, Literal(unhex.failOnError)) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1025,6 +1025,19 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { | |
} | ||
} | ||
|
||
test("unhex") { | ||
Seq(false, true).foreach { dictionary => | ||
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { | ||
tshauck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val table = "test" | ||
withTable(table) { | ||
sql(s"create table $table(col string) using parquet") | ||
sql(s"insert into $table values('537061726B2053514C')") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to see more values being tested here, both valid and invalid, and covering the padded case. |
||
checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") | ||
} | ||
} | ||
} | ||
} | ||
|
||
test("length, reverse, instr, replace, translate") { | ||
Seq(false, true).foreach { dictionary => | ||
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the minor version shims. These are going to help me with some of work around supporting
cast
.