diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 0a99c331826c..89e82fa952bb 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1340,6 +1340,92 @@ mod tests { use super::*; + async fn assert_logical_expr_schema_eq_physical_expr_schema( + df: DataFrame, + ) -> Result<()> { + let logical_expr_dfschema = df.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = df.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_ord_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field" ORDER BY "string_field") as "double_field", + array_agg("string_field" ORDER BY "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field") as "double_field", + array_agg("string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_distinct_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (2.0, 'a') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg(distinct "double_field") as "double_field", + array_agg(distinct "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + #[tokio::test] async fn select_columns() -> Result<()> { // build plan using Table API diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 4dccbfef07f8..91d5c867d312 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -34,9 +34,14 @@ use std::sync::Arc; /// ARRAY_AGG aggregate expression #[derive(Debug)] pub struct ArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl ArrayAgg { @@ -45,11 +50,13 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, + nullable: bool, ) -> Self { Self { name: name.into(), - expr, input_data_type: data_type, + expr, + nullable, } } } @@ -62,8 +69,9 @@ impl AggregateExpr for ArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -77,7 +85,7 @@ impl AggregateExpr for ArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -184,7 +192,6 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; @@ -195,6 +202,30 @@ mod tests { use datafusion_common::DataFusionError; use datafusion_common::Result; + macro_rules! test_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + #[test] fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); @@ -208,7 +239,7 @@ mod tests { ])]); let list = ScalarValue::List(Arc::new(list)); - generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) } #[test] @@ -264,7 +295,7 @@ mod tests { let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - generic_test_op!( + test_op!( array, DataType::List(Arc::new(Field::new_list( "item", diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 9b391b0c42cf..1efae424cc69 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -40,6 +40,8 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, } impl DistinctArrayAgg { @@ -48,12 +50,14 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, ) -> Self { let name = name.into(); Self { name, - expr, input_data_type, + expr, + nullable, } } } @@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )]) } @@ -238,6 +243,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let actual = aggregate(&batch, agg)?; @@ -255,6 +261,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let mut accum1 = agg.create_accumulator()?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index a53d53107add..9ca83a781a01 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -48,10 +48,17 @@ use itertools::izip; /// and that can merge aggregations from multiple partitions. #[derive(Debug)] pub struct OrderSensitiveArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, - order_by_data_types: Vec, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, + /// Ordering data types + order_by_data_types: Vec, + /// Ordering requirement ordering_req: LexOrdering, } @@ -61,13 +68,15 @@ impl OrderSensitiveArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + nullable: bool, order_by_data_types: Vec, ordering_req: LexOrdering, ) -> Self { Self { name: name.into(), - expr, input_data_type, + expr, + nullable, order_by_data_types, ordering_req, } @@ -82,8 +91,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, )) } @@ -99,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.nullable, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), - false, + self.nullable, )); Ok(fields) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 6568457bc234..596197b4eebe 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -114,13 +114,16 @@ pub fn create_aggregate_expr( ), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); + let nullable = expr.nullable(input_schema)?; + if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) + Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, name, data_type, + nullable, ordering_types, ordering_req.to_vec(), )) @@ -132,10 +135,13 @@ pub fn create_aggregate_expr( "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" ); } + let expr = input_phy_exprs[0].clone(); + let is_expr_nullable = expr.nullable(input_schema)?; Arc::new(expressions::DistinctArrayAgg::new( - input_phy_exprs[0].clone(), + expr, name, data_type, + is_expr_nullable, )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( @@ -432,8 +438,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -471,8 +477,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() );