diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index b41d97520362..54a692f2f3aa 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -980,6 +980,7 @@ pub mod scalar_doc_sections { DOC_SECTION_STRUCT, DOC_SECTION_MAP, DOC_SECTION_HASHING, + DOC_SECTION_UNION, DOC_SECTION_OTHER, ] } @@ -996,6 +997,7 @@ pub mod scalar_doc_sections { DOC_SECTION_STRUCT, DOC_SECTION_MAP, DOC_SECTION_HASHING, + DOC_SECTION_UNION, DOC_SECTION_OTHER, ] } @@ -1070,4 +1072,10 @@ The following regular expression functions are supported:"#, label: "Other Functions", description: None, }; + + pub const DOC_SECTION_UNION: DocSection = DocSection { + include: true, + label: "Union Functions", + description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"), + }; } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 76fb4bbe5b47..425ce78decbe 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -34,6 +34,7 @@ pub mod nvl; pub mod nvl2; pub mod planner; pub mod r#struct; +pub mod union_extract; pub mod version; // create UDFs @@ -48,6 +49,7 @@ make_udf_function!(getfield::GetFieldFunc, get_field); make_udf_function!(coalesce::CoalesceFunc, coalesce); make_udf_function!(greatest::GreatestFunc, greatest); make_udf_function!(least::LeastFunc, least); +make_udf_function!(union_extract::UnionExtractFun, union_extract); make_udf_function!(version::VersionFunc, version); pub mod expr_fn { @@ -99,6 +101,11 @@ pub mod expr_fn { pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr { super::get_field().call(vec![arg1, arg2.lit()]) } + + #[doc = "Returns the value of the field with the given name from the union when it's selected, or NULL otherwise"] + pub fn union_extract(arg1: Expr, arg2: impl Literal) -> Expr { + super::union_extract().call(vec![arg1, arg2.lit()]) + } } /// Returns all DataFusion functions defined in this package @@ -121,6 +128,7 @@ pub fn functions() -> Vec> { coalesce(), greatest(), least(), + union_extract(), version(), r#struct(), ] diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs new file mode 100644 index 000000000000..d54627f73598 --- /dev/null +++ b/datafusion/functions/src/core/union_extract.rs @@ -0,0 +1,255 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::datatypes::{DataType, FieldRef, UnionFields}; +use datafusion_common::cast::as_union_array; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, +}; +use datafusion_doc::Documentation; +use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(label = "Union Functions"), + description = "Returns the value of the given field in the union when selected, or NULL otherwise.", + syntax_example = "union_extract(union, field_name)", + sql_example = r#"```sql +❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; ++--------------+----------------------------------+----------------------------------+ +| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') | ++--------------+----------------------------------+----------------------------------+ +| {a=1} | 1 | | +| {b=3.0} | | 3.0 | +| {a=4} | 4 | | +| {b=} | | | +| {a=} | | | ++--------------+----------------------------------+----------------------------------+ +```"#, + standard_argument(name = "union", prefix = "Union"), + argument( + name = "field_name", + description = "String expression to operate on. Must be a constant." + ) +)] +#[derive(Debug)] +pub struct UnionExtractFun { + signature: Signature, +} + +impl Default for UnionExtractFun { + fn default() -> Self { + Self::new() + } +} + +impl UnionExtractFun { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for UnionExtractFun { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "union_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + // should be using return_type_from_exprs and not calling the default implementation + internal_err!("union_extract should return type from exprs") + } + + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + if args.arg_types.len() != 2 { + return exec_err!( + "union_extract expects 2 arguments, got {} instead", + args.arg_types.len() + ); + } + + let DataType::Union(fields, _) = &args.arg_types[0] else { + return exec_err!( + "union_extract first argument must be a union, got {} instead", + args.arg_types[0] + ); + }; + + let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { + return exec_err!( + "union_extract second argument must be a non-null string literal, got {} instead", + args.arg_types[1] + ); + }; + + let field = find_field(fields, field_name)?.1; + + Ok(ReturnInfo::new_nullable(field.data_type().clone())) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = args.args; + + if args.len() != 2 { + return exec_err!( + "union_extract expects 2 arguments, got {} instead", + args.len() + ); + } + + let target_name = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"), + _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()), + }; + + match &args[0] { + ColumnarValue::Array(array) => { + let union_array = as_union_array(&array).map_err(|_| { + exec_datafusion_err!( + "union_extract first argument must be a union, got {} instead", + array.data_type() + ) + })?; + + Ok(ColumnarValue::Array( + arrow::compute::kernels::union_extract::union_extract( + union_array, + target_name?, + )?, + )) + } + ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => { + let target_name = target_name?; + let (target_type_id, target) = find_field(fields, target_name)?; + + let result = match value { + Some((type_id, value)) if target_type_id == *type_id => { + *value.clone() + } + _ => ScalarValue::try_from(target.data_type())?, + }; + + Ok(ColumnarValue::Scalar(result)) + } + other => exec_err!( + "union_extract first argument must be a union, got {} instead", + other.data_type() + ), + } + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> { + fields + .iter() + .find(|field| field.1.name() == name) + .ok_or_else(|| exec_datafusion_err!("field {name} not found on union")) +} + +#[cfg(test)] +mod tests { + + use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; + + use super::UnionExtractFun; + + // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests + #[test] + fn test_scalar_value() -> Result<()> { + let fun = UnionExtractFun::new(); + + let fields = UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::Utf8(None)); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::Utf8(None)); + + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; + + assert_scalar(result, ScalarValue::new_utf8("42")); + + Ok(()) + } + + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), + } + } +} diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index f7c9346a8983..e58f896080db 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -22,10 +22,11 @@ use std::path::Path; use std::sync::Arc; use arrow::array::{ - ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, + Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, + LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; use arrow::record_batch::RecordBatch; use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; use datafusion::physical_plan::ExecutionPlan; @@ -113,6 +114,10 @@ impl TestContext { info!("Registering metadata table tables"); register_metadata_tables(test_ctx.session_ctx()).await; } + "union_function.slt" => { + info!("Registering table with union column"); + register_union_table(test_ctx.session_ctx()) + } _ => { info!("Using default SessionContext"); } @@ -402,3 +407,24 @@ fn create_example_udf() -> ScalarUDF { adder, ) } + +fn register_union_table(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), + ScalarBuffer::from(vec![3, 3]), + None, + vec![Arc::new(Int32Array::from(vec![1, 2]))], + ) + .unwrap(); + + let schema = Schema::new(vec![Field::new( + "union_column", + union.data_type().clone(), + false, + )]); + + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union)]).unwrap(); + + ctx.register_batch("union_table", batch).unwrap(); +} diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt new file mode 100644 index 000000000000..9c70b1011f58 --- /dev/null +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## UNION DataType Tests +########## + +query ?I +select union_column, union_extract(union_column, 'int') from union_table; +---- +{int=1} 1 +{int=2} 2 + +query error DataFusion error: Execution error: field bool not found on union +select union_extract(union_column, 'bool') from union_table; + +query error DataFusion error: Error during planning: 'union_extract' does not support zero arguments +select union_extract() from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 1 +select union_extract(union_column) from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 1 +select union_extract('a') from union_table; + +query error DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead +select union_extract('a', union_column) from union_table; + +query error DataFusion error: Execution error: union_extract second argument must be a non\-null string literal, got Int64 instead +select union_extract(union_column, 1) from union_table; + +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3 +select union_extract(union_column, 'a', 'b') from union_table; diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index b769b8b7bdb0..6ebca7613660 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -4339,6 +4339,40 @@ sha512(expression) +-------------------------------------------+ ``` +## Union Functions + +Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator + +- [union_extract](#union_extract) + +### `union_extract` + +Returns the value of the given field in the union when selected, or NULL otherwise. + +``` +union_extract(union, field_name) +``` + +#### Arguments + +- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **field_name**: String expression to operate on. Must be a constant. + +#### Example + +```sql +❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; ++--------------+----------------------------------+----------------------------------+ +| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') | ++--------------+----------------------------------+----------------------------------+ +| {a=1} | 1 | | +| {b=3.0} | | 3.0 | +| {a=4} | 4 | | +| {b=} | | | +| {a=} | | | ++--------------+----------------------------------+----------------------------------+ +``` + ## Other Functions - [arrow_cast](#arrow_cast)