Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark unhex #342

Merged
merged 20 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ under the License.
</goals>
<configuration>
<sources>
<source>src/main/${shims.source}</source>
<source>src/main/${shims.majorSource}</source>
<source>src/main/${shims.minorSource}</source>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the minor version shims. These are going to help me with some of work around supporting cast.

</sources>
</configuration>
</execution>
Expand Down
137 changes: 132 additions & 5 deletions core/src/execution/datafusion/expressions/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ use datafusion::{
};
use datafusion_common::{
cast::{as_binary_array, as_generic_string_array},
exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result as DataFusionResult,
ScalarValue,
};
use datafusion_physical_expr::{math_expressions, udf::ScalarUDF};
use num::{
Expand Down Expand Up @@ -105,6 +106,9 @@ pub fn create_comet_physical_fun(
"make_decimal" => {
make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
}
"unhex" => {
make_comet_scalar_udf!("unhex", spark_unhex, data_type)
}
"decimal_div" => {
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
}
Expand All @@ -123,11 +127,10 @@ pub fn create_comet_physical_fun(
make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type)
}
_ => {
let fun = BuiltinScalarFunction::from_str(fun_name);
if fun.is_err() {
Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?))
if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but more idiomatic IMO

Ok(ScalarFunctionDefinition::BuiltIn(fun))
} else {
Ok(ScalarFunctionDefinition::BuiltIn(fun?))
Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?))
}
}
}
Expand Down Expand Up @@ -573,6 +576,115 @@ fn spark_rpad_internal<T: OffsetSizeTrait>(
Ok(ColumnarValue::Array(Arc::new(result)))
}

fn unhex(string: &str, result: &mut Vec<u8>) -> Result<(), DataFusionError> {
// https://docs.databricks.com/en/sql/language-manual/functions/unhex.html
// If the length of expr is odd, the first character is discarded and the result is padded with
// a null byte. If expr contains non hex characters the result is NULL.
let string = if string.len() % 2 == 1 {
&string[1..]
} else {
string
};

let mut iter = string.chars().peekable();
while let Some(c) = iter.next() {
let high = if let Some(high) = c.to_digit(16) {
high
} else {
return Ok(());
};

let low = iter
.next()
.ok_or_else(|| DataFusionError::Internal("Odd number of hex characters".to_string()))?
.to_digit(16);

let low = if let Some(low) = low {
low
} else {
return Ok(());
};

result.push((high << 4 | low) as u8);
}

if string.len() % 2 == 1 {
result.push(0);
}

Ok(())
}

fn spark_unhex_inner<T: OffsetSizeTrait>(
array: &ColumnarValue,
fail_on_error: bool,
) -> Result<ColumnarValue, DataFusionError> {
let string_array = match array {
ColumnarValue::Array(array) => as_generic_string_array::<T>(array)?,
ColumnarValue::Scalar(ScalarValue::Utf8(Some(_string))) => {
return not_impl_err!("unhex with scalar string is not implemented yet");
}
_ => {
return internal_err!(
"The first argument must be a string scalar or array, but got: {:?}",
array.data_type()
);
}
};

let mut builder = arrow::array::BinaryBuilder::new();
let mut encoded = Vec::new();

for i in 0..string_array.len() {
let string = string_array.value(i);

if unhex(string, &mut encoded).is_ok() {
builder.append_value(encoded.as_slice());
encoded.clear();
} else if fail_on_error {
return plan_err!("Input to unhex is not a valid hex string: {:?}", string);
} else {
builder.append_null();
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}

fn spark_unhex(
args: &[ColumnarValue],
_data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
if args.len() > 2 {
return plan_err!("unhex takes at most 2 arguments, but got: {}", args.len());
}

let val_to_unhex = &args[0];
let fail_on_error = if args.len() == 2 {
match &args[1] {
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error,
_ => {
return plan_err!(
"The second argument must be boolean scalar, but got: {:?}",
args[1]
);
}
}
} else {
false
};

match val_to_unhex.data_type() {
DataType::Utf8 => spark_unhex_inner::<i32>(val_to_unhex, fail_on_error),
DataType::LargeUtf8 => spark_unhex_inner::<i64>(val_to_unhex, fail_on_error),
other => {
internal_err!(
"The first argument must be a string scalar or array, but got: {:?}",
other
)
}
}
}

// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to
// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since
Expand Down Expand Up @@ -701,3 +813,18 @@ fn wrap_digest_result_as_hex_string(
}
}
}

