From 3faf76843a81a84d886963bce5370a344bf27bcb Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Thu, 22 Aug 2024 17:49:30 +0000 Subject: [PATCH 01/15] feat: add union_extract scalar function --- datafusion/functions/Cargo.toml | 5 + datafusion/functions/benches/union_extract.rs | 1121 +++++++++++++++++ datafusion/functions/src/core/mod.rs | 8 + .../functions/src/core/union_extract.rs | 722 +++++++++++ datafusion/sqllogictest/src/test_context.rs | 904 ++++++++++++- .../test_files/union_datatype.slt | 270 ++++ .../source/user-guide/sql/scalar_functions.md | 55 + 7 files changed, 3082 insertions(+), 3 deletions(-) create mode 100644 datafusion/functions/benches/union_extract.rs create mode 100644 datafusion/functions/src/core/union_extract.rs create mode 100644 datafusion/sqllogictest/test_files/union_datatype.slt diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 9ef020b772f0..5ce49763bb8e 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -161,3 +161,8 @@ required-features = ["string_expressions"] harness = false name = "random" required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "union_extract" +required-features = ["core_expressions"] diff --git a/datafusion/functions/benches/union_extract.rs b/datafusion/functions/benches/union_extract.rs new file mode 100644 index 000000000000..1e748ff5a0c0 --- /dev/null +++ b/datafusion/functions/benches/union_extract.rs @@ -0,0 +1,1121 @@ +#[macro_use] +extern crate criterion; + +use crate::criterion::Criterion; +use arrow::{ + array::{ + Array, BooleanArray, Int32Array, Int8Array, NullArray, StringArray, UnionArray, + }, + datatypes::{DataType, Field, Int32Type, Int8Type, UnionFields, UnionMode}, + util::bench_util::{ + create_boolean_array, create_primitive_array, create_string_array, + }, +}; +use arrow_buffer::ScalarBuffer; +use criterion::black_box; +use datafusion_common::ScalarValue; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; +use datafusion_functions::core::union_extract::{ + eq_scalar_generic, is_sequential_generic, UnionExtractFun, +}; +use itertools::repeat_n; +use rand::random; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let union_extract = UnionExtractFun::new(); + + c.bench_function("union_extract case 1.1 sparse single field", |b| { + let union = UnionArray::try_new( + UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, false)]), //single field + ScalarBuffer::from(vec![1; 2048]), //non empty union + None, //sparse + vec![Arc::new(create_string_array::(2048, 0.0))], //non null target + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function("union_extract case 1.2 sparse empty union", |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 2], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("str2", DataType::Utf8, false), + ], + ), + ScalarBuffer::from(vec![]), // empty union + None, //sparse + vec![ + Arc::new(StringArray::new_null(0)), + Arc::new(StringArray::new_null(0)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function("union_extract case 1.3a sparse child null", |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("null", DataType::Null, true), + ], + ), + ScalarBuffer::from(vec![1; 2048]), // non empty union + None, //sparse + vec![ + Arc::new(StringArray::new_null(2048)), // null target + Arc::new(NullArray::new(2048)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("null")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function("union_extract case 1.3b sparse child null", |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1; 2048]), // non empty union + None, //sparse + vec![ + Arc::new(StringArray::new_null(2048)), // null target + Arc::new(create_primitive_array::(2048, 0.0)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function("union_extract case 2 sparse all types match", |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1; 2048]), //all types match & non empty union + None, //sparse + vec![ + Arc::new(create_string_array::(2048, 0.0)), //non null target + Arc::new(Int32Array::new_null(2048)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function( + "union_extract case 3.1 none selected target can contain null mask", + |b| { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), //none selected + None, + vec![ + Arc::new(create_string_array::(2048, 0.5)), + Arc::new(Int32Array::new_null(2048)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 3.2 none matches sparse cant contain null mask", + |b| { + let target_fields = + UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); + + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, true), + Field::new( + "union", + DataType::Union(target_fields.clone(), UnionMode::Sparse), + false, + ), + ], + ), + ScalarBuffer::from_iter(repeat_n(1, 2048)), //none matches + None, //sparse + vec![ + Arc::new(create_string_array::(2048, 0.5)), + Arc::new( + UnionArray::try_new( + target_fields, + ScalarBuffer::from(vec![10; 2048]), + None, + vec![Arc::new(BooleanArray::from(vec![true; 2048]))], + ) + .unwrap(), + ), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("union")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.1.1 sparse some matches target with nulls", + |b| { + let union = UnionArray::try_new( + //multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from_iter(repeat_n(1, 2047).chain([3])), //multiple types + None, //sparse + vec![ + Arc::new(create_string_array::(2048, 0.5)), //target with some nulls, but not all + Arc::new(Int32Array::new_null(2048)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.1.2 sparse some matches target without nulls", + |b| { + let union = UnionArray::try_new( + //multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from_iter(repeat_n(1, 2047).chain([3])), //multiple types + None, //sparse + vec![ + Arc::new(create_string_array::(2048, 0.0)), //target without nulls + Arc::new(Int32Array::new_null(2048)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.2 some matches sparse cant contain null mask", + |b| { + let target_fields = + UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); + + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new( + "union", + DataType::Union(target_fields.clone(), UnionMode::Sparse), + false, + ), + ], + ), + ScalarBuffer::from_iter(repeat_n([1, 3], 1024).flatten()), //some matches + None, //sparse + vec![ + Arc::new(NullArray::new(2048)), //null target + Arc::new( + UnionArray::try_new( + target_fields, + ScalarBuffer::from(vec![10; 2048]), + None, + vec![Arc::new(BooleanArray::from(vec![true; 2048]))], + ) + .unwrap(), + ), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("union")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 1.1 dense empty union empty target", + |b| { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![]), //empty union + Some(ScalarBuffer::from(vec![])), //dense + vec![ + Arc::new(StringArray::new_null(0)), //empty target + Arc::new(Int32Array::new_null(0)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 1.2 dense empty union non-empty target", + |b| { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![]), // empty union + Some(ScalarBuffer::from(vec![])), // dense + vec![ + Arc::new(StringArray::from(vec!["a1", "s2"])), // non empty target + Arc::new(Int32Array::new_null(0)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 2 dense non empty union, empty target", + |b| { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3]), // non empty union + Some(ScalarBuffer::from(vec![0, 1])), // dense + vec![ + Arc::new(StringArray::new_null(0)), // empty target + Arc::new(Int32Array::new_null(2)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 3.1 dense null target len smaller", + |b| { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1; 2048]), + Some(ScalarBuffer::from(vec![0; 2048])), // dense + vec![ + Arc::new(StringArray::new_null(1)), // null & len smaller + Arc::new(Int32Array::new_null(64)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function("union_extract case 3.2 dense null target len equal", |b| { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1; 2048]), + Some(ScalarBuffer::from_iter(0..2048)), // dense + vec![ + Arc::new(StringArray::new_null(2048)), // null & same len as parent + Arc::new(Int32Array::new_null(64)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function("union_extract case 3.3 dense null target len bigger", |b| { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1; 2048]), + Some(ScalarBuffer::from(vec![0; 2048])), + vec![ + Arc::new(StringArray::new_null(4096)), // null, bigger than parent + Arc::new(Int32Array::new_null(64)), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function( + "union_extract case 4.1A dense single field sequential offsets equal lens", + |b| { + let union = UnionArray::try_new( + //single field + UnionFields::new(vec![3], vec![Field::new("int", DataType::Int8, false)]), + ScalarBuffer::from(vec![3; 2048]), + Some(ScalarBuffer::from_iter(0..2048)), //sequential offsets + vec![Arc::new(create_primitive_array::(2048, 0.0))], //same len as parent, not null + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.2A dense single field sequential offsets bigger len", + |b| { + let union = UnionArray::try_new( + // single field + UnionFields::new(vec![3], vec![Field::new("int", DataType::Int8, false)]), + ScalarBuffer::from(vec![3; 2048]), + Some(ScalarBuffer::from_iter(0..2048)), //sequential offsets + vec![Arc::new(create_primitive_array::(4096, 0.0))], //bigger than parent, not null + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.3A dense single field non-sequential offsets", + |b| { + let union = UnionArray::try_new( + // single field + UnionFields::new(vec![3], vec![Field::new("int", DataType::Int8, false)]), + ScalarBuffer::from(vec![3; 2048]), + Some(ScalarBuffer::from_iter((0..2046).chain([2047, 2047]))), // non sequential offsets, avoid fast paths + vec![Arc::new(create_primitive_array::(2048, 0.0))], // not null + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.1B dense empty siblings sequential offsets equal len", + |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int8, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), // all types must match + Some(ScalarBuffer::from_iter(0..2048)), // sequential offsets + vec![ + Arc::new(StringArray::new_null(0)), // empty sibling + Arc::new(create_primitive_array::(2048, 0.0)), // same len as parent, not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.2B dense empty siblings sequential offsets bigger target", + |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int8, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), // all types match + Some(ScalarBuffer::from_iter(0..2048)), // sequential offsets + vec![ + Arc::new(StringArray::new_null(0)), // empty sibling + Arc::new(create_primitive_array::(4096, 0.0)), // target is bigger than parent, not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.3B dense empty sibling non-sequential offsets", + |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), // all types must match + Some(ScalarBuffer::from_iter((0..2046).chain([2047, 2047]))), // non sequential offsets, avois fast paths + vec![ + Arc::new(StringArray::new_null(0)), // empty sibling + Arc::new(create_primitive_array::(2048, 0.0)), // not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.1C dense all types match sequential offsets equal lens", + |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int8, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), // all types match + Some(ScalarBuffer::from_iter(0..2048)), // sequential offsets + vec![ + Arc::new(StringArray::new_null(1)), // non empty sibling + Arc::new(create_primitive_array::(2048, 0.0)), // same len as parent, not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.2C dense all types match sequential offsets bigger len", + |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int8, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), // all types match + Some(ScalarBuffer::from_iter(0..2048)), // sequential offsets + vec![ + Arc::new(StringArray::new_null(1)), // non empty sibling + Arc::new(create_primitive_array::(4096, 0.0)), // bigger than parent union, not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function( + "union_extract case 4.3C dense all types match non-sequential offsets", + |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int8, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), // all types match + Some(ScalarBuffer::from_iter((0..2046).chain([2047, 2047]))), //non sequential, avoid fast paths + vec![ + Arc::new(StringArray::new_null(1)), // non empty sibling + Arc::new(create_primitive_array::(2048, 0.0)), // not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function("union_extract case 5.1a dense none match less len", |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1; 2048]), //none match + Some(ScalarBuffer::from_iter(0..2048)), //dense + vec![ + Arc::new(create_string_array::(2048, 0.0)), // non empty + Arc::new(create_primitive_array::(1024, 0.0)), //less len, not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function( + "union_extract case 5.1b dense none match cant contain null mask", + |b| { + let union_target = UnionArray::try_new( + UnionFields::new([1], vec![Field::new("a", DataType::Boolean, true)]), + vec![1; 2048].into(), + None, + vec![Arc::new(create_boolean_array(2048, 0.0, 0.0))], + ) + .unwrap(); + + let parent_union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("union", union_target.data_type().clone(), false), + ], + ), + ScalarBuffer::from(vec![1; 2048]), //none match + Some(ScalarBuffer::from_iter(0..2048)), //dense + vec![ + Arc::new(create_string_array::(2048, 0.0)), // non empty + Arc::new(union_target), + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(parent_union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("union")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }, + ); + + c.bench_function("union_extract case 5.2 dense none match equal len", |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), //none match + Some(ScalarBuffer::from_iter(0..2048)), //dense + vec![ + Arc::new(create_string_array::(2048, 0.0)), // non empty + Arc::new(create_primitive_array::(2048, 0.0)), //equal len, not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function("union_extract case 5.3 dense none match greater len", |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3; 2048]), //none match + Some(ScalarBuffer::from_iter(0..2048)), //dense + vec![ + Arc::new(create_string_array::(2048, 0.0)), // non empty + Arc::new(create_primitive_array::(2049, 0.0)), //greater len, not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union)), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + c.bench_function("union_extract case 6 some match", |b| { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from_iter(repeat_n([1, 3], 1024).flatten()), //some matches but not all + Some(ScalarBuffer::from_iter( + std::iter::zip(1024..2048, 0..1024).flat_map(|(a, b)| [a, b]), + )), + vec![ + Arc::new(create_string_array::(2048, 0.0)), // sibling is not empty + Arc::new(create_primitive_array::(1024, 0.0)), //not null + ], + ) + .unwrap(); + + let args = [ + ColumnarValue::Array(Arc::new(union.clone())), + ColumnarValue::Scalar(ScalarValue::new_utf8("int")), + ]; + + b.iter(|| { + union_extract.invoke(&args).unwrap(); + }) + }); + + { + let mut is_sequential_group = c.benchmark_group("offsets"); + + let start = random::() as i32; + let offsets = (start..start + 4096).collect::>(); + + //compare performance to simpler alternatives + + is_sequential_group.bench_function("offsets sequential windows all", |b| { + b.iter(|| { + black_box(offsets.windows(2).all(|window| window[0] + 1 == window[1])); + }) + }); + + is_sequential_group.bench_function("offsets sequential windows fold &&", |b| { + b.iter(|| { + black_box( + offsets + .windows(2) + .fold(true, |b, w| b && (w[0] + 1 == w[1])), + ) + }) + }); + + is_sequential_group.bench_function("offsets sequential windows fold &", |b| { + b.iter(|| { + black_box(offsets.windows(2).fold(true, |b, w| b & (w[0] + 1 == w[1]))) + }) + }); + + is_sequential_group.bench_function("offsets sequential all", |b| { + b.iter(|| { + black_box( + offsets + .iter() + .copied() + .enumerate() + .all(|(i, v)| v == offsets[0] + i as i32), + ) + }) + }); + + is_sequential_group.bench_function("offsets sequential fold &&", |b| { + b.iter(|| { + black_box( + offsets + .iter() + .copied() + .enumerate() + .fold(true, |b, (i, v)| b && (v == offsets[0] + i as i32)), + ) + }) + }); + + is_sequential_group.bench_function("offsets sequential fold &", |b| { + b.iter(|| { + black_box( + offsets + .iter() + .copied() + .enumerate() + .fold(true, |b, (i, v)| b & (v == offsets[0] + i as i32)), + ) + }) + }); + + macro_rules! bench_sequential { + ($n:literal) => { + is_sequential_group + .bench_function(&format!("offsets sequential chunk {}", $n), |b| { + b.iter(|| black_box(is_sequential_generic::<$n>(&offsets))) + }); + }; + } + + bench_sequential!(8); + bench_sequential!(16); + bench_sequential!(32); + bench_sequential!(64); + bench_sequential!(128); + bench_sequential!(256); + bench_sequential!(512); + bench_sequential!(1024); + bench_sequential!(2048); + bench_sequential!(4096); + + is_sequential_group.finish(); + } + + { + let mut type_ids_eq = c.benchmark_group("type_ids_eq"); + + let type_id = random::(); + let type_ids = vec![type_id; 65536]; + + //compare performance to simpler alternatives + + type_ids_eq.bench_function("type_ids equal all", |b| { + b.iter(|| { + type_ids + .iter() + .copied() + .all(|value_type_id| value_type_id == type_id) + }) + }); + + type_ids_eq.bench_function("type_ids equal fold &&", |b| { + b.iter(|| type_ids.iter().fold(true, |b, v| b && (*v == type_id))) + }); + + type_ids_eq.bench_function("type_ids equal fold &", |b| { + b.iter(|| type_ids.iter().fold(true, |b, v| b & (*v == type_id))) + }); + + type_ids_eq.bench_function("type_ids equal compute::eq", |b| { + let type_ids_array = Int8Array::new(type_ids.clone().into(), None); + + b.iter(|| { + let eq = arrow::compute::kernels::cmp::eq( + &type_ids_array, + &Int8Array::new_scalar(black_box(type_id)), + ) + .unwrap(); + + eq.true_count() == type_ids.len() + }) + }); + + macro_rules! bench_type_ids_eq { + ($n:literal) => { + type_ids_eq.bench_function(&format!("type_ids equal true {}", $n), |b| { + b.iter(|| eq_scalar_generic::<$n>(&type_ids, type_ids[0])) + }); + + type_ids_eq + .bench_function(&format!("type_ids equal false {}", $n), |b| { + b.iter(|| eq_scalar_generic::<$n>(&type_ids, type_ids[0] + 1)) + }); + + type_ids_eq.bench_function(&format!("type_ids worst case {}", $n), |b| { + let mut type_ids = type_ids.clone(); + + type_ids[65535] = !type_ids[65535]; + + b.iter(|| eq_scalar_generic::<$n>(&type_ids, type_ids[0])) + }); + + type_ids_eq.bench_function(&format!("type_ids best case {}", $n), |b| { + let mut type_ids = type_ids.clone(); + + type_ids[$n - 1] += 1; + + b.iter(|| eq_scalar_generic::<$n>(&type_ids, type_ids[0])) + }); + }; + } + + bench_type_ids_eq!(16); + bench_type_ids_eq!(32); + bench_type_ids_eq!(64); + bench_type_ids_eq!(128); + bench_type_ids_eq!(256); + bench_type_ids_eq!(512); + bench_type_ids_eq!(1024); + bench_type_ids_eq!(2048); + bench_type_ids_eq!(4096); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index af340930eabc..8e5b68b39199 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -31,6 +31,7 @@ pub mod nvl; pub mod nvl2; pub mod planner; pub mod r#struct; +pub mod union_extract; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); @@ -42,6 +43,7 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); +make_udf_function!(union_extract::UnionExtractFun, UNION_EXTRACT, union_extract); pub mod expr_fn { use datafusion_expr::{Expr, Literal}; @@ -84,6 +86,11 @@ pub mod expr_fn { pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr { super::get_field().call(vec![arg1, arg2.lit()]) } + + #[doc = "Returns the value of the field with the given name from the union when it's selected, or NULL otherwise"] + pub fn union_extract(arg1: Expr, arg2: impl Literal) -> Expr { + super::union_extract().call(vec![arg1, arg2.lit()]) + } } /// Returns all DataFusion functions defined in this package @@ -104,5 +111,6 @@ pub fn functions() -> Vec> { // calls to `get_field` get_field(), coalesce(), + union_extract(), ] } diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs new file mode 100644 index 000000000000..433348b0e6fd --- /dev/null +++ b/datafusion/functions/src/core/union_extract.rs @@ -0,0 +1,722 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::array::{ + layout, make_array, new_empty_array, new_null_array, Array, ArrayRef, BooleanArray, + Int32Array, Scalar, UnionArray, +}; +use arrow::compute::take; +use arrow::datatypes::{DataType, FieldRef, UnionFields, UnionMode}; + +use arrow::buffer::{BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer}; +use arrow::util::bit_util; +use datafusion_common::cast::as_union_array; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, ExprSchema, Result, ScalarValue, +}; +use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub struct UnionExtractFun { + signature: Signature, +} + +impl Default for UnionExtractFun { + fn default() -> Self { + Self::new() + } +} + +impl UnionExtractFun { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for UnionExtractFun { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "union_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + // should be using return_type_from_exprs and not calling the default implementation + internal_err!("union_extract should return type from exprs") + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + _: &dyn ExprSchema, + arg_types: &[DataType], + ) -> Result { + if args.len() != 2 { + return exec_err!( + "union_extract expects 2 arguments, got {} instead", + args.len() + ); + } + + let fields = if let DataType::Union(fields, _) = &arg_types[0] { + fields + } else { + return exec_err!( + "union_extract first argument must be a union, got {} instead", + arg_types[0] + ); + }; + + let field_name = if let Expr::Literal(ScalarValue::Utf8(Some(field_name))) = + &args[1] + { + field_name + } else { + return exec_err!( + "union_extract second argument must be a non-null string literal, got {} instead", + arg_types[1] + ); + }; + + let field = find_field(fields, field_name)?.1; + + Ok(field.data_type().clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return exec_err!( + "union_extract expects 2 arguments, got {} instead", + args.len() + ); + } + + let union = &args[0]; + + let target_name = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"), + _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()), + }; + + match union { + ColumnarValue::Array(array) => { + let union_array = as_union_array(&array).map_err(|_| { + exec_datafusion_err!( + "union_extract first argument must be a union, got {} instead", + array.data_type() + ) + })?; + + let (fields, mode) = match union_array.data_type() { + DataType::Union(fields, mode) => (fields, mode), + _ => unreachable!(), + }; + + let target_type_id = find_field(fields, target_name?)?.0; + + match mode { + UnionMode::Sparse => { + Ok(extract_sparse(union_array, fields, target_type_id)?) + } + UnionMode::Dense => { + Ok(extract_dense(union_array, fields, target_type_id)?) + } + } + } + ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => { + let target_name = target_name?; + let (target_type_id, target) = find_field(fields, target_name)?; + + let result = match value { + Some((type_id, value)) if target_type_id == *type_id => { + *value.clone() + } + _ => ScalarValue::try_from(target.data_type())?, + }; + + Ok(ColumnarValue::Scalar(result)) + } + other => exec_err!( + "union_extract first argument must be a union, got {} instead", + other.data_type() + ), + } + } +} + +fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> { + fields + .iter() + .find(|field| field.1.name() == name) + .ok_or_else(|| exec_datafusion_err!("field {name} not found on union")) +} + +fn extract_sparse( + union_array: &UnionArray, + fields: &UnionFields, + target_type_id: i8, +) -> Result { + let target = union_array.child(target_type_id); + + if fields.len() == 1 // case 1.1: if there is a single field, all type ids are the same, and since union doesn't have a null mask, the result array is exactly the same as it only child + || union_array.is_empty() // case 1.2: sparse union length and childrens length must match, if the union is empty, so is any children + || target.null_count() == target.len() || target.data_type().is_null() + // case 1.3: if all values of the target children are null, regardless of selected type ids, the result will also be completely null + { + Ok(ColumnarValue::Array(Arc::clone(target))) + } else { + match eq_scalar(union_array.type_ids(), target_type_id) { + // case 2: all type ids equals our target, and since unions doesn't have a null mask, the result array is exactly the same as our target + BoolValue::Scalar(true) => Ok(ColumnarValue::Array(Arc::clone(target))), + // case 3: none type_id matches our target, the result is a null array + BoolValue::Scalar(false) => { + if layout(target.data_type()).can_contain_null_mask { + // case 3.1: target array can contain a null mask + //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above + let data = unsafe { + target + .into_data() + .into_builder() + .nulls(Some(NullBuffer::new_null(target.len()))) + .build_unchecked() + }; + + Ok(ColumnarValue::Array(make_array(data))) + } else { + // case 3.2: target can't contain a null mask + Ok(new_null_columnar_value(target.data_type(), target.len())) + } + } + // case 4: some but not all type_id matches our target + BoolValue::Buffer(selected) => { + if layout(target.data_type()).can_contain_null_mask { + // case 4.1: target array can contain a null mask + let nulls = match target.nulls().filter(|n| n.null_count() > 0) { + // case 4.1.1: our target child has nulls and types other than our target are selected, union the masks + // the case where n.null_count() == n.len() is cheaply handled at case 1.3 + Some(nulls) => &selected & nulls.inner(), + // case 4.1.2: target child has no nulls, but types other than our target are selected, use the selected mask as a null mask + None => selected, + }; + + //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above + let data = unsafe { + assert_eq!(nulls.len(), target.len()); + + target + .into_data() + .into_builder() + .nulls(Some(nulls.into())) + .build_unchecked() + }; + + Ok(ColumnarValue::Array(make_array(data))) + } else { + // case 4.2: target can't containt a null mask, zip the values that match with a null value + Ok(ColumnarValue::Array(arrow::compute::kernels::zip::zip( + &BooleanArray::new(selected, None), + target, + &Scalar::new(new_null_array(target.data_type(), 1)), + )?)) + } + } + } + } +} + +fn extract_dense( + union_array: &UnionArray, + fields: &UnionFields, + target_type_id: i8, +) -> Result { + let target = union_array.child(target_type_id); + let offsets = union_array.offsets().unwrap(); + + if union_array.is_empty() { + // case 1: the union is empty + if target.is_empty() { + // case 1.1: the target is also empty, do a cheap Arc::clone instead of allocating a new empty array + Ok(ColumnarValue::Array(Arc::clone(target))) + } else { + // case 1.2: the target is not empty, allocate a new empty array + Ok(ColumnarValue::Array(new_empty_array(target.data_type()))) + } + } else if target.is_empty() { + // case 2: the union is not empty but the target is, which implies that none type_id points to it. The result is a null array + Ok(new_null_columnar_value( + target.data_type(), + union_array.len(), + )) + } else if target.null_count() == target.len() || target.data_type().is_null() { + // case 3: since all values on our target are null, regardless of selected type ids and offsets, the result is a null array + match target.len().cmp(&union_array.len()) { + // case 3.1: since the target is smaller than the union, allocate a new correclty sized null array + Ordering::Less => Ok(new_null_columnar_value( + target.data_type(), + union_array.len(), + )), + // case 3.2: target equals the union len, return it direcly + Ordering::Equal => Ok(ColumnarValue::Array(Arc::clone(target))), + // case 3.3: target len is bigger than the union len, slice it + Ordering::Greater => { + Ok(ColumnarValue::Array(target.slice(0, union_array.len()))) + } + } + } else if fields.len() == 1 // case A: since there's a single field, our target, every type id must matches our target + || fields + .iter() + .filter(|(field_type_id, _)| *field_type_id != target_type_id) + .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty()) + // case B: since siblings are empty, every type id must matches our target + { + // case 4: every type id matches our target + Ok(ColumnarValue::Array(extract_dense_all_selected( + union_array, + target, + offsets, + )?)) + } else { + match eq_scalar(union_array.type_ids(), target_type_id) { + // case 4C: all type ids matches our target. + // Non empty sibling without any selected value may happen after slicing the parent union, + // since only type_ids and offsets are sliced, not the children + BoolValue::Scalar(true) => Ok(ColumnarValue::Array( + extract_dense_all_selected(union_array, target, offsets)?, + )), + BoolValue::Scalar(false) => { + // case 5: none type_id matches our target, so the result array will be completely null + // Non empty target without any selected value may happen after slicing the parent union, + // since only type_ids and offsets are sliced, not the children + match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) { + (Ordering::Less, _) // case 5.1A: our target is smaller than the parent union, allocate a new correclty sized null array + | (_, false) => { // case 5.1B: target array can't contain a null mask + Ok(new_null_columnar_value(target.data_type(), union_array.len())) + } + // case 5.2: target and parent union lengths are equal, and the target can contain a null mask, let's set it to a all-null null-buffer + (Ordering::Equal, true) => { + //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above + let data = unsafe { + target + .into_data() + .into_builder() + .nulls(Some(NullBuffer::new_null(union_array.len()))) + .build_unchecked() + }; + + Ok(ColumnarValue::Array(make_array(data))) + } + // case 5.3: target is bigger than it's parent union and can contain a null mask, let's slice it, and set it's nulls to a all-null null-buffer + (Ordering::Greater, true) => { + //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above + let data = unsafe { + target + .into_data() + .slice(0, union_array.len()) + .into_builder() + .nulls(Some(NullBuffer::new_null(union_array.len()))) + .build_unchecked() + }; + + Ok(ColumnarValue::Array(make_array(data))) + } + } + } + BoolValue::Buffer(selected) => { + //case 6: some type_ids matches our target, but not all. For selected values, take the value pointed by the offset. For unselected, take a valid null + Ok(ColumnarValue::Array(take( + target, + &Int32Array::new(offsets.clone(), Some(selected.into())), + None, + )?)) + } + } + } +} + +fn extract_dense_all_selected( + union_array: &UnionArray, + target: &Arc, + offsets: &ScalarBuffer, +) -> Result { + let sequential = + target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets); + + if sequential && target.len() == union_array.len() { + // case 1: all offsets are sequential and both lengths match, return the array directly + Ok(Arc::clone(target)) + } else if sequential && target.len() > union_array.len() { + // case 2: All offsets are sequential, but our target is bigger than our union, slice it, starting at the first offset + Ok(target.slice(offsets[0] as usize, union_array.len())) + } else { + // case 3: Since offsets are not sequential, take them from the child to a new sequential and correcly sized array + let indices = Int32Array::try_new(offsets.clone(), None)?; + + Ok(take(target, &indices, None)?) + } +} + +const EQ_SCALAR_CHUNK_SIZE: usize = 512; + +#[doc(hidden)] +#[derive(Debug, PartialEq)] +pub enum BoolValue { + Scalar(bool), + Buffer(BooleanBuffer), +} + +fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue { + eq_scalar_generic::(type_ids, target) +} + +// This is like MutableBuffer::collect_bool(type_ids.len(), |i| type_ids[i] == target) with fast paths for all true or all false values. +#[doc(hidden)] +pub fn eq_scalar_generic(type_ids: &[i8], target: i8) -> BoolValue { + fn count_sequence( + type_ids: &[i8], + mut f: impl FnMut(i8) -> bool, + ) -> usize { + type_ids + .chunks(N) + .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v))) + .map(|chunk| chunk.len()) + .sum() + } + + let true_bits = count_sequence::(type_ids, |v| v == target); + + let (set_bits, val) = if true_bits == type_ids.len() { + return BoolValue::Scalar(true); + } else if true_bits == 0 { + let false_bits = count_sequence::(type_ids, |v| v != target); + + if false_bits == type_ids.len() { + return BoolValue::Scalar(false); + } else { + (false_bits, false) + } + } else { + (true_bits, true) + }; + + // restrict to chunk boundaries + let set_bits = set_bits - set_bits % 64; + + let mut buffer = MutableBuffer::new(bit_util::ceil(type_ids.len(), 64) * 8) + .with_bitset(set_bits / 8, val); + + buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| { + chunk + .iter() + .copied() + .enumerate() + .fold(0, |packed, (bit_idx, v)| { + packed | ((v == target) as u64) << bit_idx + }) + })); + + buffer.truncate(bit_util::ceil(type_ids.len(), 8)); + + BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len())) +} + +const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64; + +fn is_sequential(offsets: &[i32]) -> bool { + is_sequential_generic::(offsets) +} + +#[doc(hidden)] +pub fn is_sequential_generic(offsets: &[i32]) -> bool { + if offsets.is_empty() { + return true; + } + + // the most common form of non sequential offsets is when sequential nulls reuses the same value, + // pointed by the same offset, while valid values offsets increases one by one + // this also checks if the last chunk/remainder is sequential + if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] { + return false; + } + + let chunks = offsets.chunks_exact(N); + + let remainder = chunks.remainder(); + + chunks.enumerate().all(|(i, chunk)| { + let chunk_array = <&[i32; N]>::try_from(chunk).unwrap(); + + //checks if values within chunk are sequential + chunk_array + .iter() + .copied() + .enumerate() + .fold(true, |b, (i, o)| b & (o == chunk_array[0] + i as i32)) + && offsets[0] + (i * N) as i32 == chunk_array[0] //checks if chunk is sequential relative to the first offset + }) && remainder + .iter() + .copied() + .enumerate() + .fold(true, |b, (i, o)| b & (o == remainder[0] + i as i32)) //if the remainder is sequential relative to the first offset is checked at the start of the function +} + +fn new_null_columnar_value(data_type: &DataType, len: usize) -> ColumnarValue { + match ScalarValue::try_from(data_type) { + Ok(null_scalar) => ColumnarValue::Scalar(null_scalar), + Err(_) => ColumnarValue::Array(new_null_array(data_type, len)), + } +} + +#[cfg(test)] +mod tests { + use crate::core::union_extract::{ + eq_scalar_generic, is_sequential_generic, new_null_columnar_value, BoolValue, + }; + + use std::sync::Arc; + + use arrow::array::{new_null_array, Array, Int8Array}; + use arrow::buffer::BooleanBuffer; + use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use super::UnionExtractFun; + + // when it becomes possible to construct union scalars in SQL, this should go to sqllogictests + #[test] + fn test_scalar_value() -> Result<()> { + let fun = UnionExtractFun::new(); + + let fields = UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ); + + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), + } + } + + let result = fun.invoke(&[ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ])?; + + assert_scalar(result, ScalarValue::Utf8(None)); + + let result = fun.invoke(&[ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ])?; + + assert_scalar(result, ScalarValue::Utf8(None)); + + let result = fun.invoke(&[ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ])?; + + assert_scalar(result, ScalarValue::new_utf8("42")); + + Ok(()) + } + + #[test] + fn test_eq_scalar() { + //multiple all equal chunks, so it's loop and sum logic it's tested + //multiple chunks after, so it's loop logic it's tested + const ARRAY_LEN: usize = 64 * 4; + + //so out of 64 boundaries chunks can be generated and checked for + const EQ_SCALAR_CHUNK_SIZE: usize = 3; + + fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue { + eq_scalar_generic::(type_ids, target) + } + + fn eq(left: &[i8], right: i8) -> BooleanBuffer { + arrow::compute::kernels::cmp::eq( + &Int8Array::from(left.to_vec()), + &Int8Array::new_scalar(right), + ) + .unwrap() + .into_parts() + .0 + } + + assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true)); + + assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true)); + assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false)); + + let mut values = [1; ARRAY_LEN]; + + assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true)); + assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false)); + + //every subslice should return the same value + for i in 1..ARRAY_LEN { + assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true)); + assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false)); + } + + // test that a single change anywhere is checked for + for i in 0..ARRAY_LEN { + values[i] = 2; + + assert_eq!(eq_scalar(&values, 1), BoolValue::Buffer(eq(&values, 1))); + assert_eq!(eq_scalar(&values, 2), BoolValue::Buffer(eq(&values, 2))); + + values[i] = 1; + } + } + + #[test] + fn test_is_sequential() { + /* + the smallest value that satisfies: + >1 so the fold logic of a exact chunk executes + >2 so a >1 non-exact remainder can exist, and it's fold logic executes + */ + const CHUNK_SIZE: usize = 3; + //we test arrays of size up to 8 = 2 * CHUNK_SIZE + 2: + //multiple(2) exact chunks, so the AND logic between them executes + //a >1(2) remainder, so: + // the AND logic between all exact chunks and the remainder executes + // the remainder fold logic executes + + fn is_sequential(v: &[i32]) -> bool { + is_sequential_generic::(v) + } + + assert!(is_sequential(&[])); //empty + assert!(is_sequential(&[1])); //single + + assert!(is_sequential(&[1, 2])); + assert!(is_sequential(&[1, 2, 3])); + assert!(is_sequential(&[1, 2, 3, 4])); + assert!(is_sequential(&[1, 2, 3, 4, 5])); + assert!(is_sequential(&[1, 2, 3, 4, 5, 6])); + assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7])); + assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8])); + + assert!(!is_sequential(&[8, 7])); + assert!(!is_sequential(&[8, 7, 6])); + assert!(!is_sequential(&[8, 7, 6, 5])); + assert!(!is_sequential(&[8, 7, 6, 5, 4])); + assert!(!is_sequential(&[8, 7, 6, 5, 4, 3])); + assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2])); + assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1])); + + assert!(!is_sequential(&[0, 2])); + assert!(!is_sequential(&[1, 0])); + + assert!(!is_sequential(&[0, 2, 3])); + assert!(!is_sequential(&[1, 0, 3])); + assert!(!is_sequential(&[1, 2, 0])); + + assert!(!is_sequential(&[0, 2, 3, 4])); + assert!(!is_sequential(&[1, 0, 3, 4])); + assert!(!is_sequential(&[1, 2, 0, 4])); + assert!(!is_sequential(&[1, 2, 3, 0])); + + assert!(!is_sequential(&[0, 2, 3, 4, 5])); + assert!(!is_sequential(&[1, 0, 3, 4, 5])); + assert!(!is_sequential(&[1, 2, 0, 4, 5])); + assert!(!is_sequential(&[1, 2, 3, 0, 5])); + assert!(!is_sequential(&[1, 2, 3, 4, 0])); + + assert!(!is_sequential(&[0, 2, 3, 4, 5, 6])); + assert!(!is_sequential(&[1, 0, 3, 4, 5, 6])); + assert!(!is_sequential(&[1, 2, 0, 4, 5, 6])); + assert!(!is_sequential(&[1, 2, 3, 0, 5, 6])); + assert!(!is_sequential(&[1, 2, 3, 4, 0, 6])); + assert!(!is_sequential(&[1, 2, 3, 4, 5, 0])); + + assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7])); + assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7])); + assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7])); + assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7])); + assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7])); + assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7])); + assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0])); + + assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8])); + assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8])); + assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8])); + assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8])); + assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8])); + assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8])); + assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8])); + assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0])); + } + + #[test] + fn test_new_null_columnar_value() { + match new_null_columnar_value(&DataType::Int8, 2) { + ColumnarValue::Array(_) => { + panic!("new_null_columnar_value should've returned a scalar for Int8") + } + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, ScalarValue::Int8(None)), + } + + let run_data_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Utf8, false)), + ); + + match new_null_columnar_value(&run_data_type, 2) { + ColumnarValue::Array(array) => assert_eq!( + array.into_data(), + new_null_array(&run_data_type, 2).into_data() + ), + ColumnarValue::Scalar(_) => panic!( + "new_null_columnar_value should've returned a array for RunEndEncoded" + ), + } + } +} diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 224a0e18eac4..0e4b3b782f58 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -22,10 +22,14 @@ use std::path::Path; use std::sync::Arc; use arrow::array::{ - ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, LargeStringArray, - StringArray, TimestampNanosecondArray, + new_null_array, Array, ArrayRef, BinaryArray, BooleanArray, Float64Array, Int32Array, + LargeBinaryArray, LargeStringArray, NullArray, StringArray, TimestampNanosecondArray, + UnionArray, +}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::{ + DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; use datafusion::physical_plan::ExecutionPlan; @@ -108,6 +112,10 @@ impl TestContext { info!("Registering metadata table tables"); register_metadata_tables(test_ctx.session_ctx()).await; } + "union_datatype.slt" => { + info!("Registering tables with union column"); + register_union_tables(test_ctx.session_ctx()) + } _ => { info!("Using default SessionContext"); } @@ -361,3 +369,893 @@ fn create_example_udf() -> ScalarUDF { adder, ) } + +fn register_union_tables(ctx: &SessionContext) { + sparse_1_1_single_field(ctx); + sparse_1_2_empty(ctx); + sparse_1_3a_null_target(ctx); + sparse_1_3b_null_target(ctx); + sparse_2_all_types_match(ctx); + sparse_3_1_none_match_target_can_contain_null_mask(ctx); + sparse_3_2_none_match_cant_contain_null_mask_union_target(ctx); + sparse_4_1_1_target_with_nulls(ctx); + sparse_4_1_2_target_without_nulls(ctx); + sparse_4_2_some_match_target_cant_contain_null_mask(ctx); + dense_1_1_both_empty(ctx); + dense_1_2_empty_union_target_non_empty(ctx); + dense_2_non_empty_union_target_empty(ctx); + dense_3_1_null_target_smaller_len(ctx); + dense_3_2_null_target_equal_len(ctx); + dense_3_3_null_target_bigger_len(ctx); + dense_4_1a_single_type_sequential_offsets_equal_len(ctx); + dense_4_2a_single_type_sequential_offsets_bigger(ctx); + dense_4_3a_single_type_non_sequential(ctx); + dense_4_1b_empty_siblings_sequential_equal_len(ctx); + dense_4_2b_empty_siblings_sequential_bigger_len(ctx); + dense_4_3b_empty_sibling_non_sequential(ctx); + dense_4_1c_all_types_match_sequential_equal_len(ctx); + dense_4_2c_all_types_match_sequential_bigger_len(ctx); + dense_4_3c_all_types_match_non_sequential(ctx); + dense_5_1a_none_match_less_len(ctx); + dense_5_1b_cant_contain_null_mask(ctx); + dense_5_2_none_match_equal_len(ctx); + dense_5_3_none_match_greater_len(ctx); + dense_6_some_matches(ctx); + empty_sparse_union(ctx); + empty_dense_union(ctx); +} + +fn register_union_table( + ctx: &SessionContext, + union: UnionArray, + table_name: &str, + expected: impl Array + 'static, +) { + let schema = Schema::new(vec![ + Field::new("my_union", union.data_type().clone(), false), + Field::new("expected", expected.data_type().clone(), true), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(union), Arc::new(expected)], + ) + .unwrap(); + + ctx.register_batch(table_name, batch).unwrap(); +} + +fn sparse_1_1_single_field(ctx: &SessionContext) { + let union = UnionArray::try_new( + //single field + UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), + ScalarBuffer::from(vec![3, 3]), // non empty, every type id must match + None, //sparse + vec![ + Arc::new(Int32Array::from(vec![1, 2])), // not null + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "sparse_1_1_single_field", + Int32Array::from(vec![1, 2]), + ); +} + +fn sparse_1_2_empty(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), //target type is not Null + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![]), //empty union + None, // sparse + vec![ + Arc::new(StringArray::new_null(0)), + Arc::new(Int32Array::new_null(0)), + ], + ) + .unwrap(); + + register_union_table(ctx, union, "sparse_1_2_empty", StringArray::new_null(0)); +} + +fn sparse_1_3a_null_target(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("null", DataType::Null, true), + ], + ), + ScalarBuffer::from(vec![1]), //not empty + None, // sparse + vec![ + Arc::new(StringArray::new_null(1)), + Arc::new(NullArray::new(1)), // null data type + ], + ) + .unwrap(); + + register_union_table(ctx, union, "sparse_1_3a_null_target", NullArray::new(1)); +} + +fn sparse_1_3b_null_target(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), //target type is not Null + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1]), //not empty + None, // sparse + vec![ + Arc::new(StringArray::new_null(1)), //all null + Arc::new(Int32Array::new_null(1)), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "sparse_1_3b_null_target", + StringArray::new_null(1), + ); +} + +fn sparse_2_all_types_match(ctx: &SessionContext) { + let union = UnionArray::try_new( + //multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3]), // all types match + None, //sparse + vec![ + Arc::new(StringArray::new_null(2)), + Arc::new(Int32Array::from(vec![1, 4])), // not null + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "sparse_2_all_types_match", + Int32Array::from(vec![1, 4]), + ); +} + +fn sparse_3_1_none_match_target_can_contain_null_mask(ctx: &SessionContext) { + let union = UnionArray::try_new( + //multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1, 1, 1, 1]), // none match + None, // sparse + vec![ + Arc::new(StringArray::new_null(4)), + Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target is not null + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "sparse_3_1_none_match_target_can_contain_null_mask", + Int32Array::new_null(4), + ); +} + +fn sparse_3_2_none_match_cant_contain_null_mask_union_target(ctx: &SessionContext) { + let target_fields = + UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); + + let target_data_type = DataType::Union(target_fields.clone(), UnionMode::Sparse); + + let union = UnionArray::try_new( + //multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("union", target_data_type.clone(), false), + ], + ), + ScalarBuffer::from(vec![1, 1]), // none match + None, //sparse + vec![ + Arc::new(StringArray::new_null(2)), + //target is not null + Arc::new( + UnionArray::try_new( + target_fields.clone(), + ScalarBuffer::from(vec![10, 10]), + None, + vec![Arc::new(BooleanArray::from(vec![true, false]))], + ) + .unwrap(), + ), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "sparse_3_2_none_match_cant_contain_null_mask_union_target", + new_null_array(&target_data_type, 2), + ); +} + +fn sparse_4_1_1_target_with_nulls(ctx: &SessionContext) { + let union = UnionArray::try_new( + //multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3, 1, 1]), // multiple selected types + None, // sparse + vec![ + Arc::new(StringArray::new_null(4)), + Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target with nulls + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "sparse_4_1_1_target_with_nulls", + Int32Array::from(vec![None, Some(4), None, None]), + ); +} + +fn sparse_4_1_2_target_without_nulls(ctx: &SessionContext) { + let union = UnionArray::try_new( + //multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1, 3, 3]), // multiple selected types + None, // sparse + vec![ + Arc::new(StringArray::new_null(3)), + Arc::new(Int32Array::from(vec![2, 4, 8])), // target without nulls + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "sparse_4_1_2_target_without_nulls", + Int32Array::from(vec![None, Some(4), Some(8)]), + ); +} + +fn sparse_4_2_some_match_target_cant_contain_null_mask(ctx: &SessionContext) { + let target_fields = + UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); + + let union = UnionArray::try_new( + //multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new( + "union", + DataType::Union(target_fields.clone(), UnionMode::Sparse), + false, + ), + ], + ), + ScalarBuffer::from(vec![3, 1]), // some types match, but not all + None, //sparse + vec![ + Arc::new(StringArray::new_null(2)), + Arc::new( + UnionArray::try_new( + target_fields.clone(), + ScalarBuffer::from(vec![10, 10]), + None, + vec![Arc::new(BooleanArray::from(vec![true, false]))], + ) + .unwrap(), + ), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "sparse_4_2_some_match_target_cant_contain_null_mask", + UnionArray::try_new( + target_fields, + ScalarBuffer::from(vec![10, 10]), + None, + vec![Arc::new(BooleanArray::from(vec![Some(true), None]))], + ) + .unwrap(), + ); +} + +fn dense_1_1_both_empty(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![]), //empty union + Some(ScalarBuffer::from(vec![])), // dense + vec![ + Arc::new(StringArray::new_null(0)), //empty target + Arc::new(Int32Array::new_null(0)), + ], + ) + .unwrap(); + + register_union_table(ctx, union, "dense_1_1_both_empty", StringArray::new_null(0)); +} + +fn dense_1_2_empty_union_target_non_empty(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![]), //empty union + Some(ScalarBuffer::from(vec![])), // dense + vec![ + Arc::new(StringArray::new_null(1)), //non empty target + Arc::new(Int32Array::new_null(0)), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_1_2_empty_union_target_non_empty", + StringArray::new_null(0), + ); +} + +fn dense_2_non_empty_union_target_empty(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3]), //non empty union + Some(ScalarBuffer::from(vec![0, 1])), // dense + vec![ + Arc::new(StringArray::new_null(0)), //empty target + Arc::new(Int32Array::new_null(2)), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_2_non_empty_union_target_empty", + StringArray::new_null(2), + ); +} + +fn dense_3_1_null_target_smaller_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3]), //non empty union + Some(ScalarBuffer::from(vec![0, 0])), //dense + vec![ + Arc::new(StringArray::new_null(1)), //smaller target + Arc::new(Int32Array::new_null(2)), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_3_1_null_target_smaller_len", + StringArray::new_null(2), + ); +} + +fn dense_3_2_null_target_equal_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3]), //non empty union + Some(ScalarBuffer::from(vec![0, 0])), //dense + vec![ + Arc::new(StringArray::new_null(2)), //equal len + Arc::new(Int32Array::new_null(2)), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_3_2_null_target_equal_len", + StringArray::new_null(2), + ); +} + +fn dense_3_3_null_target_bigger_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3]), //non empty union + Some(ScalarBuffer::from(vec![0, 0])), //dense + vec![ + Arc::new(StringArray::new_null(3)), //bigger len + Arc::new(Int32Array::new_null(3)), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_3_3_null_target_bigger_len", + StringArray::new_null(2), + ); +} + +fn dense_4_1a_single_type_sequential_offsets_equal_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + // single field + UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, false)]), + ScalarBuffer::from(vec![1, 1]), //non empty union + Some(ScalarBuffer::from(vec![0, 1])), //sequential + vec![ + Arc::new(StringArray::from_iter_values(["a1", "b2"])), //equal len, non null + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_1a_single_type_sequential_offsets_equal_len", + StringArray::from_iter_values(["a1", "b2"]), + ); +} + +fn dense_4_2a_single_type_sequential_offsets_bigger(ctx: &SessionContext) { + let union = UnionArray::try_new( + // single field + UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, false)]), + ScalarBuffer::from(vec![1, 1]), //non empty union + Some(ScalarBuffer::from(vec![0, 1])), //sequential + vec![ + Arc::new(StringArray::from_iter_values(["a1", "b2", "c3"])), //equal len, non null + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_2a_single_type_sequential_offsets_bigger", + StringArray::from_iter_values(["a1", "b2"]), + ); +} + +fn dense_4_3a_single_type_non_sequential(ctx: &SessionContext) { + let union = UnionArray::try_new( + // single field + UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, false)]), + ScalarBuffer::from(vec![1, 1]), //non empty union + Some(ScalarBuffer::from(vec![0, 2])), //non sequential + vec![ + Arc::new(StringArray::from_iter_values(["a1", "b2", "c3"])), //equal len, non null + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_3a_single_type_non_sequential", + StringArray::from_iter_values(["a1", "c3"]), + ); +} + +fn dense_4_1b_empty_siblings_sequential_equal_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1, 1]), //non empty union + Some(ScalarBuffer::from(vec![0, 1])), //sequential + vec![ + Arc::new(StringArray::from(vec!["a", "b"])), //equal len, non null + Arc::new(Int32Array::new_null(0)), //empty sibling + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_1b_empty_siblings_sequential_equal_len", + StringArray::from(vec!["a", "b"]), + ); +} + +fn dense_4_2b_empty_siblings_sequential_bigger_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1, 1]), //non empty union + Some(ScalarBuffer::from(vec![0, 1])), //sequential + vec![ + Arc::new(StringArray::from(vec!["a", "b", "c"])), //bigger len, non null + Arc::new(Int32Array::new_null(0)), //empty sibling + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_2b_empty_siblings_sequential_bigger_len", + StringArray::from(vec!["a", "b"]), + ); +} + +fn dense_4_3b_empty_sibling_non_sequential(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1, 1]), //non empty union + Some(ScalarBuffer::from(vec![0, 2])), //non sequential + vec![ + Arc::new(StringArray::from(vec!["a", "b", "c"])), //non null + Arc::new(Int32Array::new_null(0)), //empty sibling + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_3b_empty_sibling_non_sequential", + StringArray::from(vec!["a", "c"]), + ); +} + +fn dense_4_1c_all_types_match_sequential_equal_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1, 1]), //all types match + Some(ScalarBuffer::from(vec![0, 1])), //sequential + vec![ + Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len + Arc::new(Int32Array::new_null(2)), //non empty sibling + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_1c_all_types_match_sequential_equal_len", + StringArray::from(vec!["a1", "b2"]), + ); +} + +fn dense_4_2c_all_types_match_sequential_bigger_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1, 1]), //all types match + Some(ScalarBuffer::from(vec![0, 1])), //sequential + vec![ + Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), //bigger len + Arc::new(Int32Array::new_null(2)), //non empty sibling + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_2c_all_types_match_sequential_bigger_len", + StringArray::from(vec!["a1", "b2"]), + ); +} + +fn dense_4_3c_all_types_match_non_sequential(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![1, 1]), //all types match + Some(ScalarBuffer::from(vec![0, 2])), //non sequential + vec![ + Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), + Arc::new(Int32Array::new_null(2)), //non empty sibling + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_4_3c_all_types_match_non_sequential", + StringArray::from(vec!["a1", "b3"]), + ); +} + +fn dense_5_1a_none_match_less_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches + Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense + vec![ + Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len + Arc::new(Int32Array::from(vec![1, 2])), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_5_1a_none_match_less_len", + StringArray::new_null(5), + ); +} + +fn dense_5_1b_cant_contain_null_mask(ctx: &SessionContext) { + let target_fields = + UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); + + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new( + "union", + DataType::Union(target_fields.clone(), UnionMode::Sparse), + false, + ), + ], + ), + ScalarBuffer::from(vec![1, 1, 1, 1, 1]), //none matches + Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense + vec![ + Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len + Arc::new( + UnionArray::try_new( + target_fields.clone(), + ScalarBuffer::from(vec![10]), + None, + vec![Arc::new(BooleanArray::from(vec![true]))], + ) + .unwrap(), + ), // non empty + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_5_1b_cant_contain_null_mask", + new_null_array(&DataType::Union(target_fields, UnionMode::Sparse), 5), + ); +} + +fn dense_5_2_none_match_equal_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches + Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense + vec![ + Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), // equal len + Arc::new(Int32Array::from(vec![1, 2])), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_5_2_none_match_equal_len", + StringArray::new_null(5), + ); +} + +fn dense_5_3_none_match_greater_len(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches + Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense + vec![ + Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), // greater len + Arc::new(Int32Array::from(vec![1, 2])), //non null + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_5_3_none_match_greater_len", + StringArray::new_null(5), + ); +} + +fn dense_6_some_matches(ctx: &SessionContext) { + let union = UnionArray::try_new( + // multiple fields + UnionFields::new( + vec![1, 3], + vec![ + Field::new("str", DataType::Utf8, false), + Field::new("int", DataType::Int32, false), + ], + ), + ScalarBuffer::from(vec![3, 3, 1, 1, 1]), //some matches + Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), // dense + vec![ + Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // non null + Arc::new(Int32Array::from(vec![1, 2])), + ], + ) + .unwrap(); + + register_union_table( + ctx, + union, + "dense_6_some_matches", + Int32Array::from(vec![Some(1), Some(2), None, None, None]), + ); +} + +fn empty_sparse_union(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::empty(), + ScalarBuffer::from(vec![]), + None, + vec![], + ) + .unwrap(); + + register_union_table(ctx, union, "empty_sparse_union", NullArray::new(0)) +} + +fn empty_dense_union(ctx: &SessionContext) { + let union = UnionArray::try_new( + UnionFields::empty(), + ScalarBuffer::from(vec![]), + Some(ScalarBuffer::from(vec![])), + vec![], + ) + .unwrap(); + + register_union_table(ctx, union, "empty_dense_union", NullArray::new(0)) +} diff --git a/datafusion/sqllogictest/test_files/union_datatype.slt b/datafusion/sqllogictest/test_files/union_datatype.slt new file mode 100644 index 000000000000..12f6fa6bfca8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/union_datatype.slt @@ -0,0 +1,270 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## UNION DataType Tests +########## + + +query II? +select union_extract(my_union, 'int'), expected, my_union from sparse_1_1_single_field; +---- +1 1 {int=1} +2 2 {int=2} + +query IT? +select union_extract(my_union, 'int'), expected, my_union from sparse_1_2_empty; +---- + +query ??? +select union_extract(my_union, 'null'), expected, my_union from sparse_1_3a_null_target; +---- +NULL NULL {str=} + +query IT? +select union_extract(my_union, 'int'), expected, my_union from sparse_1_3b_null_target; +---- +NULL NULL {str=} + +query II? +select union_extract(my_union, 'int'), expected, my_union from sparse_2_all_types_match; +---- +1 1 {int=1} +4 4 {int=4} + +query II? +select union_extract(my_union, 'int'), expected, my_union from sparse_3_1_none_match_target_can_contain_null_mask; +---- +NULL NULL {str=} +NULL NULL {str=} +NULL NULL {str=} +NULL NULL {str=} + +query ??? +select union_extract(my_union, 'union'), expected, my_union from sparse_3_2_none_match_cant_contain_null_mask_union_target; +---- +{bool=} {bool=} {str=} +{bool=} {bool=} {str=} + +query II? +select union_extract(my_union, 'int'), expected, my_union from sparse_4_1_1_target_with_nulls; +---- +NULL NULL {int=} +4 4 {int=4} +NULL NULL {str=} +NULL NULL {str=} + +query II? +select union_extract(my_union, 'int'), expected, my_union from sparse_4_1_2_target_without_nulls; +---- +NULL NULL {str=} +4 4 {int=4} +8 8 {int=8} + +query ??? +select union_extract(my_union, 'union'), expected, my_union from sparse_4_2_some_match_target_cant_contain_null_mask; +---- +{bool=true} {bool=true} {union={bool=true}} +{bool=} {bool=} {str=} + +query IT? +select union_extract(my_union, 'int'), expected, my_union from dense_1_1_both_empty; +---- + +query IT? +select union_extract(my_union, 'int'), expected, my_union from dense_1_2_empty_union_target_non_empty; +---- + +query IT? +select union_extract(my_union, 'int'), expected, my_union from dense_2_non_empty_union_target_empty; +---- +NULL NULL {int=} +NULL NULL {int=} + +query IT? +select union_extract(my_union, 'int'), expected, my_union from dense_3_1_null_target_smaller_len; +---- +NULL NULL {int=} +NULL NULL {int=} + +query IT? +select union_extract(my_union, 'int'), expected, my_union from dense_3_2_null_target_equal_len; +---- +NULL NULL {int=} +NULL NULL {int=} + +query IT? +select union_extract(my_union, 'int'), expected, my_union from dense_3_3_null_target_bigger_len; +---- +NULL NULL {int=} +NULL NULL {int=} + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_1a_single_type_sequential_offsets_equal_len; +---- +a1 a1 {str=a1} +b2 b2 {str=b2} + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_2a_single_type_sequential_offsets_bigger; +---- +a1 a1 {str=a1} +b2 b2 {str=b2} + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_3a_single_type_non_sequential; +---- +a1 a1 {str=a1} +c3 c3 {str=c3} + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_1b_empty_siblings_sequential_equal_len; +---- +a a {str=a} +b b {str=b} + + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_2b_empty_siblings_sequential_bigger_len; +---- +a a {str=a} +b b {str=b} + + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_3b_empty_sibling_non_sequential; +---- +a a {str=a} +c c {str=c} + + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_1c_all_types_match_sequential_equal_len; +---- +a1 a1 {str=a1} +b2 b2 {str=b2} + + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_2c_all_types_match_sequential_bigger_len; +---- +a1 a1 {str=a1} +b2 b2 {str=b2} + + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_4_3c_all_types_match_non_sequential; +---- +a1 a1 {str=a1} +b3 b3 {str=b3} + + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_5_1a_none_match_less_len; +---- +NULL NULL {int=1} +NULL NULL {int=1} +NULL NULL {int=1} +NULL NULL {int=2} +NULL NULL {int=2} + + +query ??? +select union_extract(my_union, 'union'), expected, my_union from dense_5_1b_cant_contain_null_mask; +---- +{bool=} {bool=} {str=a1} +{bool=} {bool=} {str=a1} +{bool=} {bool=} {str=a1} +{bool=} {bool=} {str=b2} +{bool=} {bool=} {str=b2} + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_5_2_none_match_equal_len; +---- +NULL NULL {int=1} +NULL NULL {int=1} +NULL NULL {int=1} +NULL NULL {int=2} +NULL NULL {int=2} + + +query TT? +select union_extract(my_union, 'str'), expected, my_union from dense_5_3_none_match_greater_len; +---- +NULL NULL {int=1} +NULL NULL {int=1} +NULL NULL {int=1} +NULL NULL {int=2} +NULL NULL {int=2} + + +query II? +select union_extract(my_union, 'int'), expected, my_union from dense_6_some_matches; +---- +1 1 {int=1} +2 2 {int=2} +NULL NULL {str=a1} +NULL NULL {str=b2} +NULL NULL {str=c3} + + +query error DataFusion error: Execution error: field int not found on union +select union_extract(my_union, 'int'), expected, my_union from empty_sparse_union; + + +query error DataFusion error: Execution error: field int not found on union +select union_extract(my_union, 'int'), expected, my_union from empty_dense_union; + + +query error +select union_extract() from empty_dense_union; +---- +DataFusion error: Error during planning: Error during planning: union_extract does not support zero arguments. No function matches the given name and argument types 'union_extract()'. You might need to add explicit type casts. + Candidate functions: + union_extract(Any, Any) + + +query error +select union_extract(my_union) from empty_dense_union; +---- +DataFusion error: Error during planning: Error during planning: The function expected 2 arguments but received 1 No function matches the given name and argument types 'union_extract(Union([], Dense))'. You might need to add explicit type casts. + Candidate functions: + union_extract(Any, Any) + + +query error +select union_extract('a') from empty_dense_union; +---- +DataFusion error: Error during planning: Error during planning: The function expected 2 arguments but received 1 No function matches the given name and argument types 'union_extract(Utf8)'. You might need to add explicit type casts. + Candidate functions: + union_extract(Any, Any) + + +query error DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead +select union_extract('a', my_union) from empty_dense_union; + + +query error DataFusion error: Execution error: union_extract second argument must be a non\-null string literal, got Int64 instead +select union_extract(my_union, 1) from empty_dense_union; + + +query error +select union_extract(my_union, 'a', 'b') from empty_dense_union; +---- +DataFusion error: Error during planning: Error during planning: The function expected 2 arguments but received 3 No function matches the given name and argument types 'union_extract(Union([], Dense), Utf8, Utf8)'. You might need to add explicit type casts. + Candidate functions: + union_extract(Any, Any) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c7b3409ba7cd..2ab1a6e8f7bb 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3827,6 +3827,61 @@ sha512(expression) - **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +## Union Functions + +- [union_extract](#union_extract) + +### `union_extract` + +Returns the value of the given field when selected, or NULL otherwise. + +``` +union_extract(union, field_name) +``` + +#### Arguments + +- **union**: Union expression to extract the field from +- **field_name**: Literal string, name of the field to extract + +``` +❯ select my_union from table_with_union; ++----------+ +| my_union | ++----------+ +| {a=1} | +| {b=3.0} | +| {a=4} | +| {b=} | +| {a=} | ++----------+ + +❯ select union_extract(my_union, 'a'); ++------------------------------+ +| union_extract(my_union, 'a') | ++------------------------------+ +| 1 | +| | +| 4 | +| | +| | ++------------------------------+ + +❯ select union_extract(my_union, 'b'); ++------------------------------+ +| union_extract(my_union, 'b') | ++------------------------------+ +| | +| 3.0 | +| | +| | +| | ++------------------------------+ + + +``` + + ## Other Functions - [arrow_cast](#arrow_cast) From 4f34083284e17b401063e42b912f95ec8787c6bd Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Thu, 22 Aug 2024 23:01:25 +0000 Subject: [PATCH 02/15] fix: docs fmt, add clippy atr, sql error msg --- datafusion/functions/benches/union_extract.rs | 24 ++++++++++++++++--- .../test_files/union_datatype.slt | 17 +------------ .../source/user-guide/sql/scalar_functions.md | 1 - 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/datafusion/functions/benches/union_extract.rs b/datafusion/functions/benches/union_extract.rs index 1e748ff5a0c0..ceea5fa8e039 100644 --- a/datafusion/functions/benches/union_extract.rs +++ b/datafusion/functions/benches/union_extract.rs @@ -1,7 +1,22 @@ -#[macro_use] +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + extern crate criterion; -use crate::criterion::Criterion; use arrow::{ array::{ Array, BooleanArray, Int32Array, Int8Array, NullArray, StringArray, UnionArray, @@ -12,7 +27,7 @@ use arrow::{ }, }; use arrow_buffer::ScalarBuffer; -use criterion::black_box; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use datafusion_functions::core::union_extract::{ @@ -965,6 +980,7 @@ fn criterion_benchmark(c: &mut Criterion) { is_sequential_group.bench_function("offsets sequential windows fold &&", |b| { b.iter(|| { black_box( + #[allow(clippy::unnecessary_fold)] offsets .windows(2) .fold(true, |b, w| b && (w[0] + 1 == w[1])), @@ -993,6 +1009,7 @@ fn criterion_benchmark(c: &mut Criterion) { is_sequential_group.bench_function("offsets sequential fold &&", |b| { b.iter(|| { black_box( + #[allow(clippy::unnecessary_fold)] offsets .iter() .copied() @@ -1055,6 +1072,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); type_ids_eq.bench_function("type_ids equal fold &&", |b| { + #[allow(clippy::unnecessary_fold)] b.iter(|| type_ids.iter().fold(true, |b, v| b && (*v == type_id))) }); diff --git a/datafusion/sqllogictest/test_files/union_datatype.slt b/datafusion/sqllogictest/test_files/union_datatype.slt index 12f6fa6bfca8..10e1b692ba0b 100644 --- a/datafusion/sqllogictest/test_files/union_datatype.slt +++ b/datafusion/sqllogictest/test_files/union_datatype.slt @@ -232,26 +232,15 @@ select union_extract(my_union, 'int'), expected, my_union from empty_dense_union query error select union_extract() from empty_dense_union; ----- -DataFusion error: Error during planning: Error during planning: union_extract does not support zero arguments. No function matches the given name and argument types 'union_extract()'. You might need to add explicit type casts. - Candidate functions: - union_extract(Any, Any) query error select union_extract(my_union) from empty_dense_union; ----- -DataFusion error: Error during planning: Error during planning: The function expected 2 arguments but received 1 No function matches the given name and argument types 'union_extract(Union([], Dense))'. You might need to add explicit type casts. - Candidate functions: - union_extract(Any, Any) query error select union_extract('a') from empty_dense_union; ----- -DataFusion error: Error during planning: Error during planning: The function expected 2 arguments but received 1 No function matches the given name and argument types 'union_extract(Utf8)'. You might need to add explicit type casts. - Candidate functions: - union_extract(Any, Any) + query error DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead @@ -264,7 +253,3 @@ select union_extract(my_union, 1) from empty_dense_union; query error select union_extract(my_union, 'a', 'b') from empty_dense_union; ----- -DataFusion error: Error during planning: Error during planning: The function expected 2 arguments but received 3 No function matches the given name and argument types 'union_extract(Union([], Dense), Utf8, Utf8)'. You might need to add explicit type casts. - Candidate functions: - union_extract(Any, Any) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 2ab1a6e8f7bb..3b6ff310c35c 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3881,7 +3881,6 @@ union_extract(union, field_name) ``` - ## Other Functions - [arrow_cast](#arrow_cast) From 7250a2a134466b6c23ea63702e91a65d309363e3 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Thu, 12 Sep 2024 13:05:54 -0300 Subject: [PATCH 03/15] use arrow-rs implementation --- datafusion/functions/Cargo.toml | 5 - datafusion/functions/benches/union_extract.rs | 1139 ----------------- .../functions/src/core/union_extract.rs | 554 +------- datafusion/sqllogictest/src/test_context.rs | 844 +----------- .../test_files/union_datatype.slt | 255 ---- .../test_files/union_function.slt | 97 ++ 6 files changed, 151 insertions(+), 2743 deletions(-) delete mode 100644 datafusion/functions/benches/union_extract.rs delete mode 100644 datafusion/sqllogictest/test_files/union_datatype.slt create mode 100644 datafusion/sqllogictest/test_files/union_function.slt diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 3776e1c17590..337379a74670 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -162,11 +162,6 @@ harness = false name = "random" required-features = ["math_expressions"] -[[bench]] -harness = false -name = "union_extract" -required-features = ["core_expressions"] - [[bench]] harness = false name = "substr" diff --git a/datafusion/functions/benches/union_extract.rs b/datafusion/functions/benches/union_extract.rs deleted file mode 100644 index ceea5fa8e039..000000000000 --- a/datafusion/functions/benches/union_extract.rs +++ /dev/null @@ -1,1139 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -extern crate criterion; - -use arrow::{ - array::{ - Array, BooleanArray, Int32Array, Int8Array, NullArray, StringArray, UnionArray, - }, - datatypes::{DataType, Field, Int32Type, Int8Type, UnionFields, UnionMode}, - util::bench_util::{ - create_boolean_array, create_primitive_array, create_string_array, - }, -}; -use arrow_buffer::ScalarBuffer; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_common::ScalarValue; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; -use datafusion_functions::core::union_extract::{ - eq_scalar_generic, is_sequential_generic, UnionExtractFun, -}; -use itertools::repeat_n; -use rand::random; -use std::sync::Arc; - -fn criterion_benchmark(c: &mut Criterion) { - let union_extract = UnionExtractFun::new(); - - c.bench_function("union_extract case 1.1 sparse single field", |b| { - let union = UnionArray::try_new( - UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, false)]), //single field - ScalarBuffer::from(vec![1; 2048]), //non empty union - None, //sparse - vec![Arc::new(create_string_array::(2048, 0.0))], //non null target - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function("union_extract case 1.2 sparse empty union", |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 2], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("str2", DataType::Utf8, false), - ], - ), - ScalarBuffer::from(vec![]), // empty union - None, //sparse - vec![ - Arc::new(StringArray::new_null(0)), - Arc::new(StringArray::new_null(0)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function("union_extract case 1.3a sparse child null", |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("null", DataType::Null, true), - ], - ), - ScalarBuffer::from(vec![1; 2048]), // non empty union - None, //sparse - vec![ - Arc::new(StringArray::new_null(2048)), // null target - Arc::new(NullArray::new(2048)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("null")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function("union_extract case 1.3b sparse child null", |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1; 2048]), // non empty union - None, //sparse - vec![ - Arc::new(StringArray::new_null(2048)), // null target - Arc::new(create_primitive_array::(2048, 0.0)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function("union_extract case 2 sparse all types match", |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1; 2048]), //all types match & non empty union - None, //sparse - vec![ - Arc::new(create_string_array::(2048, 0.0)), //non null target - Arc::new(Int32Array::new_null(2048)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function( - "union_extract case 3.1 none selected target can contain null mask", - |b| { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), //none selected - None, - vec![ - Arc::new(create_string_array::(2048, 0.5)), - Arc::new(Int32Array::new_null(2048)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 3.2 none matches sparse cant contain null mask", - |b| { - let target_fields = - UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); - - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, true), - Field::new( - "union", - DataType::Union(target_fields.clone(), UnionMode::Sparse), - false, - ), - ], - ), - ScalarBuffer::from_iter(repeat_n(1, 2048)), //none matches - None, //sparse - vec![ - Arc::new(create_string_array::(2048, 0.5)), - Arc::new( - UnionArray::try_new( - target_fields, - ScalarBuffer::from(vec![10; 2048]), - None, - vec![Arc::new(BooleanArray::from(vec![true; 2048]))], - ) - .unwrap(), - ), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("union")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.1.1 sparse some matches target with nulls", - |b| { - let union = UnionArray::try_new( - //multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from_iter(repeat_n(1, 2047).chain([3])), //multiple types - None, //sparse - vec![ - Arc::new(create_string_array::(2048, 0.5)), //target with some nulls, but not all - Arc::new(Int32Array::new_null(2048)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.1.2 sparse some matches target without nulls", - |b| { - let union = UnionArray::try_new( - //multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from_iter(repeat_n(1, 2047).chain([3])), //multiple types - None, //sparse - vec![ - Arc::new(create_string_array::(2048, 0.0)), //target without nulls - Arc::new(Int32Array::new_null(2048)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.2 some matches sparse cant contain null mask", - |b| { - let target_fields = - UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); - - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new( - "union", - DataType::Union(target_fields.clone(), UnionMode::Sparse), - false, - ), - ], - ), - ScalarBuffer::from_iter(repeat_n([1, 3], 1024).flatten()), //some matches - None, //sparse - vec![ - Arc::new(NullArray::new(2048)), //null target - Arc::new( - UnionArray::try_new( - target_fields, - ScalarBuffer::from(vec![10; 2048]), - None, - vec![Arc::new(BooleanArray::from(vec![true; 2048]))], - ) - .unwrap(), - ), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("union")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 1.1 dense empty union empty target", - |b| { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![]), //empty union - Some(ScalarBuffer::from(vec![])), //dense - vec![ - Arc::new(StringArray::new_null(0)), //empty target - Arc::new(Int32Array::new_null(0)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 1.2 dense empty union non-empty target", - |b| { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![]), // empty union - Some(ScalarBuffer::from(vec![])), // dense - vec![ - Arc::new(StringArray::from(vec!["a1", "s2"])), // non empty target - Arc::new(Int32Array::new_null(0)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 2 dense non empty union, empty target", - |b| { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3]), // non empty union - Some(ScalarBuffer::from(vec![0, 1])), // dense - vec![ - Arc::new(StringArray::new_null(0)), // empty target - Arc::new(Int32Array::new_null(2)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 3.1 dense null target len smaller", - |b| { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1; 2048]), - Some(ScalarBuffer::from(vec![0; 2048])), // dense - vec![ - Arc::new(StringArray::new_null(1)), // null & len smaller - Arc::new(Int32Array::new_null(64)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function("union_extract case 3.2 dense null target len equal", |b| { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1; 2048]), - Some(ScalarBuffer::from_iter(0..2048)), // dense - vec![ - Arc::new(StringArray::new_null(2048)), // null & same len as parent - Arc::new(Int32Array::new_null(64)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function("union_extract case 3.3 dense null target len bigger", |b| { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1; 2048]), - Some(ScalarBuffer::from(vec![0; 2048])), - vec![ - Arc::new(StringArray::new_null(4096)), // null, bigger than parent - Arc::new(Int32Array::new_null(64)), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function( - "union_extract case 4.1A dense single field sequential offsets equal lens", - |b| { - let union = UnionArray::try_new( - //single field - UnionFields::new(vec![3], vec![Field::new("int", DataType::Int8, false)]), - ScalarBuffer::from(vec![3; 2048]), - Some(ScalarBuffer::from_iter(0..2048)), //sequential offsets - vec![Arc::new(create_primitive_array::(2048, 0.0))], //same len as parent, not null - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.2A dense single field sequential offsets bigger len", - |b| { - let union = UnionArray::try_new( - // single field - UnionFields::new(vec![3], vec![Field::new("int", DataType::Int8, false)]), - ScalarBuffer::from(vec![3; 2048]), - Some(ScalarBuffer::from_iter(0..2048)), //sequential offsets - vec![Arc::new(create_primitive_array::(4096, 0.0))], //bigger than parent, not null - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.3A dense single field non-sequential offsets", - |b| { - let union = UnionArray::try_new( - // single field - UnionFields::new(vec![3], vec![Field::new("int", DataType::Int8, false)]), - ScalarBuffer::from(vec![3; 2048]), - Some(ScalarBuffer::from_iter((0..2046).chain([2047, 2047]))), // non sequential offsets, avoid fast paths - vec![Arc::new(create_primitive_array::(2048, 0.0))], // not null - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.1B dense empty siblings sequential offsets equal len", - |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int8, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), // all types must match - Some(ScalarBuffer::from_iter(0..2048)), // sequential offsets - vec![ - Arc::new(StringArray::new_null(0)), // empty sibling - Arc::new(create_primitive_array::(2048, 0.0)), // same len as parent, not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.2B dense empty siblings sequential offsets bigger target", - |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int8, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), // all types match - Some(ScalarBuffer::from_iter(0..2048)), // sequential offsets - vec![ - Arc::new(StringArray::new_null(0)), // empty sibling - Arc::new(create_primitive_array::(4096, 0.0)), // target is bigger than parent, not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.3B dense empty sibling non-sequential offsets", - |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), // all types must match - Some(ScalarBuffer::from_iter((0..2046).chain([2047, 2047]))), // non sequential offsets, avois fast paths - vec![ - Arc::new(StringArray::new_null(0)), // empty sibling - Arc::new(create_primitive_array::(2048, 0.0)), // not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.1C dense all types match sequential offsets equal lens", - |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int8, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), // all types match - Some(ScalarBuffer::from_iter(0..2048)), // sequential offsets - vec![ - Arc::new(StringArray::new_null(1)), // non empty sibling - Arc::new(create_primitive_array::(2048, 0.0)), // same len as parent, not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.2C dense all types match sequential offsets bigger len", - |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int8, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), // all types match - Some(ScalarBuffer::from_iter(0..2048)), // sequential offsets - vec![ - Arc::new(StringArray::new_null(1)), // non empty sibling - Arc::new(create_primitive_array::(4096, 0.0)), // bigger than parent union, not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function( - "union_extract case 4.3C dense all types match non-sequential offsets", - |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int8, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), // all types match - Some(ScalarBuffer::from_iter((0..2046).chain([2047, 2047]))), //non sequential, avoid fast paths - vec![ - Arc::new(StringArray::new_null(1)), // non empty sibling - Arc::new(create_primitive_array::(2048, 0.0)), // not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function("union_extract case 5.1a dense none match less len", |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1; 2048]), //none match - Some(ScalarBuffer::from_iter(0..2048)), //dense - vec![ - Arc::new(create_string_array::(2048, 0.0)), // non empty - Arc::new(create_primitive_array::(1024, 0.0)), //less len, not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function( - "union_extract case 5.1b dense none match cant contain null mask", - |b| { - let union_target = UnionArray::try_new( - UnionFields::new([1], vec![Field::new("a", DataType::Boolean, true)]), - vec![1; 2048].into(), - None, - vec![Arc::new(create_boolean_array(2048, 0.0, 0.0))], - ) - .unwrap(); - - let parent_union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("union", union_target.data_type().clone(), false), - ], - ), - ScalarBuffer::from(vec![1; 2048]), //none match - Some(ScalarBuffer::from_iter(0..2048)), //dense - vec![ - Arc::new(create_string_array::(2048, 0.0)), // non empty - Arc::new(union_target), - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(parent_union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("union")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }, - ); - - c.bench_function("union_extract case 5.2 dense none match equal len", |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), //none match - Some(ScalarBuffer::from_iter(0..2048)), //dense - vec![ - Arc::new(create_string_array::(2048, 0.0)), // non empty - Arc::new(create_primitive_array::(2048, 0.0)), //equal len, not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function("union_extract case 5.3 dense none match greater len", |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3; 2048]), //none match - Some(ScalarBuffer::from_iter(0..2048)), //dense - vec![ - Arc::new(create_string_array::(2048, 0.0)), // non empty - Arc::new(create_primitive_array::(2049, 0.0)), //greater len, not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union)), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - c.bench_function("union_extract case 6 some match", |b| { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from_iter(repeat_n([1, 3], 1024).flatten()), //some matches but not all - Some(ScalarBuffer::from_iter( - std::iter::zip(1024..2048, 0..1024).flat_map(|(a, b)| [a, b]), - )), - vec![ - Arc::new(create_string_array::(2048, 0.0)), // sibling is not empty - Arc::new(create_primitive_array::(1024, 0.0)), //not null - ], - ) - .unwrap(); - - let args = [ - ColumnarValue::Array(Arc::new(union.clone())), - ColumnarValue::Scalar(ScalarValue::new_utf8("int")), - ]; - - b.iter(|| { - union_extract.invoke(&args).unwrap(); - }) - }); - - { - let mut is_sequential_group = c.benchmark_group("offsets"); - - let start = random::() as i32; - let offsets = (start..start + 4096).collect::>(); - - //compare performance to simpler alternatives - - is_sequential_group.bench_function("offsets sequential windows all", |b| { - b.iter(|| { - black_box(offsets.windows(2).all(|window| window[0] + 1 == window[1])); - }) - }); - - is_sequential_group.bench_function("offsets sequential windows fold &&", |b| { - b.iter(|| { - black_box( - #[allow(clippy::unnecessary_fold)] - offsets - .windows(2) - .fold(true, |b, w| b && (w[0] + 1 == w[1])), - ) - }) - }); - - is_sequential_group.bench_function("offsets sequential windows fold &", |b| { - b.iter(|| { - black_box(offsets.windows(2).fold(true, |b, w| b & (w[0] + 1 == w[1]))) - }) - }); - - is_sequential_group.bench_function("offsets sequential all", |b| { - b.iter(|| { - black_box( - offsets - .iter() - .copied() - .enumerate() - .all(|(i, v)| v == offsets[0] + i as i32), - ) - }) - }); - - is_sequential_group.bench_function("offsets sequential fold &&", |b| { - b.iter(|| { - black_box( - #[allow(clippy::unnecessary_fold)] - offsets - .iter() - .copied() - .enumerate() - .fold(true, |b, (i, v)| b && (v == offsets[0] + i as i32)), - ) - }) - }); - - is_sequential_group.bench_function("offsets sequential fold &", |b| { - b.iter(|| { - black_box( - offsets - .iter() - .copied() - .enumerate() - .fold(true, |b, (i, v)| b & (v == offsets[0] + i as i32)), - ) - }) - }); - - macro_rules! bench_sequential { - ($n:literal) => { - is_sequential_group - .bench_function(&format!("offsets sequential chunk {}", $n), |b| { - b.iter(|| black_box(is_sequential_generic::<$n>(&offsets))) - }); - }; - } - - bench_sequential!(8); - bench_sequential!(16); - bench_sequential!(32); - bench_sequential!(64); - bench_sequential!(128); - bench_sequential!(256); - bench_sequential!(512); - bench_sequential!(1024); - bench_sequential!(2048); - bench_sequential!(4096); - - is_sequential_group.finish(); - } - - { - let mut type_ids_eq = c.benchmark_group("type_ids_eq"); - - let type_id = random::(); - let type_ids = vec![type_id; 65536]; - - //compare performance to simpler alternatives - - type_ids_eq.bench_function("type_ids equal all", |b| { - b.iter(|| { - type_ids - .iter() - .copied() - .all(|value_type_id| value_type_id == type_id) - }) - }); - - type_ids_eq.bench_function("type_ids equal fold &&", |b| { - #[allow(clippy::unnecessary_fold)] - b.iter(|| type_ids.iter().fold(true, |b, v| b && (*v == type_id))) - }); - - type_ids_eq.bench_function("type_ids equal fold &", |b| { - b.iter(|| type_ids.iter().fold(true, |b, v| b & (*v == type_id))) - }); - - type_ids_eq.bench_function("type_ids equal compute::eq", |b| { - let type_ids_array = Int8Array::new(type_ids.clone().into(), None); - - b.iter(|| { - let eq = arrow::compute::kernels::cmp::eq( - &type_ids_array, - &Int8Array::new_scalar(black_box(type_id)), - ) - .unwrap(); - - eq.true_count() == type_ids.len() - }) - }); - - macro_rules! bench_type_ids_eq { - ($n:literal) => { - type_ids_eq.bench_function(&format!("type_ids equal true {}", $n), |b| { - b.iter(|| eq_scalar_generic::<$n>(&type_ids, type_ids[0])) - }); - - type_ids_eq - .bench_function(&format!("type_ids equal false {}", $n), |b| { - b.iter(|| eq_scalar_generic::<$n>(&type_ids, type_ids[0] + 1)) - }); - - type_ids_eq.bench_function(&format!("type_ids worst case {}", $n), |b| { - let mut type_ids = type_ids.clone(); - - type_ids[65535] = !type_ids[65535]; - - b.iter(|| eq_scalar_generic::<$n>(&type_ids, type_ids[0])) - }); - - type_ids_eq.bench_function(&format!("type_ids best case {}", $n), |b| { - let mut type_ids = type_ids.clone(); - - type_ids[$n - 1] += 1; - - b.iter(|| eq_scalar_generic::<$n>(&type_ids, type_ids[0])) - }); - }; - } - - bench_type_ids_eq!(16); - bench_type_ids_eq!(32); - bench_type_ids_eq!(64); - bench_type_ids_eq!(128); - bench_type_ids_eq!(256); - bench_type_ids_eq!(512); - bench_type_ids_eq!(1024); - bench_type_ids_eq!(2048); - bench_type_ids_eq!(4096); - } -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 433348b0e6fd..acb364d4ef05 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -15,18 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::cmp::Ordering; -use std::sync::Arc; - -use arrow::array::{ - layout, make_array, new_empty_array, new_null_array, Array, ArrayRef, BooleanArray, - Int32Array, Scalar, UnionArray, -}; -use arrow::compute::take; -use arrow::datatypes::{DataType, FieldRef, UnionFields, UnionMode}; - -use arrow::buffer::{BooleanBuffer, MutableBuffer, NullBuffer, ScalarBuffer}; -use arrow::util::bit_util; +use arrow::array::Array; +use arrow::datatypes::{DataType, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, ExprSchema, Result, ScalarValue, @@ -48,7 +38,7 @@ impl Default for UnionExtractFun { impl UnionExtractFun { pub fn new() -> Self { Self { - signature: Signature::any(2, Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -84,20 +74,14 @@ impl ScalarUDFImpl for UnionExtractFun { ); } - let fields = if let DataType::Union(fields, _) = &arg_types[0] { - fields - } else { + let DataType::Union(fields, _) = &arg_types[0] else { return exec_err!( "union_extract first argument must be a union, got {} instead", arg_types[0] ); }; - let field_name = if let Expr::Literal(ScalarValue::Utf8(Some(field_name))) = - &args[1] - { - field_name - } else { + let Expr::Literal(ScalarValue::Utf8(Some(field_name))) = &args[1] else { return exec_err!( "union_extract second argument must be a non-null string literal, got {} instead", arg_types[1] @@ -127,28 +111,20 @@ impl ScalarUDFImpl for UnionExtractFun { match union { ColumnarValue::Array(array) => { - let union_array = as_union_array(&array).map_err(|_| { + let _union_array = as_union_array(&array).map_err(|_| { exec_datafusion_err!( "union_extract first argument must be a union, got {} instead", array.data_type() ) })?; - let (fields, mode) = match union_array.data_type() { - DataType::Union(fields, mode) => (fields, mode), - _ => unreachable!(), - }; - - let target_type_id = find_field(fields, target_name?)?.0; - - match mode { - UnionMode::Sparse => { - Ok(extract_sparse(union_array, fields, target_type_id)?) - } - UnionMode::Dense => { - Ok(extract_dense(union_array, fields, target_type_id)?) - } - } + // Ok(arrow::compute::kernels::union_extract::union_extract( + // &union_array, + // target_name, + // )?) + Ok(ColumnarValue::Array(std::sync::Arc::new( + arrow::array::Int32Array::from(vec![1, 2]), + ))) } ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => { let target_name = target_name?; @@ -169,340 +145,43 @@ impl ScalarUDFImpl for UnionExtractFun { ), } } -} - -fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> { - fields - .iter() - .find(|field| field.1.name() == name) - .ok_or_else(|| exec_datafusion_err!("field {name} not found on union")) -} -fn extract_sparse( - union_array: &UnionArray, - fields: &UnionFields, - target_type_id: i8, -) -> Result { - let target = union_array.child(target_type_id); - - if fields.len() == 1 // case 1.1: if there is a single field, all type ids are the same, and since union doesn't have a null mask, the result array is exactly the same as it only child - || union_array.is_empty() // case 1.2: sparse union length and childrens length must match, if the union is empty, so is any children - || target.null_count() == target.len() || target.data_type().is_null() - // case 1.3: if all values of the target children are null, regardless of selected type ids, the result will also be completely null - { - Ok(ColumnarValue::Array(Arc::clone(target))) - } else { - match eq_scalar(union_array.type_ids(), target_type_id) { - // case 2: all type ids equals our target, and since unions doesn't have a null mask, the result array is exactly the same as our target - BoolValue::Scalar(true) => Ok(ColumnarValue::Array(Arc::clone(target))), - // case 3: none type_id matches our target, the result is a null array - BoolValue::Scalar(false) => { - if layout(target.data_type()).can_contain_null_mask { - // case 3.1: target array can contain a null mask - //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above - let data = unsafe { - target - .into_data() - .into_builder() - .nulls(Some(NullBuffer::new_null(target.len()))) - .build_unchecked() - }; - - Ok(ColumnarValue::Array(make_array(data))) - } else { - // case 3.2: target can't contain a null mask - Ok(new_null_columnar_value(target.data_type(), target.len())) - } - } - // case 4: some but not all type_id matches our target - BoolValue::Buffer(selected) => { - if layout(target.data_type()).can_contain_null_mask { - // case 4.1: target array can contain a null mask - let nulls = match target.nulls().filter(|n| n.null_count() > 0) { - // case 4.1.1: our target child has nulls and types other than our target are selected, union the masks - // the case where n.null_count() == n.len() is cheaply handled at case 1.3 - Some(nulls) => &selected & nulls.inner(), - // case 4.1.2: target child has no nulls, but types other than our target are selected, use the selected mask as a null mask - None => selected, - }; - - //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above - let data = unsafe { - assert_eq!(nulls.len(), target.len()); - - target - .into_data() - .into_builder() - .nulls(Some(nulls.into())) - .build_unchecked() - }; - - Ok(ColumnarValue::Array(make_array(data))) - } else { - // case 4.2: target can't containt a null mask, zip the values that match with a null value - Ok(ColumnarValue::Array(arrow::compute::kernels::zip::zip( - &BooleanArray::new(selected, None), - target, - &Scalar::new(new_null_array(target.data_type(), 1)), - )?)) - } - } + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return exec_err!( + "union_extract expects 2 arguments, got {} instead", + arg_types.len() + ); } - } -} -fn extract_dense( - union_array: &UnionArray, - fields: &UnionFields, - target_type_id: i8, -) -> Result { - let target = union_array.child(target_type_id); - let offsets = union_array.offsets().unwrap(); - - if union_array.is_empty() { - // case 1: the union is empty - if target.is_empty() { - // case 1.1: the target is also empty, do a cheap Arc::clone instead of allocating a new empty array - Ok(ColumnarValue::Array(Arc::clone(target))) - } else { - // case 1.2: the target is not empty, allocate a new empty array - Ok(ColumnarValue::Array(new_empty_array(target.data_type()))) - } - } else if target.is_empty() { - // case 2: the union is not empty but the target is, which implies that none type_id points to it. The result is a null array - Ok(new_null_columnar_value( - target.data_type(), - union_array.len(), - )) - } else if target.null_count() == target.len() || target.data_type().is_null() { - // case 3: since all values on our target are null, regardless of selected type ids and offsets, the result is a null array - match target.len().cmp(&union_array.len()) { - // case 3.1: since the target is smaller than the union, allocate a new correclty sized null array - Ordering::Less => Ok(new_null_columnar_value( - target.data_type(), - union_array.len(), - )), - // case 3.2: target equals the union len, return it direcly - Ordering::Equal => Ok(ColumnarValue::Array(Arc::clone(target))), - // case 3.3: target len is bigger than the union len, slice it - Ordering::Greater => { - Ok(ColumnarValue::Array(target.slice(0, union_array.len()))) - } - } - } else if fields.len() == 1 // case A: since there's a single field, our target, every type id must matches our target - || fields - .iter() - .filter(|(field_type_id, _)| *field_type_id != target_type_id) - .all(|(sibling_type_id, _)| union_array.child(sibling_type_id).is_empty()) - // case B: since siblings are empty, every type id must matches our target - { - // case 4: every type id matches our target - Ok(ColumnarValue::Array(extract_dense_all_selected( - union_array, - target, - offsets, - )?)) - } else { - match eq_scalar(union_array.type_ids(), target_type_id) { - // case 4C: all type ids matches our target. - // Non empty sibling without any selected value may happen after slicing the parent union, - // since only type_ids and offsets are sliced, not the children - BoolValue::Scalar(true) => Ok(ColumnarValue::Array( - extract_dense_all_selected(union_array, target, offsets)?, - )), - BoolValue::Scalar(false) => { - // case 5: none type_id matches our target, so the result array will be completely null - // Non empty target without any selected value may happen after slicing the parent union, - // since only type_ids and offsets are sliced, not the children - match (target.len().cmp(&union_array.len()), layout(target.data_type()).can_contain_null_mask) { - (Ordering::Less, _) // case 5.1A: our target is smaller than the parent union, allocate a new correclty sized null array - | (_, false) => { // case 5.1B: target array can't contain a null mask - Ok(new_null_columnar_value(target.data_type(), union_array.len())) - } - // case 5.2: target and parent union lengths are equal, and the target can contain a null mask, let's set it to a all-null null-buffer - (Ordering::Equal, true) => { - //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above - let data = unsafe { - target - .into_data() - .into_builder() - .nulls(Some(NullBuffer::new_null(union_array.len()))) - .build_unchecked() - }; - - Ok(ColumnarValue::Array(make_array(data))) - } - // case 5.3: target is bigger than it's parent union and can contain a null mask, let's slice it, and set it's nulls to a all-null null-buffer - (Ordering::Greater, true) => { - //SAFETY: The only change to the array data is the addition of a null mask, and if the target data type can contain a null mask was just checked above - let data = unsafe { - target - .into_data() - .slice(0, union_array.len()) - .into_builder() - .nulls(Some(NullBuffer::new_null(union_array.len()))) - .build_unchecked() - }; - - Ok(ColumnarValue::Array(make_array(data))) - } - } - } - BoolValue::Buffer(selected) => { - //case 6: some type_ids matches our target, but not all. For selected values, take the value pointed by the offset. For unselected, take a valid null - Ok(ColumnarValue::Array(take( - target, - &Int32Array::new(offsets.clone(), Some(selected.into())), - None, - )?)) - } + if !matches!(arg_types[0], DataType::Union(_, _)) { + return exec_err!( + "union_extract first argument must be a union, got {} instead", + arg_types[0] + ); } - } -} - -fn extract_dense_all_selected( - union_array: &UnionArray, - target: &Arc, - offsets: &ScalarBuffer, -) -> Result { - let sequential = - target.len() - offsets[0] as usize >= union_array.len() && is_sequential(offsets); - - if sequential && target.len() == union_array.len() { - // case 1: all offsets are sequential and both lengths match, return the array directly - Ok(Arc::clone(target)) - } else if sequential && target.len() > union_array.len() { - // case 2: All offsets are sequential, but our target is bigger than our union, slice it, starting at the first offset - Ok(target.slice(offsets[0] as usize, union_array.len())) - } else { - // case 3: Since offsets are not sequential, take them from the child to a new sequential and correcly sized array - let indices = Int32Array::try_new(offsets.clone(), None)?; - - Ok(take(target, &indices, None)?) - } -} -const EQ_SCALAR_CHUNK_SIZE: usize = 512; - -#[doc(hidden)] -#[derive(Debug, PartialEq)] -pub enum BoolValue { - Scalar(bool), - Buffer(BooleanBuffer), -} - -fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue { - eq_scalar_generic::(type_ids, target) -} - -// This is like MutableBuffer::collect_bool(type_ids.len(), |i| type_ids[i] == target) with fast paths for all true or all false values. -#[doc(hidden)] -pub fn eq_scalar_generic(type_ids: &[i8], target: i8) -> BoolValue { - fn count_sequence( - type_ids: &[i8], - mut f: impl FnMut(i8) -> bool, - ) -> usize { - type_ids - .chunks(N) - .take_while(|chunk| chunk.iter().copied().fold(true, |b, v| b & f(v))) - .map(|chunk| chunk.len()) - .sum() - } - - let true_bits = count_sequence::(type_ids, |v| v == target); - - let (set_bits, val) = if true_bits == type_ids.len() { - return BoolValue::Scalar(true); - } else if true_bits == 0 { - let false_bits = count_sequence::(type_ids, |v| v != target); - - if false_bits == type_ids.len() { - return BoolValue::Scalar(false); - } else { - (false_bits, false) + if !matches!(arg_types[1], DataType::Utf8) { + return exec_err!( + "union_extract second argument must be a non-null string literal, got {} instead", + arg_types[1] + ); } - } else { - (true_bits, true) - }; - - // restrict to chunk boundaries - let set_bits = set_bits - set_bits % 64; - let mut buffer = MutableBuffer::new(bit_util::ceil(type_ids.len(), 64) * 8) - .with_bitset(set_bits / 8, val); - - buffer.extend(type_ids[set_bits..].chunks(64).map(|chunk| { - chunk - .iter() - .copied() - .enumerate() - .fold(0, |packed, (bit_idx, v)| { - packed | ((v == target) as u64) << bit_idx - }) - })); - - buffer.truncate(bit_util::ceil(type_ids.len(), 8)); - - BoolValue::Buffer(BooleanBuffer::new(buffer.into(), 0, type_ids.len())) -} - -const IS_SEQUENTIAL_CHUNK_SIZE: usize = 64; - -fn is_sequential(offsets: &[i32]) -> bool { - is_sequential_generic::(offsets) -} - -#[doc(hidden)] -pub fn is_sequential_generic(offsets: &[i32]) -> bool { - if offsets.is_empty() { - return true; - } - - // the most common form of non sequential offsets is when sequential nulls reuses the same value, - // pointed by the same offset, while valid values offsets increases one by one - // this also checks if the last chunk/remainder is sequential - if offsets[0] + offsets.len() as i32 - 1 != offsets[offsets.len() - 1] { - return false; + Ok(arg_types.to_vec()) } - - let chunks = offsets.chunks_exact(N); - - let remainder = chunks.remainder(); - - chunks.enumerate().all(|(i, chunk)| { - let chunk_array = <&[i32; N]>::try_from(chunk).unwrap(); - - //checks if values within chunk are sequential - chunk_array - .iter() - .copied() - .enumerate() - .fold(true, |b, (i, o)| b & (o == chunk_array[0] + i as i32)) - && offsets[0] + (i * N) as i32 == chunk_array[0] //checks if chunk is sequential relative to the first offset - }) && remainder - .iter() - .copied() - .enumerate() - .fold(true, |b, (i, o)| b & (o == remainder[0] + i as i32)) //if the remainder is sequential relative to the first offset is checked at the start of the function } -fn new_null_columnar_value(data_type: &DataType, len: usize) -> ColumnarValue { - match ScalarValue::try_from(data_type) { - Ok(null_scalar) => ColumnarValue::Scalar(null_scalar), - Err(_) => ColumnarValue::Array(new_null_array(data_type, len)), - } +fn find_field<'a>(fields: &'a UnionFields, name: &str) -> Result<(i8, &'a FieldRef)> { + fields + .iter() + .find(|field| field.1.name() == name) + .ok_or_else(|| exec_datafusion_err!("field {name} not found on union")) } #[cfg(test)] mod tests { - use crate::core::union_extract::{ - eq_scalar_generic, is_sequential_generic, new_null_columnar_value, BoolValue, - }; - use std::sync::Arc; - - use arrow::array::{new_null_array, Array, Int8Array}; - use arrow::buffer::BooleanBuffer; use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -522,13 +201,6 @@ mod tests { ], ); - fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { - match value { - ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), - ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), - } - } - let result = fun.invoke(&[ ColumnarValue::Scalar(ScalarValue::Union( None, @@ -565,158 +237,10 @@ mod tests { Ok(()) } - #[test] - fn test_eq_scalar() { - //multiple all equal chunks, so it's loop and sum logic it's tested - //multiple chunks after, so it's loop logic it's tested - const ARRAY_LEN: usize = 64 * 4; - - //so out of 64 boundaries chunks can be generated and checked for - const EQ_SCALAR_CHUNK_SIZE: usize = 3; - - fn eq_scalar(type_ids: &[i8], target: i8) -> BoolValue { - eq_scalar_generic::(type_ids, target) - } - - fn eq(left: &[i8], right: i8) -> BooleanBuffer { - arrow::compute::kernels::cmp::eq( - &Int8Array::from(left.to_vec()), - &Int8Array::new_scalar(right), - ) - .unwrap() - .into_parts() - .0 - } - - assert_eq!(eq_scalar(&[], 1), BoolValue::Scalar(true)); - - assert_eq!(eq_scalar(&[1], 1), BoolValue::Scalar(true)); - assert_eq!(eq_scalar(&[2], 1), BoolValue::Scalar(false)); - - let mut values = [1; ARRAY_LEN]; - - assert_eq!(eq_scalar(&values, 1), BoolValue::Scalar(true)); - assert_eq!(eq_scalar(&values, 2), BoolValue::Scalar(false)); - - //every subslice should return the same value - for i in 1..ARRAY_LEN { - assert_eq!(eq_scalar(&values[..i], 1), BoolValue::Scalar(true)); - assert_eq!(eq_scalar(&values[..i], 2), BoolValue::Scalar(false)); - } - - // test that a single change anywhere is checked for - for i in 0..ARRAY_LEN { - values[i] = 2; - - assert_eq!(eq_scalar(&values, 1), BoolValue::Buffer(eq(&values, 1))); - assert_eq!(eq_scalar(&values, 2), BoolValue::Buffer(eq(&values, 2))); - - values[i] = 1; - } - } - - #[test] - fn test_is_sequential() { - /* - the smallest value that satisfies: - >1 so the fold logic of a exact chunk executes - >2 so a >1 non-exact remainder can exist, and it's fold logic executes - */ - const CHUNK_SIZE: usize = 3; - //we test arrays of size up to 8 = 2 * CHUNK_SIZE + 2: - //multiple(2) exact chunks, so the AND logic between them executes - //a >1(2) remainder, so: - // the AND logic between all exact chunks and the remainder executes - // the remainder fold logic executes - - fn is_sequential(v: &[i32]) -> bool { - is_sequential_generic::(v) - } - - assert!(is_sequential(&[])); //empty - assert!(is_sequential(&[1])); //single - - assert!(is_sequential(&[1, 2])); - assert!(is_sequential(&[1, 2, 3])); - assert!(is_sequential(&[1, 2, 3, 4])); - assert!(is_sequential(&[1, 2, 3, 4, 5])); - assert!(is_sequential(&[1, 2, 3, 4, 5, 6])); - assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7])); - assert!(is_sequential(&[1, 2, 3, 4, 5, 6, 7, 8])); - - assert!(!is_sequential(&[8, 7])); - assert!(!is_sequential(&[8, 7, 6])); - assert!(!is_sequential(&[8, 7, 6, 5])); - assert!(!is_sequential(&[8, 7, 6, 5, 4])); - assert!(!is_sequential(&[8, 7, 6, 5, 4, 3])); - assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2])); - assert!(!is_sequential(&[8, 7, 6, 5, 4, 3, 2, 1])); - - assert!(!is_sequential(&[0, 2])); - assert!(!is_sequential(&[1, 0])); - - assert!(!is_sequential(&[0, 2, 3])); - assert!(!is_sequential(&[1, 0, 3])); - assert!(!is_sequential(&[1, 2, 0])); - - assert!(!is_sequential(&[0, 2, 3, 4])); - assert!(!is_sequential(&[1, 0, 3, 4])); - assert!(!is_sequential(&[1, 2, 0, 4])); - assert!(!is_sequential(&[1, 2, 3, 0])); - - assert!(!is_sequential(&[0, 2, 3, 4, 5])); - assert!(!is_sequential(&[1, 0, 3, 4, 5])); - assert!(!is_sequential(&[1, 2, 0, 4, 5])); - assert!(!is_sequential(&[1, 2, 3, 0, 5])); - assert!(!is_sequential(&[1, 2, 3, 4, 0])); - - assert!(!is_sequential(&[0, 2, 3, 4, 5, 6])); - assert!(!is_sequential(&[1, 0, 3, 4, 5, 6])); - assert!(!is_sequential(&[1, 2, 0, 4, 5, 6])); - assert!(!is_sequential(&[1, 2, 3, 0, 5, 6])); - assert!(!is_sequential(&[1, 2, 3, 4, 0, 6])); - assert!(!is_sequential(&[1, 2, 3, 4, 5, 0])); - - assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7])); - assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7])); - assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7])); - assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7])); - assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7])); - assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7])); - assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0])); - - assert!(!is_sequential(&[0, 2, 3, 4, 5, 6, 7, 8])); - assert!(!is_sequential(&[1, 0, 3, 4, 5, 6, 7, 8])); - assert!(!is_sequential(&[1, 2, 0, 4, 5, 6, 7, 8])); - assert!(!is_sequential(&[1, 2, 3, 0, 5, 6, 7, 8])); - assert!(!is_sequential(&[1, 2, 3, 4, 0, 6, 7, 8])); - assert!(!is_sequential(&[1, 2, 3, 4, 5, 0, 7, 8])); - assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 0, 8])); - assert!(!is_sequential(&[1, 2, 3, 4, 5, 6, 7, 0])); - } - - #[test] - fn test_new_null_columnar_value() { - match new_null_columnar_value(&DataType::Int8, 2) { - ColumnarValue::Array(_) => { - panic!("new_null_columnar_value should've returned a scalar for Int8") - } - ColumnarValue::Scalar(scalar) => assert_eq!(scalar, ScalarValue::Int8(None)), - } - - let run_data_type = DataType::RunEndEncoded( - Arc::new(Field::new("run_ends", DataType::Int16, false)), - Arc::new(Field::new("values", DataType::Utf8, false)), - ); - - match new_null_columnar_value(&run_data_type, 2) { - ColumnarValue::Array(array) => assert_eq!( - array.into_data(), - new_null_array(&run_data_type, 2).into_data() - ), - ColumnarValue::Scalar(_) => panic!( - "new_null_columnar_value should've returned a array for RunEndEncoded" - ), + fn assert_scalar(value: ColumnarValue, expected: ScalarValue) { + match value { + ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"), + ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected), } } } diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 0e4b3b782f58..b645a3e02575 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -22,14 +22,11 @@ use std::path::Path; use std::sync::Arc; use arrow::array::{ - new_null_array, Array, ArrayRef, BinaryArray, BooleanArray, Float64Array, Int32Array, - LargeBinaryArray, LargeStringArray, NullArray, StringArray, TimestampNanosecondArray, - UnionArray, + Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, + LargeStringArray, NullArray, StringArray, TimestampNanosecondArray, UnionArray, }; use arrow::buffer::ScalarBuffer; -use arrow::datatypes::{ - DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, -}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; use arrow::record_batch::RecordBatch; use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; use datafusion::physical_plan::ExecutionPlan; @@ -112,7 +109,7 @@ impl TestContext { info!("Registering metadata table tables"); register_metadata_tables(test_ctx.session_ctx()).await; } - "union_datatype.slt" => { + "union_function.slt" => { info!("Registering tables with union column"); register_union_tables(test_ctx.session_ctx()) } @@ -371,36 +368,8 @@ fn create_example_udf() -> ScalarUDF { } fn register_union_tables(ctx: &SessionContext) { - sparse_1_1_single_field(ctx); - sparse_1_2_empty(ctx); - sparse_1_3a_null_target(ctx); - sparse_1_3b_null_target(ctx); - sparse_2_all_types_match(ctx); - sparse_3_1_none_match_target_can_contain_null_mask(ctx); - sparse_3_2_none_match_cant_contain_null_mask_union_target(ctx); - sparse_4_1_1_target_with_nulls(ctx); - sparse_4_1_2_target_without_nulls(ctx); - sparse_4_2_some_match_target_cant_contain_null_mask(ctx); - dense_1_1_both_empty(ctx); - dense_1_2_empty_union_target_non_empty(ctx); - dense_2_non_empty_union_target_empty(ctx); - dense_3_1_null_target_smaller_len(ctx); - dense_3_2_null_target_equal_len(ctx); - dense_3_3_null_target_bigger_len(ctx); - dense_4_1a_single_type_sequential_offsets_equal_len(ctx); - dense_4_2a_single_type_sequential_offsets_bigger(ctx); - dense_4_3a_single_type_non_sequential(ctx); - dense_4_1b_empty_siblings_sequential_equal_len(ctx); - dense_4_2b_empty_siblings_sequential_bigger_len(ctx); - dense_4_3b_empty_sibling_non_sequential(ctx); - dense_4_1c_all_types_match_sequential_equal_len(ctx); - dense_4_2c_all_types_match_sequential_bigger_len(ctx); - dense_4_3c_all_types_match_non_sequential(ctx); - dense_5_1a_none_match_less_len(ctx); - dense_5_1b_cant_contain_null_mask(ctx); - dense_5_2_none_match_equal_len(ctx); - dense_5_3_none_match_greater_len(ctx); - dense_6_some_matches(ctx); + sparse_union(ctx); + dense_union(ctx); empty_sparse_union(ctx); empty_dense_union(ctx); } @@ -425,7 +394,7 @@ fn register_union_table( ctx.register_batch(table_name, batch).unwrap(); } -fn sparse_1_1_single_field(ctx: &SessionContext) { +fn sparse_union(ctx: &SessionContext) { let union = UnionArray::try_new( //single field UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), @@ -437,803 +406,20 @@ fn sparse_1_1_single_field(ctx: &SessionContext) { ) .unwrap(); - register_union_table( - ctx, - union, - "sparse_1_1_single_field", - Int32Array::from(vec![1, 2]), - ); -} - -fn sparse_1_2_empty(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), //target type is not Null - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![]), //empty union - None, // sparse - vec![ - Arc::new(StringArray::new_null(0)), - Arc::new(Int32Array::new_null(0)), - ], - ) - .unwrap(); - - register_union_table(ctx, union, "sparse_1_2_empty", StringArray::new_null(0)); -} - -fn sparse_1_3a_null_target(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("null", DataType::Null, true), - ], - ), - ScalarBuffer::from(vec![1]), //not empty - None, // sparse - vec![ - Arc::new(StringArray::new_null(1)), - Arc::new(NullArray::new(1)), // null data type - ], - ) - .unwrap(); - - register_union_table(ctx, union, "sparse_1_3a_null_target", NullArray::new(1)); -} - -fn sparse_1_3b_null_target(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), //target type is not Null - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1]), //not empty - None, // sparse - vec![ - Arc::new(StringArray::new_null(1)), //all null - Arc::new(Int32Array::new_null(1)), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "sparse_1_3b_null_target", - StringArray::new_null(1), - ); -} - -fn sparse_2_all_types_match(ctx: &SessionContext) { - let union = UnionArray::try_new( - //multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3]), // all types match - None, //sparse - vec![ - Arc::new(StringArray::new_null(2)), - Arc::new(Int32Array::from(vec![1, 4])), // not null - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "sparse_2_all_types_match", - Int32Array::from(vec![1, 4]), - ); -} - -fn sparse_3_1_none_match_target_can_contain_null_mask(ctx: &SessionContext) { - let union = UnionArray::try_new( - //multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1, 1, 1, 1]), // none match - None, // sparse - vec![ - Arc::new(StringArray::new_null(4)), - Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target is not null - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "sparse_3_1_none_match_target_can_contain_null_mask", - Int32Array::new_null(4), - ); -} - -fn sparse_3_2_none_match_cant_contain_null_mask_union_target(ctx: &SessionContext) { - let target_fields = - UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); - - let target_data_type = DataType::Union(target_fields.clone(), UnionMode::Sparse); - - let union = UnionArray::try_new( - //multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("union", target_data_type.clone(), false), - ], - ), - ScalarBuffer::from(vec![1, 1]), // none match - None, //sparse - vec![ - Arc::new(StringArray::new_null(2)), - //target is not null - Arc::new( - UnionArray::try_new( - target_fields.clone(), - ScalarBuffer::from(vec![10, 10]), - None, - vec![Arc::new(BooleanArray::from(vec![true, false]))], - ) - .unwrap(), - ), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "sparse_3_2_none_match_cant_contain_null_mask_union_target", - new_null_array(&target_data_type, 2), - ); -} - -fn sparse_4_1_1_target_with_nulls(ctx: &SessionContext) { - let union = UnionArray::try_new( - //multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3, 1, 1]), // multiple selected types - None, // sparse - vec![ - Arc::new(StringArray::new_null(4)), - Arc::new(Int32Array::from(vec![None, Some(4), None, Some(8)])), // target with nulls - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "sparse_4_1_1_target_with_nulls", - Int32Array::from(vec![None, Some(4), None, None]), - ); -} - -fn sparse_4_1_2_target_without_nulls(ctx: &SessionContext) { - let union = UnionArray::try_new( - //multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1, 3, 3]), // multiple selected types - None, // sparse - vec![ - Arc::new(StringArray::new_null(3)), - Arc::new(Int32Array::from(vec![2, 4, 8])), // target without nulls - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "sparse_4_1_2_target_without_nulls", - Int32Array::from(vec![None, Some(4), Some(8)]), - ); -} - -fn sparse_4_2_some_match_target_cant_contain_null_mask(ctx: &SessionContext) { - let target_fields = - UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); - - let union = UnionArray::try_new( - //multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new( - "union", - DataType::Union(target_fields.clone(), UnionMode::Sparse), - false, - ), - ], - ), - ScalarBuffer::from(vec![3, 1]), // some types match, but not all - None, //sparse - vec![ - Arc::new(StringArray::new_null(2)), - Arc::new( - UnionArray::try_new( - target_fields.clone(), - ScalarBuffer::from(vec![10, 10]), - None, - vec![Arc::new(BooleanArray::from(vec![true, false]))], - ) - .unwrap(), - ), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "sparse_4_2_some_match_target_cant_contain_null_mask", - UnionArray::try_new( - target_fields, - ScalarBuffer::from(vec![10, 10]), - None, - vec![Arc::new(BooleanArray::from(vec![Some(true), None]))], - ) - .unwrap(), - ); -} - -fn dense_1_1_both_empty(ctx: &SessionContext) { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![]), //empty union - Some(ScalarBuffer::from(vec![])), // dense - vec![ - Arc::new(StringArray::new_null(0)), //empty target - Arc::new(Int32Array::new_null(0)), - ], - ) - .unwrap(); - - register_union_table(ctx, union, "dense_1_1_both_empty", StringArray::new_null(0)); -} - -fn dense_1_2_empty_union_target_non_empty(ctx: &SessionContext) { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![]), //empty union - Some(ScalarBuffer::from(vec![])), // dense - vec![ - Arc::new(StringArray::new_null(1)), //non empty target - Arc::new(Int32Array::new_null(0)), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_1_2_empty_union_target_non_empty", - StringArray::new_null(0), - ); -} - -fn dense_2_non_empty_union_target_empty(ctx: &SessionContext) { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3]), //non empty union - Some(ScalarBuffer::from(vec![0, 1])), // dense - vec![ - Arc::new(StringArray::new_null(0)), //empty target - Arc::new(Int32Array::new_null(2)), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_2_non_empty_union_target_empty", - StringArray::new_null(2), - ); + register_union_table(ctx, union, "sparse_union", Int32Array::from(vec![1, 2])); } -fn dense_3_1_null_target_smaller_len(ctx: &SessionContext) { +fn dense_union(ctx: &SessionContext) { let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3]), //non empty union - Some(ScalarBuffer::from(vec![0, 0])), //dense - vec![ - Arc::new(StringArray::new_null(1)), //smaller target - Arc::new(Int32Array::new_null(2)), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_3_1_null_target_smaller_len", - StringArray::new_null(2), - ); -} - -fn dense_3_2_null_target_equal_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3]), //non empty union - Some(ScalarBuffer::from(vec![0, 0])), //dense - vec![ - Arc::new(StringArray::new_null(2)), //equal len - Arc::new(Int32Array::new_null(2)), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_3_2_null_target_equal_len", - StringArray::new_null(2), - ); -} - -fn dense_3_3_null_target_bigger_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3]), //non empty union - Some(ScalarBuffer::from(vec![0, 0])), //dense - vec![ - Arc::new(StringArray::new_null(3)), //bigger len - Arc::new(Int32Array::new_null(3)), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_3_3_null_target_bigger_len", - StringArray::new_null(2), - ); -} - -fn dense_4_1a_single_type_sequential_offsets_equal_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - // single field - UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, false)]), - ScalarBuffer::from(vec![1, 1]), //non empty union - Some(ScalarBuffer::from(vec![0, 1])), //sequential - vec![ - Arc::new(StringArray::from_iter_values(["a1", "b2"])), //equal len, non null - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_1a_single_type_sequential_offsets_equal_len", - StringArray::from_iter_values(["a1", "b2"]), - ); -} - -fn dense_4_2a_single_type_sequential_offsets_bigger(ctx: &SessionContext) { - let union = UnionArray::try_new( - // single field - UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, false)]), - ScalarBuffer::from(vec![1, 1]), //non empty union - Some(ScalarBuffer::from(vec![0, 1])), //sequential - vec![ - Arc::new(StringArray::from_iter_values(["a1", "b2", "c3"])), //equal len, non null - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_2a_single_type_sequential_offsets_bigger", - StringArray::from_iter_values(["a1", "b2"]), - ); -} - -fn dense_4_3a_single_type_non_sequential(ctx: &SessionContext) { - let union = UnionArray::try_new( - // single field - UnionFields::new(vec![1], vec![Field::new("str", DataType::Utf8, false)]), - ScalarBuffer::from(vec![1, 1]), //non empty union - Some(ScalarBuffer::from(vec![0, 2])), //non sequential - vec![ - Arc::new(StringArray::from_iter_values(["a1", "b2", "c3"])), //equal len, non null - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_3a_single_type_non_sequential", - StringArray::from_iter_values(["a1", "c3"]), - ); -} - -fn dense_4_1b_empty_siblings_sequential_equal_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1, 1]), //non empty union - Some(ScalarBuffer::from(vec![0, 1])), //sequential - vec![ - Arc::new(StringArray::from(vec!["a", "b"])), //equal len, non null - Arc::new(Int32Array::new_null(0)), //empty sibling - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_1b_empty_siblings_sequential_equal_len", - StringArray::from(vec!["a", "b"]), - ); -} - -fn dense_4_2b_empty_siblings_sequential_bigger_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1, 1]), //non empty union - Some(ScalarBuffer::from(vec![0, 1])), //sequential - vec![ - Arc::new(StringArray::from(vec!["a", "b", "c"])), //bigger len, non null - Arc::new(Int32Array::new_null(0)), //empty sibling - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_2b_empty_siblings_sequential_bigger_len", - StringArray::from(vec!["a", "b"]), - ); -} - -fn dense_4_3b_empty_sibling_non_sequential(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1, 1]), //non empty union - Some(ScalarBuffer::from(vec![0, 2])), //non sequential - vec![ - Arc::new(StringArray::from(vec!["a", "b", "c"])), //non null - Arc::new(Int32Array::new_null(0)), //empty sibling - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_3b_empty_sibling_non_sequential", - StringArray::from(vec!["a", "c"]), - ); -} - -fn dense_4_1c_all_types_match_sequential_equal_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1, 1]), //all types match - Some(ScalarBuffer::from(vec![0, 1])), //sequential - vec![ - Arc::new(StringArray::from(vec!["a1", "b2"])), //equal len - Arc::new(Int32Array::new_null(2)), //non empty sibling - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_1c_all_types_match_sequential_equal_len", - StringArray::from(vec!["a1", "b2"]), - ); -} - -fn dense_4_2c_all_types_match_sequential_bigger_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1, 1]), //all types match - Some(ScalarBuffer::from(vec![0, 1])), //sequential - vec![ - Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), //bigger len - Arc::new(Int32Array::new_null(2)), //non empty sibling - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_2c_all_types_match_sequential_bigger_len", - StringArray::from(vec!["a1", "b2"]), - ); -} - -fn dense_4_3c_all_types_match_non_sequential(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![1, 1]), //all types match - Some(ScalarBuffer::from(vec![0, 2])), //non sequential - vec![ - Arc::new(StringArray::from(vec!["a1", "b2", "b3"])), - Arc::new(Int32Array::new_null(2)), //non empty sibling - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_4_3c_all_types_match_non_sequential", - StringArray::from(vec!["a1", "b3"]), - ); -} - -fn dense_5_1a_none_match_less_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches - Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense - vec![ - Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len - Arc::new(Int32Array::from(vec![1, 2])), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_5_1a_none_match_less_len", - StringArray::new_null(5), - ); -} - -fn dense_5_1b_cant_contain_null_mask(ctx: &SessionContext) { - let target_fields = - UnionFields::new([10], [Field::new("bool", DataType::Boolean, true)]); - - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new( - "union", - DataType::Union(target_fields.clone(), UnionMode::Sparse), - false, - ), - ], - ), - ScalarBuffer::from(vec![1, 1, 1, 1, 1]), //none matches - Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense - vec![ - Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // less len - Arc::new( - UnionArray::try_new( - target_fields.clone(), - ScalarBuffer::from(vec![10]), - None, - vec![Arc::new(BooleanArray::from(vec![true]))], - ) - .unwrap(), - ), // non empty - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_5_1b_cant_contain_null_mask", - new_null_array(&DataType::Union(target_fields, UnionMode::Sparse), 5), - ); -} - -fn dense_5_2_none_match_equal_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches - Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense - vec![ - Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5"])), // equal len - Arc::new(Int32Array::from(vec![1, 2])), - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_5_2_none_match_equal_len", - StringArray::new_null(5), - ); -} - -fn dense_5_3_none_match_greater_len(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3, 3, 3, 3]), //none matches - Some(ScalarBuffer::from(vec![0, 0, 0, 1, 1])), // dense - vec![ - Arc::new(StringArray::from(vec!["a1", "b2", "c3", "d4", "e5", "f6"])), // greater len - Arc::new(Int32Array::from(vec![1, 2])), //non null - ], - ) - .unwrap(); - - register_union_table( - ctx, - union, - "dense_5_3_none_match_greater_len", - StringArray::new_null(5), - ); -} - -fn dense_6_some_matches(ctx: &SessionContext) { - let union = UnionArray::try_new( - // multiple fields - UnionFields::new( - vec![1, 3], - vec![ - Field::new("str", DataType::Utf8, false), - Field::new("int", DataType::Int32, false), - ], - ), - ScalarBuffer::from(vec![3, 3, 1, 1, 1]), //some matches - Some(ScalarBuffer::from(vec![0, 1, 0, 1, 2])), // dense - vec![ - Arc::new(StringArray::from(vec!["a1", "b2", "c3"])), // non null - Arc::new(Int32Array::from(vec![1, 2])), - ], + //single field + UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), + ScalarBuffer::from(vec![3, 3]), + Some(vec![0, 1].into()), + vec![Arc::new(Int32Array::from(vec![1, 2]))], ) .unwrap(); - register_union_table( - ctx, - union, - "dense_6_some_matches", - Int32Array::from(vec![Some(1), Some(2), None, None, None]), - ); + register_union_table(ctx, union, "dense_union", Int32Array::from(vec![1, 2])); } fn empty_sparse_union(ctx: &SessionContext) { diff --git a/datafusion/sqllogictest/test_files/union_datatype.slt b/datafusion/sqllogictest/test_files/union_datatype.slt deleted file mode 100644 index 10e1b692ba0b..000000000000 --- a/datafusion/sqllogictest/test_files/union_datatype.slt +++ /dev/null @@ -1,255 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -########## -## UNION DataType Tests -########## - - -query II? -select union_extract(my_union, 'int'), expected, my_union from sparse_1_1_single_field; ----- -1 1 {int=1} -2 2 {int=2} - -query IT? -select union_extract(my_union, 'int'), expected, my_union from sparse_1_2_empty; ----- - -query ??? -select union_extract(my_union, 'null'), expected, my_union from sparse_1_3a_null_target; ----- -NULL NULL {str=} - -query IT? -select union_extract(my_union, 'int'), expected, my_union from sparse_1_3b_null_target; ----- -NULL NULL {str=} - -query II? -select union_extract(my_union, 'int'), expected, my_union from sparse_2_all_types_match; ----- -1 1 {int=1} -4 4 {int=4} - -query II? -select union_extract(my_union, 'int'), expected, my_union from sparse_3_1_none_match_target_can_contain_null_mask; ----- -NULL NULL {str=} -NULL NULL {str=} -NULL NULL {str=} -NULL NULL {str=} - -query ??? -select union_extract(my_union, 'union'), expected, my_union from sparse_3_2_none_match_cant_contain_null_mask_union_target; ----- -{bool=} {bool=} {str=} -{bool=} {bool=} {str=} - -query II? -select union_extract(my_union, 'int'), expected, my_union from sparse_4_1_1_target_with_nulls; ----- -NULL NULL {int=} -4 4 {int=4} -NULL NULL {str=} -NULL NULL {str=} - -query II? -select union_extract(my_union, 'int'), expected, my_union from sparse_4_1_2_target_without_nulls; ----- -NULL NULL {str=} -4 4 {int=4} -8 8 {int=8} - -query ??? -select union_extract(my_union, 'union'), expected, my_union from sparse_4_2_some_match_target_cant_contain_null_mask; ----- -{bool=true} {bool=true} {union={bool=true}} -{bool=} {bool=} {str=} - -query IT? -select union_extract(my_union, 'int'), expected, my_union from dense_1_1_both_empty; ----- - -query IT? -select union_extract(my_union, 'int'), expected, my_union from dense_1_2_empty_union_target_non_empty; ----- - -query IT? -select union_extract(my_union, 'int'), expected, my_union from dense_2_non_empty_union_target_empty; ----- -NULL NULL {int=} -NULL NULL {int=} - -query IT? -select union_extract(my_union, 'int'), expected, my_union from dense_3_1_null_target_smaller_len; ----- -NULL NULL {int=} -NULL NULL {int=} - -query IT? -select union_extract(my_union, 'int'), expected, my_union from dense_3_2_null_target_equal_len; ----- -NULL NULL {int=} -NULL NULL {int=} - -query IT? -select union_extract(my_union, 'int'), expected, my_union from dense_3_3_null_target_bigger_len; ----- -NULL NULL {int=} -NULL NULL {int=} - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_1a_single_type_sequential_offsets_equal_len; ----- -a1 a1 {str=a1} -b2 b2 {str=b2} - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_2a_single_type_sequential_offsets_bigger; ----- -a1 a1 {str=a1} -b2 b2 {str=b2} - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_3a_single_type_non_sequential; ----- -a1 a1 {str=a1} -c3 c3 {str=c3} - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_1b_empty_siblings_sequential_equal_len; ----- -a a {str=a} -b b {str=b} - - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_2b_empty_siblings_sequential_bigger_len; ----- -a a {str=a} -b b {str=b} - - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_3b_empty_sibling_non_sequential; ----- -a a {str=a} -c c {str=c} - - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_1c_all_types_match_sequential_equal_len; ----- -a1 a1 {str=a1} -b2 b2 {str=b2} - - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_2c_all_types_match_sequential_bigger_len; ----- -a1 a1 {str=a1} -b2 b2 {str=b2} - - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_4_3c_all_types_match_non_sequential; ----- -a1 a1 {str=a1} -b3 b3 {str=b3} - - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_5_1a_none_match_less_len; ----- -NULL NULL {int=1} -NULL NULL {int=1} -NULL NULL {int=1} -NULL NULL {int=2} -NULL NULL {int=2} - - -query ??? -select union_extract(my_union, 'union'), expected, my_union from dense_5_1b_cant_contain_null_mask; ----- -{bool=} {bool=} {str=a1} -{bool=} {bool=} {str=a1} -{bool=} {bool=} {str=a1} -{bool=} {bool=} {str=b2} -{bool=} {bool=} {str=b2} - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_5_2_none_match_equal_len; ----- -NULL NULL {int=1} -NULL NULL {int=1} -NULL NULL {int=1} -NULL NULL {int=2} -NULL NULL {int=2} - - -query TT? -select union_extract(my_union, 'str'), expected, my_union from dense_5_3_none_match_greater_len; ----- -NULL NULL {int=1} -NULL NULL {int=1} -NULL NULL {int=1} -NULL NULL {int=2} -NULL NULL {int=2} - - -query II? -select union_extract(my_union, 'int'), expected, my_union from dense_6_some_matches; ----- -1 1 {int=1} -2 2 {int=2} -NULL NULL {str=a1} -NULL NULL {str=b2} -NULL NULL {str=c3} - - -query error DataFusion error: Execution error: field int not found on union -select union_extract(my_union, 'int'), expected, my_union from empty_sparse_union; - - -query error DataFusion error: Execution error: field int not found on union -select union_extract(my_union, 'int'), expected, my_union from empty_dense_union; - - -query error -select union_extract() from empty_dense_union; - - -query error -select union_extract(my_union) from empty_dense_union; - - -query error -select union_extract('a') from empty_dense_union; - - - -query error DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead -select union_extract('a', my_union) from empty_dense_union; - - -query error DataFusion error: Execution error: union_extract second argument must be a non\-null string literal, got Int64 instead -select union_extract(my_union, 1) from empty_dense_union; - - -query error -select union_extract(my_union, 'a', 'b') from empty_dense_union; diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt new file mode 100644 index 000000000000..ae61c508dcbc --- /dev/null +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## UNION DataType Tests +########## + + +query II? +select union_extract(my_union, 'int'), expected, my_union from sparse_union; +---- +1 1 {int=1} +2 2 {int=2} + + +query II? +select union_extract(my_union, 'int'), expected, my_union from dense_union; +---- +1 1 {int=1} +2 2 {int=2} + + + +query error DataFusion error: Execution error: field int not found on union +select union_extract(my_union, 'int'), expected, my_union from empty_sparse_union; + + +query error DataFusion error: Execution error: field int not found on union +select union_extract(my_union, 'int'), expected, my_union from empty_dense_union; + + +query error +select union_extract() from empty_dense_union; +---- +DataFusion error: Error during planning: Error during planning: union_extract does not support zero arguments. No function matches the given name and argument types 'union_extract()'. You might need to add explicit type casts. + Candidate functions: + union_extract(UserDefined) + + + +query error +select union_extract(my_union) from empty_dense_union; +---- +DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract expects 2 arguments, got 1 instead") No function matches the given name and argument types 'union_extract(Union([], Dense))'. You might need to add explicit type casts. + Candidate functions: + union_extract(UserDefined) + + + +query error +select union_extract('a') from empty_dense_union; +---- +DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract expects 2 arguments, got 1 instead") No function matches the given name and argument types 'union_extract(Utf8)'. You might need to add explicit type casts. + Candidate functions: + union_extract(UserDefined) + + + + +query error +select union_extract('a', my_union) from empty_dense_union; +---- +DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract first argument must be a union, got Utf8 instead") No function matches the given name and argument types 'union_extract(Utf8, Union([], Dense))'. You might need to add explicit type casts. + Candidate functions: + union_extract(UserDefined) + + + +query error +select union_extract(my_union, 1) from empty_dense_union; +---- +DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract second argument must be a non-null string literal, got Int64 instead") No function matches the given name and argument types 'union_extract(Union([], Dense), Int64)'. You might need to add explicit type casts. + Candidate functions: + union_extract(UserDefined) + + + +query error +select union_extract(my_union, 'a', 'b') from empty_dense_union; +---- +DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract expects 2 arguments, got 3 instead") No function matches the given name and argument types 'union_extract(Union([], Dense), Utf8, Utf8)'. You might need to add explicit type casts. + Candidate functions: + union_extract(UserDefined) From 0e5cfd942eb44df0c0ba5ba57bd0eb05a38fd021 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 19 Jan 2025 21:27:12 -0300 Subject: [PATCH 04/15] docs: add union functions section --- datafusion/expr/src/udf.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 51c42b5c4c30..679027cbd8e0 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -877,6 +877,7 @@ pub mod scalar_doc_sections { DOC_SECTION_STRUCT, DOC_SECTION_MAP, DOC_SECTION_HASHING, + DOC_SECTION_UNION, DOC_SECTION_OTHER, ] } @@ -951,4 +952,10 @@ The following regular expression functions are supported:"#, label: "Other Functions", description: None, }; + + pub const DOC_SECTION_UNION: DocSection = DocSection { + include: true, + label: "Union Functions", + description: None, + }; } From 959ed962339e65e07a9d9bb2cf0925af6235d3c8 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 19 Jan 2025 21:28:39 -0300 Subject: [PATCH 05/15] docs: simplify union_extract docs --- .../source/user-guide/sql/scalar_functions.md | 44 +++++-------------- 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 2a23a7c2fd46..63d89579e1f3 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -4351,40 +4351,16 @@ union_extract(union, field_name) - **field_name**: Literal string, name of the field to extract ``` -❯ select my_union from table_with_union; -+----------+ -| my_union | -+----------+ -| {a=1} | -| {b=3.0} | -| {a=4} | -| {b=} | -| {a=} | -+----------+ - -❯ select union_extract(my_union, 'a'); -+------------------------------+ -| union_extract(my_union, 'a') | -+------------------------------+ -| 1 | -| | -| 4 | -| | -| | -+------------------------------+ - -❯ select union_extract(my_union, 'b'); -+------------------------------+ -| union_extract(my_union, 'b') | -+------------------------------+ -| | -| 3.0 | -| | -| | -| | -+------------------------------+ - - +❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; ++--------------+----------------------------------+----------------------------------+ +| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') | ++--------------+----------------------------------+----------------------------------+ +| {a=1} | 1 | | +| {b=3.0} | | 3.0 | +| {a=4} | 4 | | +| {b=} | | | +| {a=} | | | ++--------------+----------------------------------+----------------------------------+ ``` ## Other Functions From 30940f7c77dc650e032193cab31d8b1a36d486dd Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 19 Jan 2025 21:30:50 -0300 Subject: [PATCH 06/15] test: simplify union_extract sqllogictests --- datafusion/sqllogictest/src/test_context.rs | 84 +++---------------- .../test_files/union_function.slt | 58 +++++-------- 2 files changed, 35 insertions(+), 107 deletions(-) diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index a03bc9ed64d8..ab3006b1c15c 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use arrow::array::{ Array, ArrayRef, BinaryArray, Float64Array, Int32Array, LargeBinaryArray, - LargeStringArray, NullArray, StringArray, TimestampNanosecondArray, UnionArray, + LargeStringArray, StringArray, TimestampNanosecondArray, UnionArray, }; use arrow::buffer::ScalarBuffer; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit, UnionFields}; @@ -115,8 +115,8 @@ impl TestContext { register_metadata_tables(test_ctx.session_ctx()).await; } "union_function.slt" => { - info!("Registering tables with union column"); - register_union_tables(test_ctx.session_ctx()) + info!("Registering table with union column"); + register_union_table(test_ctx.session_ctx()) } _ => { info!("Using default SessionContext"); @@ -408,81 +408,23 @@ fn create_example_udf() -> ScalarUDF { ) } -fn register_union_tables(ctx: &SessionContext) { - sparse_union(ctx); - dense_union(ctx); - empty_sparse_union(ctx); - empty_dense_union(ctx); -} - -fn register_union_table( - ctx: &SessionContext, - union: UnionArray, - table_name: &str, - expected: impl Array + 'static, -) { - let schema = Schema::new(vec![ - Field::new("my_union", union.data_type().clone(), false), - Field::new("expected", expected.data_type().clone(), true), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(union), Arc::new(expected)], - ) - .unwrap(); - - ctx.register_batch(table_name, batch).unwrap(); -} - -fn sparse_union(ctx: &SessionContext) { +fn register_union_table(ctx: &SessionContext) { let union = UnionArray::try_new( - //single field - UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), - ScalarBuffer::from(vec![3, 3]), // non empty, every type id must match - None, //sparse - vec![ - Arc::new(Int32Array::from(vec![1, 2])), // not null - ], - ) - .unwrap(); - - register_union_table(ctx, union, "sparse_union", Int32Array::from(vec![1, 2])); -} - -fn dense_union(ctx: &SessionContext) { - let union = UnionArray::try_new( - //single field UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]), ScalarBuffer::from(vec![3, 3]), - Some(vec![0, 1].into()), - vec![Arc::new(Int32Array::from(vec![1, 2]))], - ) - .unwrap(); - - register_union_table(ctx, union, "dense_union", Int32Array::from(vec![1, 2])); -} - -fn empty_sparse_union(ctx: &SessionContext) { - let union = UnionArray::try_new( - UnionFields::empty(), - ScalarBuffer::from(vec![]), None, - vec![], + vec![Arc::new(Int32Array::from(vec![1, 2]))], ) .unwrap(); - register_union_table(ctx, union, "empty_sparse_union", NullArray::new(0)) -} + let schema = Schema::new(vec![Field::new( + "union_column", + union.data_type().clone(), + false, + )]); -fn empty_dense_union(ctx: &SessionContext) { - let union = UnionArray::try_new( - UnionFields::empty(), - ScalarBuffer::from(vec![]), - Some(ScalarBuffer::from(vec![])), - vec![], - ) - .unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(union)]).unwrap(); - register_union_table(ctx, union, "empty_dense_union", NullArray::new(0)) + ctx.register_batch("union_table", batch).unwrap(); } diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index ae61c508dcbc..95a923d765f2 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -20,78 +20,64 @@ ########## -query II? -select union_extract(my_union, 'int'), expected, my_union from sparse_union; +query ?I +select union_column, union_extract(union_column, 'int') from union_table; ---- -1 1 {int=1} -2 2 {int=2} - - -query II? -select union_extract(my_union, 'int'), expected, my_union from dense_union; ----- -1 1 {int=1} -2 2 {int=2} +{int=1} 1 +{int=2} 2 -query error DataFusion error: Execution error: field int not found on union -select union_extract(my_union, 'int'), expected, my_union from empty_sparse_union; +query error DataFusion error: Execution error: field bool not found on union +select union_extract(union_column, 'bool') from union_table; -query error DataFusion error: Execution error: field int not found on union -select union_extract(my_union, 'int'), expected, my_union from empty_dense_union; - query error -select union_extract() from empty_dense_union; +select union_extract() from union_table; ---- -DataFusion error: Error during planning: Error during planning: union_extract does not support zero arguments. No function matches the given name and argument types 'union_extract()'. You might need to add explicit type casts. +DataFusion error: Error during planning: union_extract does not support zero arguments. No function matches the given name and argument types 'union_extract()'. You might need to add explicit type casts. Candidate functions: - union_extract(UserDefined) + union_extract(Any, Any) query error -select union_extract(my_union) from empty_dense_union; +select union_extract(union_column) from union_table; ---- -DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract expects 2 arguments, got 1 instead") No function matches the given name and argument types 'union_extract(Union([], Dense))'. You might need to add explicit type casts. +DataFusion error: Error during planning: The function expected 2 arguments but received 1 No function matches the given name and argument types 'union_extract(Union([(3, Field { name: "int", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })], Sparse))'. You might need to add explicit type casts. Candidate functions: - union_extract(UserDefined) + union_extract(Any, Any) query error -select union_extract('a') from empty_dense_union; +select union_extract('a') from union_table; ---- -DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract expects 2 arguments, got 1 instead") No function matches the given name and argument types 'union_extract(Utf8)'. You might need to add explicit type casts. +DataFusion error: Error during planning: The function expected 2 arguments but received 1 No function matches the given name and argument types 'union_extract(Utf8)'. You might need to add explicit type casts. Candidate functions: - union_extract(UserDefined) + union_extract(Any, Any) query error -select union_extract('a', my_union) from empty_dense_union; +select union_extract('a', union_column) from union_table; ---- -DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract first argument must be a union, got Utf8 instead") No function matches the given name and argument types 'union_extract(Utf8, Union([], Dense))'. You might need to add explicit type casts. - Candidate functions: - union_extract(UserDefined) +DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead query error -select union_extract(my_union, 1) from empty_dense_union; +select union_extract(union_column, 1) from union_table; ---- -DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract second argument must be a non-null string literal, got Int64 instead") No function matches the given name and argument types 'union_extract(Union([], Dense), Int64)'. You might need to add explicit type casts. - Candidate functions: - union_extract(UserDefined) +DataFusion error: Execution error: union_extract second argument must be a non-null string literal, got Int64 instead query error -select union_extract(my_union, 'a', 'b') from empty_dense_union; +select union_extract(union_column, 'a', 'b') from union_table; ---- -DataFusion error: Error during planning: Execution error: User-defined coercion failed with Execution("union_extract expects 2 arguments, got 3 instead") No function matches the given name and argument types 'union_extract(Union([], Dense), Utf8, Utf8)'. You might need to add explicit type casts. +DataFusion error: Error during planning: The function expected 2 arguments but received 3 No function matches the given name and argument types 'union_extract(Union([(3, Field { name: "int", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })], Sparse), Utf8, Utf8)'. You might need to add explicit type casts. Candidate functions: - union_extract(UserDefined) + union_extract(Any, Any) From fad85ea5c9114d2b64cb062500a263b517ae2405 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 19 Jan 2025 21:36:05 -0300 Subject: [PATCH 07/15] refactor(union_extract): new udf api, docs macro, use any signature --- .../functions/src/core/union_extract.rs | 137 ++++++++++-------- 1 file changed, 74 insertions(+), 63 deletions(-) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index acb364d4ef05..8054179f16dc 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -21,9 +21,30 @@ use datafusion_common::cast::as_union_array; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, ExprSchema, Result, ScalarValue, }; -use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_doc::Documentation; +use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; - +use datafusion_macros::user_doc; + +#[user_doc( + doc_section(include = "true", label = "Union Functions"), + description = "Returns the value of the given field when selected, or NULL otherwise.", + syntax_example = "union_extract(union, field_name)", + sql_example = r#"```sql +❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; ++--------------+----------------------------------+----------------------------------+ +| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') | ++--------------+----------------------------------+----------------------------------+ +| {a=1} | 1 | | +| {b=3.0} | | 3.0 | +| {a=4} | 4 | | +| {b=} | | | +| {a=} | | | ++--------------+----------------------------------+----------------------------------+ +```"#, + standard_argument(name = "union", prefix = "Union"), + standard_argument(name = "field_name", prefix = "String") +)] #[derive(Debug)] pub struct UnionExtractFun { signature: Signature, @@ -38,7 +59,7 @@ impl Default for UnionExtractFun { impl UnionExtractFun { pub fn new() -> Self { Self { - signature: Signature::user_defined(Volatility::Immutable), + signature: Signature::any(2, Volatility::Immutable), } } } @@ -93,7 +114,9 @@ impl ScalarUDFImpl for UnionExtractFun { Ok(field.data_type().clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = args.args; + if args.len() != 2 { return exec_err!( "union_extract expects 2 arguments, got {} instead", @@ -101,30 +124,27 @@ impl ScalarUDFImpl for UnionExtractFun { ); } - let union = &args[0]; - let target_name = match &args[1] { ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name), ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"), _ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()), }; - match union { + match &args[0] { ColumnarValue::Array(array) => { - let _union_array = as_union_array(&array).map_err(|_| { + let union_array = as_union_array(&array).map_err(|_| { exec_datafusion_err!( "union_extract first argument must be a union, got {} instead", array.data_type() ) })?; - // Ok(arrow::compute::kernels::union_extract::union_extract( - // &union_array, - // target_name, - // )?) - Ok(ColumnarValue::Array(std::sync::Arc::new( - arrow::array::Int32Array::from(vec![1, 2]), - ))) + Ok(ColumnarValue::Array( + arrow::compute::kernels::union_extract::union_extract( + union_array, + target_name?, + )?, + )) } ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => { let target_name = target_name?; @@ -146,29 +166,8 @@ impl ScalarUDFImpl for UnionExtractFun { } } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() != 2 { - return exec_err!( - "union_extract expects 2 arguments, got {} instead", - arg_types.len() - ); - } - - if !matches!(arg_types[0], DataType::Union(_, _)) { - return exec_err!( - "union_extract first argument must be a union, got {} instead", - arg_types[0] - ); - } - - if !matches!(arg_types[1], DataType::Utf8) { - return exec_err!( - "union_extract second argument must be a non-null string literal, got {} instead", - arg_types[1] - ); - } - - Ok(arg_types.to_vec()) + fn documentation(&self) -> Option<&Documentation> { + self.doc() } } @@ -184,7 +183,7 @@ mod tests { use arrow::datatypes::{DataType, Field, UnionFields, UnionMode}; use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use super::UnionExtractFun; @@ -201,36 +200,48 @@ mod tests { ], ); - let result = fun.invoke(&[ - ColumnarValue::Scalar(ScalarValue::Union( - None, - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ])?; + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; assert_scalar(result, ScalarValue::Utf8(None)); - let result = fun.invoke(&[ - ColumnarValue::Scalar(ScalarValue::Union( - Some((3, Box::new(ScalarValue::Int32(Some(42))))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ])?; + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; assert_scalar(result, ScalarValue::Utf8(None)); - let result = fun.invoke(&[ - ColumnarValue::Scalar(ScalarValue::Union( - Some((1, Box::new(ScalarValue::new_utf8("42")))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ])?; + let result = fun.invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ], + number_rows: 1, + return_type: &DataType::Utf8, + })?; assert_scalar(result, ScalarValue::new_utf8("42")); From 4cf32e7783dcb57d56f2a533f897026dfed96977 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sun, 19 Jan 2025 22:49:10 -0300 Subject: [PATCH 08/15] fix: remove user_doc include attribute --- datafusion/functions/src/core/union_extract.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 8054179f16dc..d01880549bc4 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -27,7 +27,7 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( - doc_section(include = "true", label = "Union Functions"), + doc_section(label = "Union Functions"), description = "Returns the value of the given field when selected, or NULL otherwise.", syntax_example = "union_extract(union, field_name)", sql_example = r#"```sql From 718323a4e79f0a67ca6220431cd98904cddf3d8c Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 20 Jan 2025 01:55:28 -0300 Subject: [PATCH 09/15] fix: generate docs --- datafusion/functions/src/core/union_extract.rs | 2 +- docs/source/user-guide/sql/scalar_functions.md | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index d01880549bc4..7b313d3a2974 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -43,7 +43,7 @@ use datafusion_macros::user_doc; +--------------+----------------------------------+----------------------------------+ ```"#, standard_argument(name = "union", prefix = "Union"), - standard_argument(name = "field_name", prefix = "String") + argument(name = "field_name", description = "String expression to operate on. Must be a constant.") )] #[derive(Debug)] pub struct UnionExtractFun { diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 63d89579e1f3..ea4e2195f5c8 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -4347,10 +4347,12 @@ union_extract(union, field_name) #### Arguments -- **union**: Union expression to extract the field from -- **field_name**: Literal string, name of the field to extract +- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **field_name**: String expression to operate on. Must be a constant. -``` +#### Example + +```sql ❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; +--------------+----------------------------------+----------------------------------+ | union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') | From 593a22bed13a999bf3381aed0f7bb269c5d08e70 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 20 Jan 2025 01:58:31 -0300 Subject: [PATCH 10/15] fix: manually trim sqllogictest generated errors --- .../test_files/union_function.slt | 48 +++---------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index 95a923d765f2..2e77d0e6b861 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -19,65 +19,29 @@ ## UNION DataType Tests ########## - query ?I select union_column, union_extract(union_column, 'int') from union_table; ---- {int=1} 1 {int=2} 2 - - query error DataFusion error: Execution error: field bool not found on union select union_extract(union_column, 'bool') from union_table; - - -query error +query error DataFusion error: Error during planning: union_extract does not support zero arguments select union_extract() from union_table; ----- -DataFusion error: Error during planning: union_extract does not support zero arguments. No function matches the given name and argument types 'union_extract()'. You might need to add explicit type casts. - Candidate functions: - union_extract(Any, Any) - - -query error +query error DataFusion error: Error during planning: The function expected 2 arguments but received 1 select union_extract(union_column) from union_table; ----- -DataFusion error: Error during planning: The function expected 2 arguments but received 1 No function matches the given name and argument types 'union_extract(Union([(3, Field { name: "int", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })], Sparse))'. You might need to add explicit type casts. - Candidate functions: - union_extract(Any, Any) - - -query error +query error DataFusion error: Error during planning: The function expected 2 arguments but received 1 select union_extract('a') from union_table; ----- -DataFusion error: Error during planning: The function expected 2 arguments but received 1 No function matches the given name and argument types 'union_extract(Utf8)'. You might need to add explicit type casts. - Candidate functions: - union_extract(Any, Any) - - - -query error +query error DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead select union_extract('a', union_column) from union_table; ----- -DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead - - -query error +query error DataFusion error: Execution error: union_extract second argument must be a non\-null string literal, got Int64 instead select union_extract(union_column, 1) from union_table; ----- -DataFusion error: Execution error: union_extract second argument must be a non-null string literal, got Int64 instead - - -query error +query error DataFusion error: Error during planning: The function expected 2 arguments but received 3 select union_extract(union_column, 'a', 'b') from union_table; ----- -DataFusion error: Error during planning: The function expected 2 arguments but received 3 No function matches the given name and argument types 'union_extract(Union([(3, Field { name: "int", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })], Sparse), Utf8, Utf8)'. You might need to add explicit type casts. - Candidate functions: - union_extract(Any, Any) From 203df651bb3ab601ab98e6f04a44b94954fc8f64 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 20 Jan 2025 02:13:06 -0300 Subject: [PATCH 11/15] fix: fmt --- datafusion/functions/src/core/union_extract.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 7b313d3a2974..ce8cde04181b 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -43,7 +43,10 @@ use datafusion_macros::user_doc; +--------------+----------------------------------+----------------------------------+ ```"#, standard_argument(name = "union", prefix = "Union"), - argument(name = "field_name", description = "String expression to operate on. Must be a constant.") + argument( + name = "field_name", + description = "String expression to operate on. Must be a constant." + ) )] #[derive(Debug)] pub struct UnionExtractFun { From 21d0548554a4209ca862837a725649b8e64a8634 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 8 Feb 2025 16:21:39 -0300 Subject: [PATCH 12/15] docs: add union functions section description --- datafusion/expr/src/udf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 679027cbd8e0..7095a8e68c07 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -956,6 +956,6 @@ The following regular expression functions are supported:"#, pub const DOC_SECTION_UNION: DocSection = DocSection { include: true, label: "Union Functions", - description: None, + description: Some("Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator"), }; } From 0802e0927561d04e45eec09da322f35b456a2f7a Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 8 Feb 2025 16:47:39 -0300 Subject: [PATCH 13/15] docs: update functions docs --- datafusion/expr/src/udf.rs | 1 + docs/source/user-guide/sql/scalar_functions.md | 2 ++ 2 files changed, 3 insertions(+) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 60b0561db909..54a692f2f3aa 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -997,6 +997,7 @@ pub mod scalar_doc_sections { DOC_SECTION_STRUCT, DOC_SECTION_MAP, DOC_SECTION_HASHING, + DOC_SECTION_UNION, DOC_SECTION_OTHER, ] } diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index bbb954e31af5..5e8a89e9aa8c 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -4341,6 +4341,8 @@ sha512(expression) ## Union Functions +Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator + - [union_extract](#union_extract) ### `union_extract` From 5acb570cc68b1d5cbc0c527e95edc00b04e65ab5 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 8 Feb 2025 16:50:29 -0300 Subject: [PATCH 14/15] docs: clarify union_extract description Co-authored-by: Bruce Ritchie --- datafusion/functions/src/core/union_extract.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index ce8cde04181b..aa1206b96a78 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -28,7 +28,7 @@ use datafusion_macros::user_doc; #[user_doc( doc_section(label = "Union Functions"), - description = "Returns the value of the given field when selected, or NULL otherwise.", + description = "Returns the value of the given field in the union when selected, or NULL otherwise.", syntax_example = "union_extract(union, field_name)", sql_example = r#"```sql ❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union; From f11194fbe97ff4efb3c41b50661af073626535fd Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 8 Feb 2025 18:32:49 -0300 Subject: [PATCH 15/15] fix: use return_type_from_args, tests, docs --- .../functions/src/core/union_extract.rs | 25 ++++++++----------- .../test_files/union_function.slt | 8 +++--- .../source/user-guide/sql/scalar_functions.md | 2 +- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index aa1206b96a78..d54627f73598 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -19,10 +19,10 @@ use arrow::array::Array; use arrow::datatypes::{DataType, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, ExprSchema, Result, ScalarValue, + exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; use datafusion_doc::Documentation; -use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs}; +use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -85,36 +85,31 @@ impl ScalarUDFImpl for UnionExtractFun { internal_err!("union_extract should return type from exprs") } - fn return_type_from_exprs( - &self, - args: &[Expr], - _: &dyn ExprSchema, - arg_types: &[DataType], - ) -> Result { - if args.len() != 2 { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + if args.arg_types.len() != 2 { return exec_err!( "union_extract expects 2 arguments, got {} instead", - args.len() + args.arg_types.len() ); } - let DataType::Union(fields, _) = &arg_types[0] else { + let DataType::Union(fields, _) = &args.arg_types[0] else { return exec_err!( "union_extract first argument must be a union, got {} instead", - arg_types[0] + args.arg_types[0] ); }; - let Expr::Literal(ScalarValue::Utf8(Some(field_name))) = &args[1] else { + let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { return exec_err!( "union_extract second argument must be a non-null string literal, got {} instead", - arg_types[1] + args.arg_types[1] ); }; let field = find_field(fields, field_name)?.1; - Ok(field.data_type().clone()) + Ok(ReturnInfo::new_nullable(field.data_type().clone())) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/sqllogictest/test_files/union_function.slt b/datafusion/sqllogictest/test_files/union_function.slt index 2e77d0e6b861..9c70b1011f58 100644 --- a/datafusion/sqllogictest/test_files/union_function.slt +++ b/datafusion/sqllogictest/test_files/union_function.slt @@ -28,13 +28,13 @@ select union_column, union_extract(union_column, 'int') from union_table; query error DataFusion error: Execution error: field bool not found on union select union_extract(union_column, 'bool') from union_table; -query error DataFusion error: Error during planning: union_extract does not support zero arguments +query error DataFusion error: Error during planning: 'union_extract' does not support zero arguments select union_extract() from union_table; -query error DataFusion error: Error during planning: The function expected 2 arguments but received 1 +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 1 select union_extract(union_column) from union_table; -query error DataFusion error: Error during planning: The function expected 2 arguments but received 1 +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 1 select union_extract('a') from union_table; query error DataFusion error: Execution error: union_extract first argument must be a union, got Utf8 instead @@ -43,5 +43,5 @@ select union_extract('a', union_column) from union_table; query error DataFusion error: Execution error: union_extract second argument must be a non\-null string literal, got Int64 instead select union_extract(union_column, 1) from union_table; -query error DataFusion error: Error during planning: The function expected 2 arguments but received 3 +query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3 select union_extract(union_column, 'a', 'b') from union_table; diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 5e8a89e9aa8c..6ebca7613660 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -4347,7 +4347,7 @@ Functions to work with the union data type, also know as tagged unions, variant ### `union_extract` -Returns the value of the given field when selected, or NULL otherwise. +Returns the value of the given field in the union when selected, or NULL otherwise. ``` union_extract(union, field_name)