From 65ab5179106a143864bdf433c08c1fb6ac226043 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 13 Aug 2022 07:17:03 -0400 Subject: [PATCH] Rename `array()` function to `make_array()`, extend `array[]` --- datafusion/core/src/logical_plan/mod.rs | 22 ++-- datafusion/core/src/prelude.rs | 8 +- datafusion/core/tests/sql/functions.rs | 113 ++++++++++++++---- datafusion/expr/src/built_in_function.rs | 12 +- datafusion/expr/src/expr_fn.rs | 4 +- datafusion/expr/src/function.rs | 12 +- .../physical-expr/src/array_expressions.rs | 41 ++++--- datafusion/physical-expr/src/functions.rs | 2 +- datafusion/sql/src/planner.rs | 75 ++---------- 9 files changed, 157 insertions(+), 132 deletions(-) diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index 87a02ae0118c0..39d3af4a20ce3 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -27,11 +27,11 @@ pub use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema, }; pub use datafusion_expr::{ - abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, - atan2, avg, bit_length, btrim, call_fn, case, cast, ceil, character_length, chr, - coalesce, col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, - count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, - exists, exp, expr_rewriter, + abs, acos, and, approx_distinct, approx_percentile_cont, ascii, asin, atan, atan2, + avg, bit_length, btrim, call_fn, case, cast, ceil, character_length, chr, coalesce, + col, combine_filters, concat, concat_expr, concat_ws, concat_ws_expr, cos, count, + count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp, + expr_rewriter, expr_rewriter::{ normalize_col, normalize_col_with_schemas, normalize_cols, replace_col, rewrite_sort_cols_by_aggs, unnormalize_col, unnormalize_cols, ExprRewritable, @@ -50,11 +50,11 @@ pub use datafusion_expr::{ StringifiedPlan, Subquery, TableScan, ToStringifiedPlan, Union, UserDefinedLogicalNode, Values, }, - lower, lpad, ltrim, max, md5, min, not_exists, not_in_subquery, now, now_expr, - nullif, octet_length, or, power, random, regexp_match, regexp_replace, repeat, - replace, reverse, right, round, rpad, rtrim, scalar_subquery, sha224, sha256, sha384, - sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, - to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trim, - trunc, unalias, upper, when, Expr, ExprSchemable, Literal, Operator, + lower, lpad, ltrim, make_array, max, md5, min, not_exists, not_in_subquery, now, + now_expr, nullif, octet_length, or, power, random, regexp_match, regexp_replace, + repeat, replace, reverse, right, round, rpad, rtrim, scalar_subquery, sha224, sha256, + sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, + to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, + trim, trunc, unalias, upper, when, Expr, ExprSchemable, Literal, Operator, }; pub use datafusion_optimizer::expr_simplifier::{ExprSimplifiable, SimplifyInfo}; diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index edae225d87474..a4cc8e5c3c927 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -31,10 +31,10 @@ pub use crate::execution::options::{ AvroReadOptions, CsvReadOptions, NdJsonReadOptions, ParquetReadOptions, }; pub use crate::logical_plan::{ - approx_percentile_cont, array, ascii, avg, bit_length, btrim, cast, character_length, - chr, coalesce, col, concat, concat_ws, count, create_udf, date_part, date_trunc, - digest, exists, from_unixtime, in_list, in_subquery, initcap, left, length, lit, - lower, lpad, ltrim, max, md5, min, not_exists, not_in_subquery, now, octet_length, + approx_percentile_cont, ascii, avg, bit_length, btrim, cast, character_length, chr, + coalesce, col, concat, concat_ws, count, create_udf, date_part, date_trunc, digest, + exists, from_unixtime, in_list, in_subquery, initcap, left, length, lit, lower, lpad, + ltrim, make_array, max, md5, min, not_exists, not_in_subquery, now, octet_length, random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, scalar_subquery, sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, Column, Expr, JoinType, Partitioning, diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 88c00b45ef1cf..aa5a6725dcf59 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -111,8 +111,8 @@ async fn query_concat() -> Result<()> { Ok(()) } -#[tokio::test] -async fn query_array() -> Result<()> { +// Return a session context with table "test" registered with 2 columns +fn array_context() -> SessionContext { let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Utf8, false), Field::new("c2", DataType::Int32, true), @@ -124,43 +124,110 @@ async fn query_array() -> Result<()> { Arc::new(StringArray::from_slice(&["", "a", "aa", "aaa"])), Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])), ], - )?; + ) + .unwrap(); - let table = MemTable::try_new(schema, vec![vec![data]])?; + let table = MemTable::try_new(schema, vec![vec![data]]).unwrap(); let ctx = SessionContext::new(); - ctx.register_table("test", Arc::new(table))?; - let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; + ctx.register_table("test", Arc::new(table)).unwrap(); + ctx +} + +#[tokio::test] +async fn query_array() { + let ctx = array_context(); + let sql = "SELECT array[c1, cast(c2 as varchar)] FROM test"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+--------------------------------------+", - "| array(test.c1,CAST(test.c2 AS Utf8)) |", - "+--------------------------------------+", - "| [, 0] |", - "| [a, 1] |", - "| [aa, ] |", - "| [aaa, 3] |", - "+--------------------------------------+", + "+----------+", + "| array |", + "+----------+", + "| [, 0] |", + "| [a, 1] |", + "| [aa, ] |", + "| [aaa, 3] |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn query_make_array() { + let ctx = array_context(); + let sql = "SELECT make_array(c1, cast(c2 as varchar)) FROM test"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------------------------------------+", + "| makearray(test.c1,CAST(test.c2 AS Utf8)) |", + "+------------------------------------------+", + "| [, 0] |", + "| [a, 1] |", + "| [aa, ] |", + "| [aaa, 3] |", + "+------------------------------------------+", ]; assert_batches_eq!(expected, &actual); - Ok(()) } #[tokio::test] -async fn query_array_scalar() -> Result<()> { +async fn query_array_scalar() { let ctx = SessionContext::new(); - let sql = "SELECT array(1, 2, 3);"; + let sql = "SELECT array[1, 2, 3];"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ - "+-----------------------------------+", - "| array(Int64(1),Int64(2),Int64(3)) |", - "+-----------------------------------+", - "| [1, 2, 3] |", - "+-----------------------------------+", + "+-----------+", + "| array |", + "+-----------+", + "| [1, 2, 3] |", + "+-----------+", + ]; + assert_batches_eq!(expected, &actual); + + // alternate syntax format + let sql = "SELECT [1, 2, 3];"; + let actual = execute_to_batches(&ctx, sql).await; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn query_array_scalar_bad_types() { + let ctx = SessionContext::new(); + + // no common type to coerce to, should error + let err = plan_and_collect(&ctx, "SELECT [1, true, null]") + .await + .unwrap_err(); + assert_eq!(err.to_string(), "Error during planning: Coercion from [Int64, Boolean, Null] to the signature VariadicEqual failed.",); +} + +#[tokio::test] +async fn query_array_scalar_coerce() { + let ctx = SessionContext::new(); + + // The planner should be able to coerce this to all integers + // https://github.com/apache/arrow-datafusion/issues/3170 + let err = plan_and_collect(&ctx, "SELECT [1, 2, '3']") + .await + .unwrap_err(); + assert_eq!(err.to_string(), "Error during planning: Coercion from [Int64, Int64, Utf8] to the signature VariadicEqual failed.",); +} + +#[tokio::test] +async fn query_make_array_scalar() { + let ctx = SessionContext::new(); + + let sql = "SELECT make_array(1, 2, 3);"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+---------------------------------------+", + "| makearray(Int64(1),Int64(2),Int64(3)) |", + "+---------------------------------------+", + "| [1, 2, 3] |", + "+---------------------------------------+", ]; assert_batches_eq!(expected, &actual); - Ok(()) } #[tokio::test] diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 532699a37cbbb..45214266fccf5 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -71,9 +71,11 @@ pub enum BuiltinScalarFunction { /// trunc Trunc, - // string functions + // array functions /// construct an array from columns - Array, + MakeArray, + + // string functions /// ascii Ascii, /// bit_length @@ -204,7 +206,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Sqrt => Volatility::Immutable, BuiltinScalarFunction::Tan => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::Array => Volatility::Immutable, + BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::Btrim => Volatility::Immutable, @@ -297,8 +299,10 @@ impl FromStr for BuiltinScalarFunction { // conditional functions "coalesce" => BuiltinScalarFunction::Coalesce, + // array functions + "make_array" => BuiltinScalarFunction::MakeArray, + // string functions - "array" => BuiltinScalarFunction::Array, "ascii" => BuiltinScalarFunction::Ascii, "bit_length" => BuiltinScalarFunction::BitLength, "btrim" => BuiltinScalarFunction::Btrim, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 09ac0c2870413..1731d42640d4b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -382,9 +382,9 @@ scalar_expr!(FromUnixtime, from_unixtime, unixtime); unary_scalar_expr!(ArrowTypeof, arrow_typeof, "data type"); /// Returns an array of fixed size with each argument on it. -pub fn array(args: Vec) -> Expr { +pub fn make_array(args: Vec) -> Expr { Expr::ScalarFunction { - fun: built_in_function::BuiltinScalarFunction::Array, + fun: built_in_function::BuiltinScalarFunction::MakeArray, args, } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 5cf42fbd21243..1d7de6b651ebc 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -21,8 +21,8 @@ use crate::nullif::SUPPORTED_NULLIF_TYPES; use crate::type_coercion::data_types; use crate::ColumnarValue; use crate::{ - array_expressions, conditional_expressions, struct_expressions, Accumulator, - BuiltinScalarFunction, Signature, TypeSignature, + conditional_expressions, struct_expressions, Accumulator, BuiltinScalarFunction, + Signature, TypeSignature, }; use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; use datafusion_common::{DataFusionError, Result}; @@ -96,7 +96,7 @@ pub fn return_type( // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match fun { - BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( + BuiltinScalarFunction::MakeArray => Ok(DataType::FixedSizeList( Box::new(Field::new("item", input_expr_types[0].clone(), true)), input_expr_types.len() as i32, )), @@ -267,12 +267,8 @@ pub fn return_type( pub fn signature(fun: &BuiltinScalarFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. - // for now, the list is small, as we do not have many built-in functions. match fun { - BuiltinScalarFunction::Array => Signature::variadic( - array_expressions::SUPPORTED_ARRAY_TYPES.to_vec(), - fun.volatility(), - ), + BuiltinScalarFunction::MakeArray => Signature::variadic_equal(fun.volatility()), BuiltinScalarFunction::Struct => Signature::variadic( struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), fun.volatility(), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 84e6732e39997..216ccef46d433 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -34,7 +34,10 @@ macro_rules! downcast_vec { }}; } -macro_rules! array { +/// Create an array of FixedSizeList from a set of individual Arrays +/// where each element in the output FixedSizeList is the result of +/// concatenating the corresponding values in the input Arrays +macro_rules! make_fixed_size_list { ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ // downcast all arguments to their common format let args = @@ -59,7 +62,7 @@ macro_rules! array { }}; } -fn array_array(args: &[ArrayRef]) -> Result { +fn arrays_to_fixed_size_list_array(args: &[ArrayRef]) -> Result { // do not accept 0 arguments. if args.is_empty() { return Err(DataFusionError::Internal( @@ -68,19 +71,21 @@ fn array_array(args: &[ArrayRef]) -> Result { } let res = match args[0].data_type() { - DataType::Utf8 => array!(args, StringArray, StringBuilder), - DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder), - DataType::Boolean => array!(args, BooleanArray, BooleanBuilder), - DataType::Float32 => array!(args, Float32Array, Float32Builder), - DataType::Float64 => array!(args, Float64Array, Float64Builder), - DataType::Int8 => array!(args, Int8Array, Int8Builder), - DataType::Int16 => array!(args, Int16Array, Int16Builder), - DataType::Int32 => array!(args, Int32Array, Int32Builder), - DataType::Int64 => array!(args, Int64Array, Int64Builder), - DataType::UInt8 => array!(args, UInt8Array, UInt8Builder), - DataType::UInt16 => array!(args, UInt16Array, UInt16Builder), - DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), - DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), + DataType::Utf8 => make_fixed_size_list!(args, StringArray, StringBuilder), + DataType::LargeUtf8 => { + make_fixed_size_list!(args, LargeStringArray, LargeStringBuilder) + } + DataType::Boolean => make_fixed_size_list!(args, BooleanArray, BooleanBuilder), + DataType::Float32 => make_fixed_size_list!(args, Float32Array, Float32Builder), + DataType::Float64 => make_fixed_size_list!(args, Float64Array, Float64Builder), + DataType::Int8 => make_fixed_size_list!(args, Int8Array, Int8Builder), + DataType::Int16 => make_fixed_size_list!(args, Int16Array, Int16Builder), + DataType::Int32 => make_fixed_size_list!(args, Int32Array, Int32Builder), + DataType::Int64 => make_fixed_size_list!(args, Int64Array, Int64Builder), + DataType::UInt8 => make_fixed_size_list!(args, UInt8Array, UInt8Builder), + DataType::UInt16 => make_fixed_size_list!(args, UInt16Array, UInt16Builder), + DataType::UInt32 => make_fixed_size_list!(args, UInt32Array, UInt32Builder), + DataType::UInt64 => make_fixed_size_list!(args, UInt64Array, UInt64Builder), data_type => { return Err(DataFusionError::NotImplemented(format!( "Array is not implemented for type '{:?}'.", @@ -92,7 +97,7 @@ fn array_array(args: &[ArrayRef]) -> Result { } /// put values in an array. -pub fn array(values: &[ColumnarValue]) -> Result { +pub fn make_array(values: &[ColumnarValue]) -> Result { let arrays: Vec = values .iter() .map(|x| match x { @@ -100,5 +105,7 @@ pub fn array(values: &[ColumnarValue]) -> Result { ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), }) .collect(); - Ok(ColumnarValue::Array(array_array(arrays.as_slice())?)) + Ok(ColumnarValue::Array(arrays_to_fixed_size_list_array( + arrays.as_slice(), + )?)) } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index dde0ee0a06bef..c23a54e908b6d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -322,7 +322,7 @@ pub fn create_physical_fun( } // string functions - BuiltinScalarFunction::Array => Arc::new(array_expressions::array), + BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::make_array), BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 9f974aa4c055d..7df98f6dddbe1 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -1667,7 +1667,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fractional_seconds_precision, ), - SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), + SQLExpr::Array(arr) => self.sql_array_expr(arr.elem, schema), SQLExpr::Identifier(id) => { if id.value.starts_with('@') { @@ -2360,50 +2360,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .is_ok() } - fn sql_array_literal( - &self, - elements: Vec, - schema: &DFSchema, - ) -> Result { - let mut values = Vec::with_capacity(elements.len()); - - for element in elements { - let value = - self.sql_expr_to_logical_expr(element, schema, &mut HashMap::new())?; - match value { - Expr::Literal(scalar) => { - values.push(scalar); - } - _ => { - return Err(DataFusionError::NotImplemented(format!( - "Arrays with elements other than literal are not supported: {}", - value - ))); - } - } - } - - let data_types: HashSet = - values.iter().map(|e| e.get_datatype()).collect(); - - if data_types.is_empty() { - Ok(Expr::Literal(ScalarValue::List( - None, - Box::new(Field::new("item", DataType::Utf8, true)), - ))) - } else if data_types.len() > 1 { - Err(DataFusionError::NotImplemented(format!( - "Arrays with different types are not supported: {:?}", - data_types, - ))) - } else { - let data_type = values[0].get_datatype(); + fn sql_array_expr(&self, elements: Vec, schema: &DFSchema) -> Result { + let args: Vec = elements + .into_iter() + .map(|expr| self.sql_expr_to_logical_expr(expr, schema, &mut HashMap::new())) + .collect::>()?; - Ok(Expr::Literal(ScalarValue::List( - Some(values), - Box::new(Field::new("item", data_type, true)), - ))) - } + let fun = BuiltinScalarFunction::MakeArray; + // follow postgres convention and name result "array" + Ok(Expr::ScalarFunction { fun, args }.alias("array")) } } @@ -2594,7 +2559,6 @@ fn parse_sql_number(n: &str) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::assert_contains; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use std::any::Any; @@ -3330,25 +3294,12 @@ mod tests { ); } - #[test] - fn select_array_no_common_type() { - let sql = "SELECT [1, true, null]"; - let err = logical_plan(sql).expect_err("query should have failed"); - - // HashSet doesn't guarantee order - assert_contains!( - err.to_string(), - r#"Arrays with different types are not supported: "# - ); - } - #[test] fn select_array_non_literal_type() { - let sql = "SELECT [now()]"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - r#"NotImplemented("Arrays with elements other than literal are not supported: now()")"#, - format!("{:?}", err) + quick_test( + "SELECT [now()]", + "Projection: makearray(now()) AS array\ + \n EmptyRelation", ); }