#[cfg(test)]
mod test {
use super::unhex;

#[test]
fn test_unhex() {
let mut result = Vec::new();

unhex("537061726B2053514C", &mut result).unwrap();
let result_str = std::str::from_utf8(&result).unwrap();
assert_eq!(result_str, "Spark SQL");
result.clear();
}
}
10 changes: 6 additions & 4 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1301,24 +1301,26 @@ impl PhysicalPlanner {
.iter()
.map(|x| x.data_type(input_schema.as_ref()))
.collect::<Result<Vec<_>, _>>()?;

let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) {
Some(t) => t,
None => {
// If no data type is provided from Spark, we'll use DF's return type from the
// scalar function
// Note this assumes the `fun_name` is a defined function in DF. Otherwise, it'll
// throw error.
let fun = BuiltinScalarFunction::from_str(fun_name);
if fun.is_err() {

if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but more idiomatic IMO

fun.return_type(&input_expr_types)?
} else {
self.session_ctx
.udf(fun_name)?
.inner()
.return_type(&input_expr_types)?
} else {
fun?.return_type(&input_expr_types)?
}
}
};

let fun_expr =
create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?;

Expand Down
6 changes: 5 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ under the License.
<argLine>-ea -Xmx4g -Xss4m ${extraJavaTestArgs}</argLine>
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
<additional.3_4.test.source>spark-3.4</additional.3_4.test.source>
<shims.source>spark-3.x</shims.source>
<shims.majorSource>spark-3.x</shims.majorSource>
<shims.minorSource>spark-3.4.x</shims.minorSource>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI spark-3.x will be gone and supposed to be replaced by independent spark-3.2, spark-3.3, and spark-3.4 dirs
So let's go ahead and start using spark-3.4 instead of spark-3.4.x? I.e.

<shims.source.shared>spark-3.x</shims.source.shared>
<shims.source>spark-3.4</shims.source>

</properties>

<dependencyManagement>
Expand Down Expand Up @@ -500,6 +501,7 @@ under the License.
<!-- we don't add special test suits for spark-3.2, so a not existed dir is specified-->
<additional.3_3.test.source>not-needed-yet</additional.3_3.test.source>
<additional.3_4.test.source>not-needed-yet</additional.3_4.test.source>
<shims.minorSource>spark-3.3.x</shims.minorSource>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should be spark-3.2?

</properties>
</profile>

Expand All @@ -512,6 +514,7 @@ under the License.
<parquet.version>1.12.0</parquet.version>
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
<additional.3_4.test.source>not-needed-yet</additional.3_4.test.source>
<shims.minorSource>spark-3.3.x</shims.minorSource>
</properties>
</profile>

Expand All @@ -523,6 +526,7 @@ under the License.
<parquet.version>1.13.1</parquet.version>
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
<additional.3_4.test.source>spark-3.4</additional.3_4.test.source>
<shims.minorSource>spark-3.4.x</shims.minorSource>
</properties>
</profile>

Expand Down
3 changes: 2 additions & 1 deletion spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ under the License.
</goals>
<configuration>
<sources>
<source>src/main/${shims.source}</source>
<source>src/main/${shims.majorSource}</source>
<source>src/main/${shims.minorSource}</source>
</sources>
</configuration>
</execution>
Expand Down
16 changes: 15 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, isC
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator}
import org.apache.comet.shims.ShimCometUnhexExpr
import org.apache.comet.shims.ShimQueryPlanSerde

/**
* An utility object for query plan and expression serialization.
*/
object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
object QueryPlanSerde extends Logging with ShimQueryPlanSerde with ShimCometUnhexExpr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say ShimCometUnhexExpr should be more generic name, not unhex specific. What about ShimCometExpr (or if you have other idea, please feel free to propose)?
Otherwise, we will have to keep adding trait class per function.

Eventually we should merge ShimQueryPlanSerde and ShimCometExpr into one when we remove spark-3.x dir.

def emitWarning(reason: String): Unit = {
logWarning(s"Comet native execution is disabled due to: $reason")
}
Expand Down Expand Up @@ -1396,6 +1397,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, left, right)

case e: Unhex =>
val unHex = unhexSerde(e)

val childCast = Cast(unHex._1, StringType)
val failOnErrorCast = Cast(unHex._2, BooleanType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if we do not use Cast?


val childExpr = exprToProtoInternal(childCast, inputs)
val failOnErrorExpr = exprToProtoInternal(failOnErrorCast, inputs)

val optExpr =
scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr)
optExprWithInfo(optExpr, expr, unHex._1)

case e @ Ceil(child) =>
val childExpr = exprToProtoInternal(child, inputs)
child.dataType match {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.comet.shims

import org.apache.spark.sql.catalyst.expressions._

/**
* `ShimCometUnhexExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x.
*/
trait ShimCometUnhexExpr {
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(false))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.comet.shims

import org.apache.spark.sql.catalyst.expressions._

/**
* `ShimCometUnhexExpr` parses the `Unhex` expression assuming that the catalyst version is 3.4.x.
*/
trait ShimCometUnhexExpr {
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(unhex.failOnError))
}
}
13 changes: 13 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,19 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("unhex") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
tshauck marked this conversation as resolved.
Show resolved Hide resolved
val table = "test"
withTable(table) {
sql(s"create table $table(col string) using parquet")
sql(s"insert into $table values('537061726B2053514C')")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to see more values being tested here, both valid and invalid, and covering the padded case.

checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table")
}
}
}
}

test("length, reverse, instr, replace, translate") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
Expand Down