From 2b30e33c0edc5f564cd9330101533c6243acb2a9 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 30 Aug 2020 06:57:57 +0200 Subject: [PATCH 1/4] Added support for array physical expression. --- rust/datafusion/README.md | 2 + rust/datafusion/src/logical_plan/mod.rs | 8 + .../src/physical_plan/array_expressions.rs | 108 +++++++++++ .../datafusion/src/physical_plan/functions.rs | 82 ++++++++- rust/datafusion/src/physical_plan/mod.rs | 1 + rust/datafusion/src/prelude.rs | 2 +- rust/datafusion/tests/sql.rs | 172 ++++++++++-------- 7 files changed, 299 insertions(+), 76 deletions(-) create mode 100644 rust/datafusion/src/physical_plan/array_expressions.rs diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index bf3c60f3f43ac..5405f269d2813 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -61,6 +61,8 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [ ] Basic date functions - [ ] Basic time functions - [x] Basic timestamp functions +- nested functions + - [x] Array of columns - [x] Sorting - [ ] Nested types - [ ] Lists diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index e0f5d9d65d86f..e37bd1003d66e 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -623,6 +623,14 @@ pub fn concat(args: Vec) -> Expr { } } +/// returns an array of fixed size with each argument on it. +pub fn array(args: Vec) -> Expr { + Expr::ScalarFunction { + fun: functions::BuiltinScalarFunction::Array, + args, + } +} + /// Creates a new UDF with a specific signature and specific return type. /// This is a helper function to create a new UDF. /// The function `create_udf` returns a subset of all possible `ScalarFunction`: diff --git a/rust/datafusion/src/physical_plan/array_expressions.rs b/rust/datafusion/src/physical_plan/array_expressions.rs new file mode 100644 index 0000000000000..79fb64e795ae7 --- /dev/null +++ b/rust/datafusion/src/physical_plan/array_expressions.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Array expressions + +use crate::error::{ExecutionError, Result}; +use arrow::array::*; +use arrow::datatypes::DataType; +use std::sync::Arc; + +macro_rules! downcast_vec { + ($ARGS:expr, $ARRAY_TYPE:ident) => {{ + $ARGS + .iter() + .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() { + Some(array) => Ok(array), + _ => Err(ExecutionError::General("failed to downcast".to_string())), + }) + }}; +} + +macro_rules! array { + ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{ + // downcast all arguments to their common format + let args = + downcast_vec!($ARGS, $ARRAY_TYPE).collect::>>()?; + + let mut builder = FixedSizeListBuilder::<$BUILDER_TYPE>::new( + <$BUILDER_TYPE>::new(args[0].len()), + args.len() as i32, + ); + // for each entry in the array + for index in 0..args[0].len() { + for arg in &args { + if arg.is_null(index) { + builder.values().append_null()?; + } else { + builder.values().append_value(arg.value(index))?; + } + } + builder.append(true)?; + } + Ok(Arc::new(builder.finish())) + }}; +} + +/// put values in an array. +pub fn array(args: &[ArrayRef]) -> Result { + // do not accept 0 arguments. + if args.len() == 0 { + return Err(ExecutionError::InternalError( + "array requires at least one argument".to_string(), + )); + } + + match args[0].data_type() { + DataType::Utf8 => array!(args, StringArray, StringBuilder), + DataType::LargeUtf8 => array!(args, LargeStringArray, LargeStringBuilder), + DataType::Boolean => array!(args, BooleanArray, BooleanBuilder), + DataType::Float32 => array!(args, Float32Array, Float32Builder), + DataType::Float64 => array!(args, Float64Array, Float64Builder), + DataType::Int8 => array!(args, Int8Array, Int8Builder), + DataType::Int16 => array!(args, Int16Array, Int16Builder), + DataType::Int32 => array!(args, Int32Array, Int32Builder), + DataType::Int64 => array!(args, Int64Array, Int64Builder), + DataType::UInt8 => array!(args, UInt8Array, UInt8Builder), + DataType::UInt16 => array!(args, UInt16Array, UInt16Builder), + DataType::UInt32 => array!(args, UInt32Array, UInt32Builder), + DataType::UInt64 => array!(args, UInt64Array, UInt64Builder), + data_type => Err(ExecutionError::NotImplemented(format!( + "Array is not implemented for type '{:?}'.", + data_type + ))), + } +} + +/// Currently supported types by the array function. +/// The order of these types correspond to the order on which coercion applies +/// This should thus be from least informative to most informative +pub static SUPPORTED_ARRAY_TYPES: &'static [DataType] = &[ + DataType::Boolean, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Utf8, + DataType::LargeUtf8, +]; diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index af02c6d2cefa7..ace965dfbb6b5 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -34,6 +34,7 @@ use super::{ PhysicalExpr, }; use crate::error::{ExecutionError, Result}; +use crate::physical_plan::array_expressions; use crate::physical_plan::datetime_expressions; use crate::physical_plan::math_expressions; use crate::physical_plan::string_expressions; @@ -118,6 +119,8 @@ pub enum BuiltinScalarFunction { Concat, /// to_timestamp ToTimestamp, + /// construct an array from columns + Array, } impl fmt::Display for BuiltinScalarFunction { @@ -151,6 +154,7 @@ impl FromStr for BuiltinScalarFunction { "length" => BuiltinScalarFunction::Length, "concat" => BuiltinScalarFunction::Concat, "to_timestamp" => BuiltinScalarFunction::ToTimestamp, + "array" => BuiltinScalarFunction::Array, _ => { return Err(ExecutionError::General(format!( "There is no built-in function named {}", @@ -189,6 +193,10 @@ pub fn return_type( BuiltinScalarFunction::ToTimestamp => { Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) } + BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList( + Box::new(arg_types[0].clone()), + arg_types.len() as i32, + )), _ => Ok(DataType::Float64), } } @@ -225,6 +233,7 @@ pub fn create_physical_expr( BuiltinScalarFunction::ToTimestamp => { |args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?)) } + BuiltinScalarFunction::Array => |args| Ok(array_expressions::array(args)?), }); // coerce let args = coerce(args, input_schema, &signature(fun))?; @@ -251,6 +260,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { BuiltinScalarFunction::Length => Signature::Uniform(1, vec![DataType::Utf8]), BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]), + BuiltinScalarFunction::Array => { + Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec()) + } // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -341,10 +353,7 @@ mod tests { error::Result, logical_plan::ScalarValue, physical_plan::expressions::lit, }; use arrow::{ - array::{ - ArrayRef, Float64Array, Int32Array, PrimitiveArrayOps, StringArray, - StringArrayOps, - }, + array::{ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, PrimitiveArrayOps, StringArrayOps}, datatypes::Field, record_batch::RecordBatch, }; @@ -432,4 +441,69 @@ mod tests { Ok(()) } } + + fn generic_test_array( + value1: ScalarValue, + value2: ScalarValue, + expected_type: DataType, + expected: &str, + ) -> Result<()> { + // any type works here: we evaluate against a literal of `value` + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let columns: Vec = vec![Arc::new(Int32Array::from(vec![1]))]; + + let expr = create_physical_expr( + &BuiltinScalarFunction::Array, + &vec![lit(value1.clone()), lit(value2.clone())], + &schema, + )?; + + // type is correct + assert_eq!( + expr.data_type(&schema)?, + // type equals to a common coercion + DataType::FixedSizeList(Box::new(expected_type), 2) + ); + + // evaluate works + let result = + expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), columns)?)?; + + // downcast works + let result = result + .as_any() + .downcast_ref::() + .unwrap(); + + // value is correct + assert_eq!(format!("{:?}", result.value(0)), expected); + + Ok(()) + } + + #[test] + fn test_array() -> Result<()> { + generic_test_array( + ScalarValue::Utf8("aa".to_string()), + ScalarValue::Utf8("aa".to_string()), + DataType::Utf8, + "StringArray\n[\n \"aa\",\n \"aa\",\n]", + )?; + + // different types, to validate that casting happens + generic_test_array( + ScalarValue::UInt32(1), + ScalarValue::UInt64(1), + DataType::UInt64, + "PrimitiveArray\n[\n 1,\n 1,\n]", + )?; + + // different types (another order), to validate that casting happens + generic_test_array( + ScalarValue::UInt64(1), + ScalarValue::UInt32(1), + DataType::UInt64, + "PrimitiveArray\n[\n 1,\n 1,\n]", + ) + } } diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index 99ce8d6d4223b..f71b2792ff401 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -131,6 +131,7 @@ pub trait Accumulator: Debug { } pub mod aggregates; +pub mod array_expressions; pub mod common; pub mod csv; pub mod datetime_expressions; diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 1b68347d6e26a..aac2ebf71f1f6 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,6 +28,6 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - avg, col, concat, count, create_udf, length, lit, max, min, sum, + array, avg, col, concat, count, create_udf, length, lit, max, min, sum, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 9de59fbd405a5..d4629a04b6996 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -609,6 +609,74 @@ fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec { result_str(&results) } +fn array_str(array: &Arc, row_index: usize) -> String { + if array.is_null(row_index) { + return "NULL".to_string(); + } + + match array.data_type() { + DataType::Int8 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::Int16 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::Int32 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::Int64 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::UInt8 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::UInt16 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::UInt32 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::UInt64 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::Utf8 => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + format!("{:?}", array.value(row_index)) + } + DataType::FixedSizeList(_, n) => { + let array = array.as_any().downcast_ref::().unwrap(); + + let mut r = Vec::with_capacity(*n as usize); + for i in 0..*n { + let array = array.value(row_index); + r.push(array_str(&array, i as usize)); + } + format!("[{}]", r.join(",")) + } + _ => "???".to_string(), + } +} + fn result_str(results: &[RecordBatch]) -> Vec { let mut result = vec![]; for batch in results { @@ -620,76 +688,7 @@ fn result_str(results: &[RecordBatch]) -> Vec { } let column = batch.column(column_index); - match column.data_type() { - DataType::Int8 => { - let array = column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::Int16 => { - let array = column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::Int32 => { - let array = column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::Int64 => { - let array = column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::UInt8 => { - let array = column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::UInt16 => { - let array = - column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::UInt32 => { - let array = - column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::UInt64 => { - let array = - column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::Float32 => { - let array = - column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::Float64 => { - let array = - column.as_any().downcast_ref::().unwrap(); - str.push_str(&format!("{:?}", array.value(row_index))); - } - DataType::Utf8 => { - let array = - column.as_any().downcast_ref::().unwrap(); - let s = if array.is_null(row_index) { - "NULL" - } else { - array.value(row_index) - }; - - str.push_str(&format!("{:?}", s)); - } - DataType::Boolean => { - let array = - column.as_any().downcast_ref::().unwrap(); - let s = if array.is_null(row_index) { - "NULL".to_string() - } else { - format!("{:?}", array.value(row_index)) - }; - - str.push_str(&s); - } - _ => str.push_str("???"), - } + str.push_str(&array_str(column, row_index)); } result.push(str); } @@ -762,7 +761,38 @@ fn query_concat() -> Result<()> { ctx.register_table("test", Box::new(table)); let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql); - let expected = vec!["\"-hi-0\"", "\"a-hi-1\"", "\"NULL\"", "\"aaa-hi-3\""]; + let expected = vec!["\"-hi-0\"", "\"a-hi-1\"", "NULL", "\"aaa-hi-3\""]; + assert_eq!(expected, actual); + Ok(()) +} + +#[test] +fn query_array() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Int32, true), + ])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["", "a", "aa", "aaa"])), + Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])), + ], + )?; + + let table = MemTable::new(schema, vec![vec![data]])?; + + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; + let actual = execute(&mut ctx, sql); + let expected = vec![ + "[\"\",\"0\"]", + "[\"a\",\"1\"]", + "[\"aa\",NULL]", + "[\"aaa\",\"3\"]", + ]; assert_eq!(expected, actual); Ok(()) } From 2a102e13b996f34a45ab5860cbeca7c6ab64f60b Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Thu, 17 Sep 2020 04:53:45 +0200 Subject: [PATCH 2/4] Simplified code. --- rust/datafusion/tests/sql.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index d4629a04b6996..43dbec73e14f4 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -665,10 +665,10 @@ fn array_str(array: &Arc, row_index: usize) -> String { } DataType::FixedSizeList(_, n) => { let array = array.as_any().downcast_ref::().unwrap(); + let array = array.value(row_index); let mut r = Vec::with_capacity(*n as usize); for i in 0..*n { - let array = array.value(row_index); r.push(array_str(&array, i as usize)); } format!("[{}]", r.join(",")) From 303a90f70d7d6eaf85caaac84a4f30f0ebbe4cd6 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 18 Sep 2020 04:06:23 +0200 Subject: [PATCH 3/4] Added comment to help readability. --- rust/datafusion/tests/sql.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 43dbec73e14f4..ce2600081eb04 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -609,10 +609,14 @@ fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec { result_str(&results) } + +/// Converts an array's value at `row_index` to a string. fn array_str(array: &Arc, row_index: usize) -> String { if array.is_null(row_index) { return "NULL".to_string(); } + // beyond this point, we can assume that `array...downcast().value(row_index)` is valid, + // due to the `if` above. match array.data_type() { DataType::Int8 => { From 1aa71291a3e76c7957283597c69d54546408e9af Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Fri, 18 Sep 2020 04:12:33 +0200 Subject: [PATCH 4/4] Formatting. --- rust/datafusion/src/physical_plan/functions.rs | 5 ++++- rust/datafusion/tests/sql.rs | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index ace965dfbb6b5..95bd252e1db90 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -353,7 +353,10 @@ mod tests { error::Result, logical_plan::ScalarValue, physical_plan::expressions::lit, }; use arrow::{ - array::{ArrayRef, FixedSizeListArray, Float64Array, Int32Array, StringArray, PrimitiveArrayOps, StringArrayOps}, + array::{ + ArrayRef, FixedSizeListArray, Float64Array, Int32Array, PrimitiveArrayOps, + StringArray, StringArrayOps, + }, datatypes::Field, record_batch::RecordBatch, }; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index ce2600081eb04..a87dc79ea33bf 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -609,7 +609,6 @@ fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec { result_str(&results) } - /// Converts an array's value at `row_index` to a string. fn array_str(array: &Arc, row_index: usize) -> String { if array.is_null(row_index) {