diff --git a/common/pom.xml b/common/pom.xml
index 73b1a269b9..71d74ce056 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -179,7 +179,8 @@ under the License.
-
+
+
diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs
index 2895937ca0..8c5e1f3916 100644
--- a/core/src/execution/datafusion/expressions/scalar_funcs.rs
+++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs
@@ -52,6 +52,9 @@ use num::{
};
use unicode_segmentation::UnicodeSegmentation;
+mod unhex;
+use unhex::spark_unhex;
+
macro_rules! make_comet_scalar_udf {
($name:expr, $func:ident, $data_type:ident) => {{
let scalar_func = CometScalarFunction::new(
@@ -105,6 +108,10 @@ pub fn create_comet_physical_fun(
"make_decimal" => {
make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
}
+ "unhex" => {
+ let func = Arc::new(spark_unhex);
+ make_comet_scalar_udf!("unhex", func, without data_type)
+ }
"decimal_div" => {
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
}
@@ -123,11 +130,10 @@ pub fn create_comet_physical_fun(
make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type)
}
_ => {
- let fun = BuiltinScalarFunction::from_str(fun_name);
- if fun.is_err() {
- Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?))
+ if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) {
+ Ok(ScalarFunctionDefinition::BuiltIn(fun))
} else {
- Ok(ScalarFunctionDefinition::BuiltIn(fun?))
+ Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?))
}
}
}
diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs
new file mode 100644
index 0000000000..38d5c04787
--- /dev/null
+++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs
@@ -0,0 +1,257 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow_array::OffsetSizeTrait;
+use arrow_schema::DataType;
+use datafusion::logical_expr::ColumnarValue;
+use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue};
+
+/// Helper function to convert a hex digit to a binary value.
+fn unhex_digit(c: u8) -> Result {
+ match c {
+ b'0'..=b'9' => Ok(c - b'0'),
+ b'A'..=b'F' => Ok(10 + c - b'A'),
+ b'a'..=b'f' => Ok(10 + c - b'a'),
+ _ => Err(DataFusionError::Execution(
+ "Input to unhex_digit is not a valid hex digit".to_string(),
+ )),
+ }
+}
+
+/// Convert a hex string to binary and store the result in `result`. Returns an error if the input
+/// is not a valid hex string.
+fn unhex(hex_str: &str, result: &mut Vec) -> Result<(), DataFusionError> {
+ let bytes = hex_str.as_bytes();
+
+ let mut i = 0;
+
+ if (bytes.len() & 0x01) != 0 {
+ let v = unhex_digit(bytes[0])?;
+
+ result.push(v);
+ i += 1;
+ }
+
+ while i < bytes.len() {
+ let first = unhex_digit(bytes[i])?;
+ let second = unhex_digit(bytes[i + 1])?;
+ result.push((first << 4) | second);
+
+ i += 2;
+ }
+
+ Ok(())
+}
+
+fn spark_unhex_inner(
+ array: &ColumnarValue,
+ fail_on_error: bool,
+) -> Result {
+ match array {
+ ColumnarValue::Array(array) => {
+ let string_array = as_generic_string_array::(array)?;
+
+ let mut encoded = Vec::new();
+ let mut builder = arrow::array::BinaryBuilder::new();
+
+ for item in string_array.iter() {
+ if let Some(s) = item {
+ if unhex(s, &mut encoded).is_ok() {
+ builder.append_value(encoded.as_slice());
+ } else if fail_on_error {
+ return exec_err!("Input to unhex is not a valid hex string: {s}");
+ } else {
+ builder.append_null();
+ }
+ encoded.clear();
+ } else {
+ builder.append_null();
+ }
+ }
+ Ok(ColumnarValue::Array(Arc::new(builder.finish())))
+ }
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => {
+ let mut encoded = Vec::new();
+
+ if unhex(string, &mut encoded).is_ok() {
+ Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded))))
+ } else if fail_on_error {
+ exec_err!("Input to unhex is not a valid hex string: {string}")
+ } else {
+ Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
+ }
+ }
+ ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
+ Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
+ }
+ _ => {
+ exec_err!(
+ "The first argument must be a string scalar or array, but got: {:?}",
+ array
+ )
+ }
+ }
+}
+
+pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result {
+ if args.len() > 2 {
+ return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len());
+ }
+
+ let val_to_unhex = &args[0];
+ let fail_on_error = if args.len() == 2 {
+ match &args[1] {
+ ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error,
+ _ => {
+ return exec_err!(
+ "The second argument must be boolean scalar, but got: {:?}",
+ args[1]
+ );
+ }
+ }
+ } else {
+ false
+ };
+
+ match val_to_unhex.data_type() {
+ DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error),
+ DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error),
+ other => exec_err!(
+ "The first argument must be a Utf8 or LargeUtf8: {:?}",
+ other
+ ),
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use std::sync::Arc;
+
+ use arrow::array::{BinaryBuilder, StringBuilder};
+ use arrow_array::make_array;
+ use arrow_data::ArrayData;
+ use datafusion::logical_expr::ColumnarValue;
+ use datafusion_common::ScalarValue;
+
+ use super::unhex;
+
+ #[test]
+ fn test_spark_unhex_null() -> Result<(), Box> {
+ let input = ArrayData::new_null(&arrow_schema::DataType::Utf8, 2);
+ let output = ArrayData::new_null(&arrow_schema::DataType::Binary, 2);
+
+ let input = ColumnarValue::Array(Arc::new(make_array(input)));
+ let expected = ColumnarValue::Array(Arc::new(make_array(output)));
+
+ let result = super::spark_unhex(&[input])?;
+
+ match (result, expected) {
+ (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => {
+ assert_eq!(*result, *expected);
+ Ok(())
+ }
+ _ => Err("Unexpected result type".into()),
+ }
+ }
+
+ #[test]
+ fn test_partial_error() -> Result<(), Box> {
+ let mut input = StringBuilder::new();
+
+ input.append_value("1CGG"); // 1C is ok, but GG is invalid
+ input.append_value("537061726B2053514C"); // followed by valid
+
+ let input = ColumnarValue::Array(Arc::new(input.finish()));
+ let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
+
+ let result = super::spark_unhex(&[input, fail_on_error])?;
+
+ let mut expected = BinaryBuilder::new();
+ expected.append_null();
+ expected.append_value("Spark SQL".as_bytes());
+
+ match (result, ColumnarValue::Array(Arc::new(expected.finish()))) {
+ (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => {
+ assert_eq!(*result, *expected);
+
+ Ok(())
+ }
+ _ => Err("Unexpected result type".into()),
+ }
+ }
+
+ #[test]
+ fn test_unhex_valid() -> Result<(), Box> {
+ let mut result = Vec::new();
+
+ unhex("537061726B2053514C", &mut result)?;
+ let result_str = std::str::from_utf8(&result)?;
+ assert_eq!(result_str, "Spark SQL");
+ result.clear();
+
+ unhex("1C", &mut result)?;
+ assert_eq!(result, vec![28]);
+ result.clear();
+
+ unhex("737472696E67", &mut result)?;
+ assert_eq!(result, "string".as_bytes());
+ result.clear();
+
+ unhex("1", &mut result)?;
+ assert_eq!(result, vec![1]);
+ result.clear();
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_odd_length() -> Result<(), Box> {
+ let mut result = Vec::new();
+
+ unhex("A1B", &mut result)?;
+ assert_eq!(result, vec![10, 27]);
+ result.clear();
+
+ unhex("0A1B", &mut result)?;
+ assert_eq!(result, vec![10, 27]);
+ result.clear();
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_unhex_empty() {
+ let mut result = Vec::new();
+
+ // Empty hex string
+ unhex("", &mut result).unwrap();
+ assert!(result.is_empty());
+ }
+
+ #[test]
+ fn test_unhex_invalid() {
+ let mut result = Vec::new();
+
+ // Invalid hex strings
+ assert!(unhex("##", &mut result).is_err());
+ assert!(unhex("G123", &mut result).is_err());
+ assert!(unhex("hello", &mut result).is_err());
+ assert!(unhex("\0", &mut result).is_err());
+ }
+}
diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs
index 6a050eb8bb..b5f8201aec 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -1326,6 +1326,7 @@ impl PhysicalPlanner {
.iter()
.map(|x| x.data_type(input_schema.as_ref()))
.collect::, _>>()?;
+
let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) {
Some(t) => t,
None => {
@@ -1333,17 +1334,18 @@ impl PhysicalPlanner {
// scalar function
// Note this assumes the `fun_name` is a defined function in DF. Otherwise, it'll
// throw error.
- let fun = BuiltinScalarFunction::from_str(fun_name);
- if fun.is_err() {
+
+ if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) {
+ fun.return_type(&input_expr_types)?
+ } else {
self.session_ctx
.udf(fun_name)?
.inner()
.return_type(&input_expr_types)?
- } else {
- fun?.return_type(&input_expr_types)?
}
}
};
+
let fun_expr =
create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?;
diff --git a/pom.xml b/pom.xml
index 6f0ebd8865..3dce920263 100644
--- a/pom.xml
+++ b/pom.xml
@@ -88,7 +88,8 @@ under the License.
-ea -Xmx4g -Xss4m ${extraJavaTestArgs}spark-3.3-plusspark-3.4
- spark-3.x
+ spark-3.x
+ spark-3.4
@@ -504,6 +505,7 @@ under the License.
not-needed-yetnot-needed-yet
+ spark-3.2
@@ -516,6 +518,7 @@ under the License.
1.12.0spark-3.3-plusnot-needed-yet
+ spark-3.3
@@ -527,6 +530,7 @@ under the License.
1.13.1spark-3.3-plusspark-3.4
+ spark-3.4
diff --git a/spark/pom.xml b/spark/pom.xml
index 7c4524bd62..1052b5a04b 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -263,7 +263,8 @@ under the License.
-
+
+
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 86e9f10b90..ad47766493 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -47,12 +47,13 @@ import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupp
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator}
+import org.apache.comet.shims.CometExprShim
import org.apache.comet.shims.ShimQueryPlanSerde
/**
* An utility object for query plan and expression serialization.
*/
-object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
+object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim {
def emitWarning(reason: String): Unit = {
logWarning(s"Comet native execution is disabled due to: $reason")
}
@@ -1467,6 +1468,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, left, right)
+ case e: Unhex if !isSpark32 =>
+ val unHex = unhexSerde(e)
+
+ val childExpr = exprToProtoInternal(unHex._1, inputs)
+ val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs)
+
+ val optExpr =
+ scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr)
+ optExprWithInfo(optExpr, expr, unHex._1)
+
case e @ Ceil(child) =>
val childExpr = exprToProtoInternal(child, inputs)
child.dataType match {
diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
new file mode 100644
index 0000000000..0c45a9c2ce
--- /dev/null
+++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions.
+ */
+trait CometExprShim {
+ /**
+ * Returns a tuple of expressions for the `unhex` function.
+ */
+ def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+ (unhex.child, Literal(false))
+ }
+}
diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
new file mode 100644
index 0000000000..0c45a9c2ce
--- /dev/null
+++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions.
+ */
+trait CometExprShim {
+ /**
+ * Returns a tuple of expressions for the `unhex` function.
+ */
+ def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+ (unhex.child, Literal(false))
+ }
+}
diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
new file mode 100644
index 0000000000..409e1c94b1
--- /dev/null
+++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions._
+
+/**
+ * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions.
+ */
+trait CometExprShim {
+ /**
+ * Returns a tuple of expressions for the `unhex` function.
+ */
+ def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+ (unhex.child, Literal(unhex.failOnError))
+ }
+}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index eb4429dc6c..28027c5cb5 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1036,7 +1036,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}
+ test("unhex") {
+ // When running against Spark 3.2, we include a bug fix for https://issues.apache.org/jira/browse/SPARK-40924 that
+ // was added in Spark 3.3, so although Comet's behavior is more correct when running against Spark 3.2, it is not
+ // the same (and this only applies to edge cases with hex inputs with lengths that are not divisible by 2)
+ assume(!isSpark32, "unhex function has incorrect behavior in 3.2")
+ val table = "unhex_table"
+ withTable(table) {
+ sql(s"create table $table(col string) using parquet")
+
+ sql(s"""INSERT INTO $table VALUES
+ |('537061726B2053514C'),
+ |('737472696E67'),
+ |('\\0'),
+ |(''),
+ |('###'),
+ |('G123'),
+ |('hello'),
+ |('A1B'),
+ |('0A1B')""".stripMargin)
+
+ checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table")
+ }
+ }
test("length, reverse, instr, replace, translate") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {