Skip to content

Commit

Permalink
feat: Implement Spark unhex (apache#342)
Browse files Browse the repository at this point in the history
  • Loading branch information
tshauck authored and Steve Vaughan committed May 9, 2024
1 parent 6acbeba commit 86582b9
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 12 deletions.
3 changes: 2 additions & 1 deletion common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ under the License.
</goals>
<configuration>
<sources>
<source>src/main/${shims.source}</source>
<source>src/main/${shims.majorVerSrc}</source>
<source>src/main/${shims.minorVerSrc}</source>
</sources>
</configuration>
</execution>
Expand Down
14 changes: 10 additions & 4 deletions core/src/execution/datafusion/expressions/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ use num::{
};
use unicode_segmentation::UnicodeSegmentation;

mod unhex;
use unhex::spark_unhex;

macro_rules! make_comet_scalar_udf {
($name:expr, $func:ident, $data_type:ident) => {{
let scalar_func = CometScalarFunction::new(
Expand Down Expand Up @@ -105,6 +108,10 @@ pub fn create_comet_physical_fun(
"make_decimal" => {
make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type)
}
"unhex" => {
let func = Arc::new(spark_unhex);
make_comet_scalar_udf!("unhex", func, without data_type)
}
"decimal_div" => {
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
}
Expand All @@ -123,11 +130,10 @@ pub fn create_comet_physical_fun(
make_comet_scalar_udf!(spark_func_name, wrapped_func, without data_type)
}
_ => {
let fun = BuiltinScalarFunction::from_str(fun_name);
if fun.is_err() {
Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?))
if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) {
Ok(ScalarFunctionDefinition::BuiltIn(fun))
} else {
Ok(ScalarFunctionDefinition::BuiltIn(fun?))
Ok(ScalarFunctionDefinition::UDF(registry.udf(fun_name)?))
}
}
}
Expand Down
257 changes: 257 additions & 0 deletions core/src/execution/datafusion/expressions/scalar_funcs/unhex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow_array::OffsetSizeTrait;
use arrow_schema::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue};

/// Helper function to convert a hex digit to a binary value.
fn unhex_digit(c: u8) -> Result<u8, DataFusionError> {
match c {
b'0'..=b'9' => Ok(c - b'0'),
b'A'..=b'F' => Ok(10 + c - b'A'),
b'a'..=b'f' => Ok(10 + c - b'a'),
_ => Err(DataFusionError::Execution(
"Input to unhex_digit is not a valid hex digit".to_string(),
)),
}
}

/// Convert a hex string to binary and store the result in `result`. Returns an error if the input
/// is not a valid hex string.
fn unhex(hex_str: &str, result: &mut Vec<u8>) -> Result<(), DataFusionError> {
let bytes = hex_str.as_bytes();

let mut i = 0;

if (bytes.len() & 0x01) != 0 {
let v = unhex_digit(bytes[0])?;

result.push(v);
i += 1;
}

while i < bytes.len() {
let first = unhex_digit(bytes[i])?;
let second = unhex_digit(bytes[i + 1])?;
result.push((first << 4) | second);

i += 2;
}

Ok(())
}

fn spark_unhex_inner<T: OffsetSizeTrait>(
array: &ColumnarValue,
fail_on_error: bool,
) -> Result<ColumnarValue, DataFusionError> {
match array {
ColumnarValue::Array(array) => {
let string_array = as_generic_string_array::<T>(array)?;

let mut encoded = Vec::new();
let mut builder = arrow::array::BinaryBuilder::new();

for item in string_array.iter() {
if let Some(s) = item {
if unhex(s, &mut encoded).is_ok() {
builder.append_value(encoded.as_slice());
} else if fail_on_error {
return exec_err!("Input to unhex is not a valid hex string: {s}");
} else {
builder.append_null();
}
encoded.clear();
} else {
builder.append_null();
}
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => {
let mut encoded = Vec::new();

if unhex(string, &mut encoded).is_ok() {
Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded))))
} else if fail_on_error {
exec_err!("Input to unhex is not a valid hex string: {string}")
} else {
Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
}
}
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {
Ok(ColumnarValue::Scalar(ScalarValue::Binary(None)))
}
_ => {
exec_err!(
"The first argument must be a string scalar or array, but got: {:?}",
array
)
}
}
}

pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
if args.len() > 2 {
return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len());
}

let val_to_unhex = &args[0];
let fail_on_error = if args.len() == 2 {
match &args[1] {
ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error,
_ => {
return exec_err!(
"The second argument must be boolean scalar, but got: {:?}",
args[1]
);
}
}
} else {
false
};

match val_to_unhex.data_type() {
DataType::Utf8 => spark_unhex_inner::<i32>(val_to_unhex, fail_on_error),
DataType::LargeUtf8 => spark_unhex_inner::<i64>(val_to_unhex, fail_on_error),
other => exec_err!(
"The first argument must be a Utf8 or LargeUtf8: {:?}",
other
),
}
}

#[cfg(test)]
mod test {
use std::sync::Arc;

use arrow::array::{BinaryBuilder, StringBuilder};
use arrow_array::make_array;
use arrow_data::ArrayData;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;

use super::unhex;

#[test]
fn test_spark_unhex_null() -> Result<(), Box<dyn std::error::Error>> {
let input = ArrayData::new_null(&arrow_schema::DataType::Utf8, 2);
let output = ArrayData::new_null(&arrow_schema::DataType::Binary, 2);

let input = ColumnarValue::Array(Arc::new(make_array(input)));
let expected = ColumnarValue::Array(Arc::new(make_array(output)));

let result = super::spark_unhex(&[input])?;

match (result, expected) {
(ColumnarValue::Array(result), ColumnarValue::Array(expected)) => {
assert_eq!(*result, *expected);
Ok(())
}
_ => Err("Unexpected result type".into()),
}
}

#[test]
fn test_partial_error() -> Result<(), Box<dyn std::error::Error>> {
let mut input = StringBuilder::new();

input.append_value("1CGG"); // 1C is ok, but GG is invalid
input.append_value("537061726B2053514C"); // followed by valid

let input = ColumnarValue::Array(Arc::new(input.finish()));
let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));

let result = super::spark_unhex(&[input, fail_on_error])?;

let mut expected = BinaryBuilder::new();
expected.append_null();
expected.append_value("Spark SQL".as_bytes());

match (result, ColumnarValue::Array(Arc::new(expected.finish()))) {
(ColumnarValue::Array(result), ColumnarValue::Array(expected)) => {
assert_eq!(*result, *expected);

Ok(())
}
_ => Err("Unexpected result type".into()),
}
}

#[test]
fn test_unhex_valid() -> Result<(), Box<dyn std::error::Error>> {
let mut result = Vec::new();

unhex("537061726B2053514C", &mut result)?;
let result_str = std::str::from_utf8(&result)?;
assert_eq!(result_str, "Spark SQL");
result.clear();

unhex("1C", &mut result)?;
assert_eq!(result, vec![28]);
result.clear();

unhex("737472696E67", &mut result)?;
assert_eq!(result, "string".as_bytes());
result.clear();

unhex("1", &mut result)?;
assert_eq!(result, vec![1]);
result.clear();

Ok(())
}

#[test]
fn test_odd_length() -> Result<(), Box<dyn std::error::Error>> {
let mut result = Vec::new();

unhex("A1B", &mut result)?;
assert_eq!(result, vec![10, 27]);
result.clear();

unhex("0A1B", &mut result)?;
assert_eq!(result, vec![10, 27]);
result.clear();

Ok(())
}

#[test]
fn test_unhex_empty() {
let mut result = Vec::new();

// Empty hex string
unhex("", &mut result).unwrap();
assert!(result.is_empty());
}

#[test]
fn test_unhex_invalid() {
let mut result = Vec::new();

// Invalid hex strings
assert!(unhex("##", &mut result).is_err());
assert!(unhex("G123", &mut result).is_err());
assert!(unhex("hello", &mut result).is_err());
assert!(unhex("\0", &mut result).is_err());
}
}
10 changes: 6 additions & 4 deletions core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1326,24 +1326,26 @@ impl PhysicalPlanner {
.iter()
.map(|x| x.data_type(input_schema.as_ref()))
.collect::<Result<Vec<_>, _>>()?;

let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) {
Some(t) => t,
None => {
// If no data type is provided from Spark, we'll use DF's return type from the
// scalar function
// Note this assumes the `fun_name` is a defined function in DF. Otherwise, it'll
// throw error.
let fun = BuiltinScalarFunction::from_str(fun_name);
if fun.is_err() {

if let Ok(fun) = BuiltinScalarFunction::from_str(fun_name) {
fun.return_type(&input_expr_types)?
} else {
self.session_ctx
.udf(fun_name)?
.inner()
.return_type(&input_expr_types)?
} else {
fun?.return_type(&input_expr_types)?
}
}
};

let fun_expr =
create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?;

Expand Down
6 changes: 5 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ under the License.
<argLine>-ea -Xmx4g -Xss4m ${extraJavaTestArgs}</argLine>
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
<additional.3_4.test.source>spark-3.4</additional.3_4.test.source>
<shims.source>spark-3.x</shims.source>
<shims.majorVerSrc>spark-3.x</shims.majorVerSrc>
<shims.minorVerSrc>spark-3.4</shims.minorVerSrc>
</properties>

<dependencyManagement>
Expand Down Expand Up @@ -504,6 +505,7 @@ under the License.
<!-- we don't add special test suits for spark-3.2, so a not existed dir is specified-->
<additional.3_3.test.source>not-needed-yet</additional.3_3.test.source>
<additional.3_4.test.source>not-needed-yet</additional.3_4.test.source>
<shims.minorVerSrc>spark-3.2</shims.minorVerSrc>
</properties>
</profile>

Expand All @@ -516,6 +518,7 @@ under the License.
<parquet.version>1.12.0</parquet.version>
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
<additional.3_4.test.source>not-needed-yet</additional.3_4.test.source>
<shims.minorVerSrc>spark-3.3</shims.minorVerSrc>
</properties>
</profile>

Expand All @@ -527,6 +530,7 @@ under the License.
<parquet.version>1.13.1</parquet.version>
<additional.3_3.test.source>spark-3.3-plus</additional.3_3.test.source>
<additional.3_4.test.source>spark-3.4</additional.3_4.test.source>
<shims.minorVerSrc>spark-3.4</shims.minorVerSrc>
</properties>
</profile>

Expand Down
3 changes: 2 additions & 1 deletion spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ under the License.
</goals>
<configuration>
<sources>
<source>src/main/${shims.source}</source>
<source>src/main/${shims.majorVerSrc}</source>
<source>src/main/${shims.minorVerSrc}</source>
</sources>
</configuration>
</execution>
Expand Down
Loading

0 comments on commit 86582b9

Please sign in to comment.