diff --git a/common/pom.xml b/common/pom.xml index 73b1a269b9..71d74ce056 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -179,7 +179,8 @@ under the License. - src/main/${shims.source} + src/main/${shims.majorVerSrc} + src/main/${shims.minorVerSrc} diff --git a/core/src/execution/datafusion/expressions/scalar_funcs.rs b/core/src/execution/datafusion/expressions/scalar_funcs.rs index 2895937ca0..8c5e1f3916 100644 --- a/core/src/execution/datafusion/expressions/scalar_funcs.rs +++ b/core/src/execution/datafusion/expressions/scalar_funcs.rs @@ -52,6 +52,9 @@ use num::{ }; use unicode_segmentation::UnicodeSegmentation; +mod unhex; +use unhex::spark_unhex; + macro_rules! make_comet_scalar_udf { ($name:expr, $func:ident, $data_type:ident) => {{ let scalar_func = CometScalarFunction::new( @@ -105,6 +108,10 @@ pub fn create_comet_physical_fun( "make_decimal" => { make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) } + "unhex" => { + let func = Arc::new(spark_unhex); + make_comet_scalar_udf!("unhex", func, without data_type) + } "decimal_div" => { make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) } @@ -123,11 +130,10 @@ pub fn create_comet_physical_fun( make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type) } _ => { - let fun = BuiltinScalarFunction::from_str(fun_name); - if fun.is_err() { - Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) + if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) { + Ok(ScalarFunctionDefinition::BuiltIn(fun)) } else { - Ok(ScalarFunctionDefinition::BuiltIn(fun?)) + Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?)) } } } diff --git a/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs new file mode 100644 index 0000000000..38d5c04787 --- /dev/null +++ b/core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::OffsetSizeTrait; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue}; + +/// Helper function to convert a hex digit to a binary value. +fn unhex_digit(c: u8) -> Result { + match c { + b'0'..=b'9' => Ok(c - b'0'), + b'A'..=b'F' => Ok(10 + c - b'A'), + b'a'..=b'f' => Ok(10 + c - b'a'), + _ => Err(DataFusionError::Execution( + "Input to unhex_digit is not a valid hex digit".to_string(), + )), + } +} + +/// Convert a hex string to binary and store the result in `result`. Returns an error if the input +/// is not a valid hex string. +fn unhex(hex_str: &str, result: &mut Vec) -> Result<(), DataFusionError> { + let bytes = hex_str.as_bytes(); + + let mut i = 0; + + if (bytes.len() & 0x01) != 0 { + let v = unhex_digit(bytes[0])?; + + result.push(v); + i += 1; + } + + while i < bytes.len() { + let first = unhex_digit(bytes[i])?; + let second = unhex_digit(bytes[i + 1])?; + result.push((first << 4) | second); + + i += 2; + } + + Ok(()) +} + +fn spark_unhex_inner( + array: &ColumnarValue, + fail_on_error: bool, +) -> Result { + match array { + ColumnarValue::Array(array) => { + let string_array = as_generic_string_array::(array)?; + + let mut encoded = Vec::new(); + let mut builder = arrow::array::BinaryBuilder::new(); + + for item in string_array.iter() { + if let Some(s) = item { + if unhex(s, &mut encoded).is_ok() { + builder.append_value(encoded.as_slice()); + } else if fail_on_error { + return exec_err!("Input to unhex is not a valid hex string: {s}"); + } else { + builder.append_null(); + } + encoded.clear(); + } else { + builder.append_null(); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => { + let mut encoded = Vec::new(); + + if unhex(string, &mut encoded).is_ok() { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded)))) + } else if fail_on_error { + exec_err!("Input to unhex is not a valid hex string: {string}") + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + _ => { + exec_err!( + "The first argument must be a string scalar or array, but got: {:?}", + array + ) + } + } +} + +pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len()); + } + + let val_to_unhex = &args[0]; + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match val_to_unhex.data_type() { + DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + other => exec_err!( + "The first argument must be a Utf8 or LargeUtf8: {:?}", + other + ), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{BinaryBuilder, StringBuilder}; + use arrow_array::make_array; + use arrow_data::ArrayData; + use datafusion::logical_expr::ColumnarValue; + use datafusion_common::ScalarValue; + + use super::unhex; + + #[test] + fn test_spark_unhex_null() -> Result<(), Box> { + let input = ArrayData::new_null(&arrow_schema::DataType::Utf8, 2); + let output = ArrayData::new_null(&arrow_schema::DataType::Binary, 2); + + let input = ColumnarValue::Array(Arc::new(make_array(input))); + let expected = ColumnarValue::Array(Arc::new(make_array(output))); + + let result = super::spark_unhex(&[input])?; + + match (result, expected) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + + #[test] + fn test_partial_error() -> Result<(), Box> { + let mut input = StringBuilder::new(); + + input.append_value("1CGG"); // 1C is ok, but GG is invalid + input.append_value("537061726B2053514C"); // followed by valid + + let input = ColumnarValue::Array(Arc::new(input.finish())); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + + let result = super::spark_unhex(&[input, fail_on_error])?; + + let mut expected = BinaryBuilder::new(); + expected.append_null(); + expected.append_value("Spark SQL".as_bytes()); + + match (result, ColumnarValue::Array(Arc::new(expected.finish()))) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + + #[test] + fn test_unhex_valid() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("537061726B2053514C", &mut result)?; + let result_str = std::str::from_utf8(&result)?; + assert_eq!(result_str, "Spark SQL"); + result.clear(); + + unhex("1C", &mut result)?; + assert_eq!(result, vec![28]); + result.clear(); + + unhex("737472696E67", &mut result)?; + assert_eq!(result, "string".as_bytes()); + result.clear(); + + unhex("1", &mut result)?; + assert_eq!(result, vec![1]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_odd_length() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + unhex("0A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_unhex_empty() { + let mut result = Vec::new(); + + // Empty hex string + unhex("", &mut result).unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn test_unhex_invalid() { + let mut result = Vec::new(); + + // Invalid hex strings + assert!(unhex("##", &mut result).is_err()); + assert!(unhex("G123", &mut result).is_err()); + assert!(unhex("hello", &mut result).is_err()); + assert!(unhex("\0", &mut result).is_err()); + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 6a050eb8bb..b5f8201aec 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -1326,6 +1326,7 @@ impl PhysicalPlanner { .iter() .map(|x| x.data_type(input_schema.as_ref())) .collect::, _>>()?; + let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) { Some(t) => t, None => { @@ -1333,17 +1334,18 @@ impl PhysicalPlanner { // scalar function // Note this assumes the `fun_name` is a defined function in DF. Otherwise, it'll // throw error. - let fun = BuiltinScalarFunction::from_str(fun_name); - if fun.is_err() { + + if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) { + fun.return_type(&input_expr_types)? + } else { self.session_ctx .udf(fun_name)? .inner() .return_type(&input_expr_types)? - } else { - fun?.return_type(&input_expr_types)? } } }; + let fun_expr = create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; diff --git a/pom.xml b/pom.xml index 6f0ebd8865..3dce920263 100644 --- a/pom.xml +++ b/pom.xml @@ -88,7 +88,8 @@ under the License. -ea -Xmx4g -Xss4m ${extraJavaTestArgs} spark-3.3-plus spark-3.4 - spark-3.x + spark-3.x + spark-3.4 @@ -504,6 +505,7 @@ under the License. not-needed-yet not-needed-yet + spark-3.2 @@ -516,6 +518,7 @@ under the License. 1.12.0 spark-3.3-plus not-needed-yet + spark-3.3 @@ -527,6 +530,7 @@ under the License. 1.13.1 spark-3.3-plus spark-3.4 + spark-3.4 diff --git a/spark/pom.xml b/spark/pom.xml index 7c4524bd62..1052b5a04b 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -263,7 +263,8 @@ under the License. - src/main/${shims.source} + src/main/${shims.majorVerSrc} + src/main/${shims.minorVerSrc} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 86e9f10b90..ad47766493 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -47,12 +47,13 @@ import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, Unsupp import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, JoinType, Operator} +import org.apache.comet.shims.CometExprShim import org.apache.comet.shims.ShimQueryPlanSerde /** * An utility object for query plan and expression serialization. */ -object QueryPlanSerde extends Logging with ShimQueryPlanSerde { +object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim { def emitWarning(reason: String): Unit = { logWarning(s"Comet native execution is disabled due to: $reason") } @@ -1467,6 +1468,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) optExprWithInfo(optExpr, expr, left, right) + case e: Unhex if !isSpark32 => + val unHex = unhexSerde(e) + + val childExpr = exprToProtoInternal(unHex._1, inputs) + val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs) + + val optExpr = + scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) + optExprWithInfo(optExpr, expr, unHex._1) + case e @ Ceil(child) => val childExpr = exprToProtoInternal(child, inputs) child.dataType match { diff --git a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 0000000000..0c45a9c2ce --- /dev/null +++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(false)) + } +} diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 0000000000..0c45a9c2ce --- /dev/null +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(false)) + } +} diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala new file mode 100644 index 0000000000..409e1c94b1 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions._ + +/** + * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. + */ +trait CometExprShim { + /** + * Returns a tuple of expressions for the `unhex` function. + */ + def unhexSerde(unhex: Unhex): (Expression, Expression) = { + (unhex.child, Literal(unhex.failOnError)) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index eb4429dc6c..28027c5cb5 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1036,7 +1036,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + test("unhex") { + // When running against Spark 3.2, we include a bug fix for https://issues.apache.org/jira/browse/SPARK-40924 that + // was added in Spark 3.3, so although Comet's behavior is more correct when running against Spark 3.2, it is not + // the same (and this only applies to edge cases with hex inputs with lengths that are not divisible by 2) + assume(!isSpark32, "unhex function has incorrect behavior in 3.2") + val table = "unhex_table" + withTable(table) { + sql(s"create table $table(col string) using parquet") + + sql(s"""INSERT INTO $table VALUES + |('537061726B2053514C'), + |('737472696E67'), + |('\\0'), + |(''), + |('###'), + |('G123'), + |('hello'), + |('A1B'), + |('0A1B')""".stripMargin) + + checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") + } + } test("length, reverse, instr, replace, translate") { Seq(false, true).foreach { dictionary => withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {