From c082f3b11ab360d3a0987426a816b624abcd8e7d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 11:31:29 +0800 Subject: [PATCH 01/10] fix schema Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 29 ++++++++++++++ .../physical-expr/src/aggregate/array_agg.rs | 36 +++++++++++++++--- .../src/aggregate/array_agg_ordered.rs | 19 ++++++--- .../physical-expr/src/aggregate/build_in.rs | 10 ++++- test_array_agg.parquet | Bin 0 -> 2350 bytes 5 files changed, 82 insertions(+), 12 deletions(-) create mode 100644 test_array_agg.parquet diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 0a99c331826c..ecdeb4e55c3b 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1332,14 +1332,43 @@ mod tests { use datafusion_physical_expr::expressions::Column; use crate::execution::context::SessionConfig; + use crate::execution::options::ParquetReadOptions; use crate::physical_plan::ColumnarValue; use crate::physical_plan::Partitioning; use crate::physical_plan::PhysicalExpr; + use crate::test_util::parquet_test_data; use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; use super::*; + #[tokio::test] + async fn test_array_agg_schema() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_parquet( + "test", + &format!("{}/test_array_agg.parquet", parquet_test_data()), + ParquetReadOptions::default(), + ) + .await?; + + ctx.register_table("test_table", ctx.table("test").await?.into_view())?; + + let query = r#"SELECT + array_agg("double_field" ORDER BY "string_field") as "double_field", + array_agg("string_field" ORDER BY "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + let logical_expr_dfschema = result.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = result.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + + Ok(()) + } + #[tokio::test] async fn select_columns() -> Result<()> { // build plan using Table API diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 4dccbfef07f8..acefc1552fe8 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -37,6 +37,7 @@ pub struct ArrayAgg { name: String, input_data_type: DataType, expr: Arc, + is_expr_nullable: bool, } impl ArrayAgg { @@ -45,11 +46,13 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, + is_expr_nullable: bool, ) -> Self { Self { name: name.into(), expr, input_data_type: data_type, + is_expr_nullable, } } } @@ -62,7 +65,7 @@ impl AggregateExpr for ArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, - Field::new("item", self.input_data_type.clone(), true), + Field::new("item", self.input_data_type.clone(), self.is_expr_nullable), false, )) } @@ -76,7 +79,7 @@ impl AggregateExpr for ArrayAgg { fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), + Field::new("item", self.input_data_type.clone(), self.is_expr_nullable), false, )]) } @@ -184,7 +187,6 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; @@ -195,6 +197,30 @@ mod tests { use datafusion_common::DataFusionError; use datafusion_common::Result; + macro_rules! test_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + #[test] fn array_agg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); @@ -208,7 +234,7 @@ mod tests { ])]); let list = ScalarValue::List(Arc::new(list)); - generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) } #[test] @@ -264,7 +290,7 @@ mod tests { let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - generic_test_op!( + test_op!( array, DataType::List(Arc::new(Field::new_list( "item", diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index a53d53107add..04b860dc3a4d 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -50,6 +50,7 @@ use itertools::izip; pub struct OrderSensitiveArrayAgg { name: String, input_data_type: DataType, + is_expr_nullable: bool, order_by_data_types: Vec, expr: Arc, ordering_req: LexOrdering, @@ -61,6 +62,7 @@ impl OrderSensitiveArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + is_expr_nullable: bool, order_by_data_types: Vec, ordering_req: LexOrdering, ) -> Self { @@ -68,6 +70,7 @@ impl OrderSensitiveArrayAgg { name: name.into(), expr, input_data_type, + is_expr_nullable, order_by_data_types, ordering_req, } @@ -82,8 +85,8 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, - Field::new("item", self.input_data_type.clone(), true), - false, + Field::new("item", self.input_data_type.clone(), self.is_expr_nullable), + true, // This should be aligned with the return type of `AggregateFunction::ArrayAgg` )) } @@ -98,14 +101,18 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn state_fields(&self) -> Result> { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - false, + Field::new("item", self.input_data_type.clone(), self.is_expr_nullable), + true, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), - Field::new("item", DataType::Struct(Fields::from(orderings)), true), - false, + Field::new( + "item", + DataType::Struct(Fields::from(orderings)), + self.is_expr_nullable, + ), + true, )); Ok(fields) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 6568457bc234..1f30935122dd 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -114,13 +114,21 @@ pub fn create_aggregate_expr( ), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); + let is_expr_nullable = expr.nullable(input_schema)?; + println!("null: {:?}", is_expr_nullable); if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type)) + Arc::new(expressions::ArrayAgg::new( + expr, + name, + data_type, + is_expr_nullable, + )) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, name, data_type, + is_expr_nullable, ordering_types, ordering_req.to_vec(), )) diff --git a/test_array_agg.parquet b/test_array_agg.parquet new file mode 100644 index 0000000000000000000000000000000000000000..e5006c8868907c6c24e8a953146af4178724d9d0 GIT binary patch literal 2350 zcmcguOK;*<6uuCKiK0l2sxblyRgpD=vLGoYkQqiJ&0JtGrjSWU!oyV+`2`qk+|wC-#Pajn={Nn8v8v2ztmZ3`%@w$(vSjT!U(oKxFH1m zR}8yI>5CU3kV#~d$);>(bAk>+N<1WiZVl_tktDOb!0LLAZuM8N{t$}&7)t%UmHKHb zx%BW5mc*+(ibq{#pqUaALCnsx@OtWq*(IXY2+vSG32N&Xp%Q}Xr&BYb|hy;3T4kE z#~DuuiW|*w^|3{yT<$1ylwkohJCDAi%pP$^S@s^qcX6}%uKhBYqWG`dUz0WKmodiY zaFO*5#yMQ<;|j)oTv zVk=uMS1yrwqAj+kqfYjLxO@e$S8hAU&Ah@BOjuXjc^mwX#0!DM28`T9YZWqDr%3$@ zW1MrdQUI5;ztBdXYP|h}O*;J^%R;$U^FZ+6L1G_Q<`-%tkTj0XF z$QQV-IW;=j0CP~UP5d9Rvx-n-s3W(ERdnTSC38AF*Gb)=CZ19?dTP~}Kz=8XpDFW& zfm-fnf>@lP@|95+a$44~nsG^9vRU!w&)llj8sw+;MAI;2KJXpT&E9|6DWALMe4dl} zX`ypL&iZsFwksL=SRnlcwVm2s=sm?6(X)Wufv(-CGkSJeu~WO1+6ABX=xV9=(%ZqDCoUl(s(ioi=2QYGlVb6Mu(`}AfpLPvDl=-5!vbd zP!>1H4f=luvtljIJy4w!JD5+pU6J{Nd9IRqghiEv0BxBXl_(%e26AFo9LdKmUFKVh zvjSO{fCSv#zzjwk#rJK=`QF#OL&dfa>5FN9{^HpWUP*~Ge1geW__xOhoxwk)e*nEQ BhHwA? literal 0 HcmV?d00001 From 20216e6f2028c3d5f0b2569d792e239d5cf63463 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 11:43:45 +0800 Subject: [PATCH 02/10] cleanup Signed-off-by: jayzhan211 --- datafusion/physical-expr/src/aggregate/build_in.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 1f30935122dd..c7faf6522102 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -115,7 +115,7 @@ pub fn create_aggregate_expr( (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); let is_expr_nullable = expr.nullable(input_schema)?; - println!("null: {:?}", is_expr_nullable); + if ordering_req.is_empty() { Arc::new(expressions::ArrayAgg::new( expr, From d0de60773c7325cbe6e7d636baa7b06c7d9c6d5a Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 11:48:36 +0800 Subject: [PATCH 03/10] upd parquet-testing Signed-off-by: jayzhan211 --- parquet-testing | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parquet-testing b/parquet-testing index e45cd23f784a..2a856f246390 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit e45cd23f784aab3d6bf0701f8f4e621469ed3be7 +Subproject commit 2a856f246390385a1e1f8d05c52ed4e110cf35fa From c8c530d9411454e769d240fdd8a80edf99617434 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 12:27:55 +0800 Subject: [PATCH 04/10] avoid parquet file Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index ecdeb4e55c3b..7cbba1ea6d0f 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1332,11 +1332,9 @@ mod tests { use datafusion_physical_expr::expressions::Column; use crate::execution::context::SessionConfig; - use crate::execution::options::ParquetReadOptions; use crate::physical_plan::ColumnarValue; use crate::physical_plan::Partitioning; use crate::physical_plan::PhysicalExpr; - use crate::test_util::parquet_test_data; use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; @@ -1345,14 +1343,17 @@ mod tests { #[tokio::test] async fn test_array_agg_schema() -> Result<()> { let ctx = SessionContext::new(); - ctx.register_parquet( - "test", - &format!("{}/test_array_agg.parquet", parquet_test_data()), - ParquetReadOptions::default(), - ) - .await?; - ctx.register_table("test_table", ctx.table("test").await?.into_view())?; + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; let query = r#"SELECT array_agg("double_field" ORDER BY "string_field") as "double_field", From c0968b7c7323dae42b9abc2ad9bf929379b6a20b Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 12:29:10 +0800 Subject: [PATCH 05/10] reset parquet-testing Signed-off-by: jayzhan211 --- parquet-testing | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parquet-testing b/parquet-testing index 2a856f246390..e45cd23f784a 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 2a856f246390385a1e1f8d05c52ed4e110cf35fa +Subproject commit e45cd23f784aab3d6bf0701f8f4e621469ed3be7 From 719fa2c85bdfa1a4b9ad0d8ec78cfb0ae6e1367d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 12:30:05 +0800 Subject: [PATCH 06/10] remove file Signed-off-by: jayzhan211 --- test_array_agg.parquet | Bin 2350 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 test_array_agg.parquet diff --git a/test_array_agg.parquet b/test_array_agg.parquet deleted file mode 100644 index e5006c8868907c6c24e8a953146af4178724d9d0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2350 zcmcguOK;*<6uuCKiK0l2sxblyRgpD=vLGoYkQqiJ&0JtGrjSWU!oyV+`2`qk+|wC-#Pajn={Nn8v8v2ztmZ3`%@w$(vSjT!U(oKxFH1m zR}8yI>5CU3kV#~d$);>(bAk>+N<1WiZVl_tktDOb!0LLAZuM8N{t$}&7)t%UmHKHb zx%BW5mc*+(ibq{#pqUaALCnsx@OtWq*(IXY2+vSG32N&Xp%Q}Xr&BYb|hy;3T4kE z#~DuuiW|*w^|3{yT<$1ylwkohJCDAi%pP$^S@s^qcX6}%uKhBYqWG`dUz0WKmodiY zaFO*5#yMQ<;|j)oTv zVk=uMS1yrwqAj+kqfYjLxO@e$S8hAU&Ah@BOjuXjc^mwX#0!DM28`T9YZWqDr%3$@ zW1MrdQUI5;ztBdXYP|h}O*;J^%R;$U^FZ+6L1G_Q<`-%tkTj0XF z$QQV-IW;=j0CP~UP5d9Rvx-n-s3W(ERdnTSC38AF*Gb)=CZ19?dTP~}Kz=8XpDFW& zfm-fnf>@lP@|95+a$44~nsG^9vRU!w&)llj8sw+;MAI;2KJXpT&E9|6DWALMe4dl} zX`ypL&iZsFwksL=SRnlcwVm2s=sm?6(X)Wufv(-CGkSJeu~WO1+6ABX=xV9=(%ZqDCoUl(s(ioi=2QYGlVb6Mu(`}AfpLPvDl=-5!vbd zP!>1H4f=luvtljIJy4w!JD5+pU6J{Nd9IRqghiEv0BxBXl_(%e26AFo9LdKmUFKVh zvjSO{fCSv#zzjwk#rJK=`QF#OL&dfa>5FN9{^HpWUP*~Ge1geW__xOhoxwk)e*nEQ BhHwA? From 3a29f4de8a57132e5db690ca77c1471b95a59417 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 14:50:28 +0800 Subject: [PATCH 07/10] fix Signed-off-by: jayzhan211 --- .../physical-expr/src/aggregate/array_agg.rs | 9 +++++---- .../src/aggregate/array_agg_distinct.rs | 11 +++++++++-- .../src/aggregate/array_agg_ordered.rs | 17 +++++++---------- .../physical-expr/src/aggregate/build_in.rs | 5 ++++- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index acefc1552fe8..ee9d58385023 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -65,8 +65,9 @@ impl AggregateExpr for ArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, - Field::new("item", self.input_data_type.clone(), self.is_expr_nullable), - false, + // This should be the same as return type of AggregateFunction::ArrayAgg + Field::new("item", self.input_data_type.clone(), true), + self.is_expr_nullable, )) } @@ -79,8 +80,8 @@ impl AggregateExpr for ArrayAgg { fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), self.is_expr_nullable), - false, + Field::new("item", self.input_data_type.clone(), true), + self.is_expr_nullable, )]) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 9b391b0c42cf..99ece1bf5ef1 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -40,6 +40,8 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, + /// Whether the input expression can produce NULL values + is_expr_nullable: bool, } impl DistinctArrayAgg { @@ -48,12 +50,14 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, + is_expr_nullable: bool, ) -> Self { let name = name.into(); Self { name, expr, input_data_type, + is_expr_nullable, } } } @@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, + // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - false, + self.is_expr_nullable, )) } @@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), true), - false, + self.is_expr_nullable, )]) } @@ -238,6 +243,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let actual = aggregate(&batch, agg)?; @@ -255,6 +261,7 @@ mod tests { col("a", &schema)?, "bla".to_string(), datatype, + true, )); let mut accum1 = agg.create_accumulator()?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 04b860dc3a4d..19ba7b9ef05f 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -85,8 +85,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn field(&self) -> Result { Ok(Field::new_list( &self.name, - Field::new("item", self.input_data_type.clone(), self.is_expr_nullable), - true, // This should be aligned with the return type of `AggregateFunction::ArrayAgg` + // This should be the same as return type of AggregateFunction::ArrayAgg + Field::new("item", self.input_data_type.clone(), true), + self.is_expr_nullable, )) } @@ -101,18 +102,14 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn state_fields(&self) -> Result> { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), self.is_expr_nullable), - true, // This should be the same as field() + Field::new("item", self.input_data_type.clone(), true), + self.is_expr_nullable, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), - Field::new( - "item", - DataType::Struct(Fields::from(orderings)), - self.is_expr_nullable, - ), - true, + Field::new("item", DataType::Struct(Fields::from(orderings)), true), + self.is_expr_nullable, )); Ok(fields) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index c7faf6522102..ee847e0f652d 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -140,10 +140,13 @@ pub fn create_aggregate_expr( "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" ); } + let expr = input_phy_exprs[0].clone(); + let is_expr_nullable = expr.nullable(input_schema)?; Arc::new(expressions::DistinctArrayAgg::new( - input_phy_exprs[0].clone(), + expr, name, data_type, + is_expr_nullable, )) } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( From 690c5e92813e32d8b836f27162b197fcb6ee22fe Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 15:15:28 +0800 Subject: [PATCH 08/10] fix test Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 62 ++++++++++++++++++- .../physical-expr/src/aggregate/build_in.rs | 8 +-- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 7cbba1ea6d0f..8b456c286aed 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1341,7 +1341,7 @@ mod tests { use super::*; #[tokio::test] - async fn test_array_agg_schema() -> Result<()> { + async fn test_array_agg_ord_schema() -> Result<()> { let ctx = SessionContext::new(); let create_table_query = r#" @@ -1370,6 +1370,66 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_array_agg_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (3.0, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg("double_field") as "double_field", + array_agg("string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + let logical_expr_dfschema = result.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = result.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_distinct_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE test_table ( + "double_field" DOUBLE, + "string_field" VARCHAR + ) AS VALUES + (1.0, 'a'), + (2.0, 'b'), + (2.0, 'a') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"SELECT + array_agg(distinct "double_field") as "double_field", + array_agg(distinct "string_field") as "string_field" + FROM test_table"#; + + let result = ctx.sql(query).await?; + let logical_expr_dfschema = result.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = result.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + + Ok(()) + } + #[tokio::test] async fn select_columns() -> Result<()> { // build plan using Table API diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ee847e0f652d..9c0cfbdf6ebe 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -443,8 +443,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); @@ -482,8 +482,8 @@ mod tests { assert_eq!( Field::new_list( "c1", - Field::new("item", data_type.clone(), true,), - false, + Field::new("item", data_type.clone(), true), + true, ), result_agg_phy_exprs.field().unwrap() ); From 072c993585830525f804dd76716778bff3b8481f Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 5 Nov 2023 15:21:28 +0800 Subject: [PATCH 09/10] cleanup Signed-off-by: jayzhan211 --- datafusion/core/src/dataframe/mod.rs | 32 ++++++++++++---------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 8b456c286aed..89e82fa952bb 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1340,6 +1340,17 @@ mod tests { use super::*; + async fn assert_logical_expr_schema_eq_physical_expr_schema( + df: DataFrame, + ) -> Result<()> { + let logical_expr_dfschema = df.schema(); + let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); + let batches = df.collect().await?; + let physical_expr_schema = batches[0].schema(); + assert_eq!(logical_expr_schema, physical_expr_schema); + Ok(()) + } + #[tokio::test] async fn test_array_agg_ord_schema() -> Result<()> { let ctx = SessionContext::new(); @@ -1361,12 +1372,7 @@ mod tests { FROM test_table"#; let result = ctx.sql(query).await?; - let logical_expr_dfschema = result.schema(); - let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); - let batches = result.collect().await?; - let physical_expr_schema = batches[0].schema(); - assert_eq!(logical_expr_schema, physical_expr_schema); - + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; Ok(()) } @@ -1391,12 +1397,7 @@ mod tests { FROM test_table"#; let result = ctx.sql(query).await?; - let logical_expr_dfschema = result.schema(); - let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); - let batches = result.collect().await?; - let physical_expr_schema = batches[0].schema(); - assert_eq!(logical_expr_schema, physical_expr_schema); - + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; Ok(()) } @@ -1421,12 +1422,7 @@ mod tests { FROM test_table"#; let result = ctx.sql(query).await?; - let logical_expr_dfschema = result.schema(); - let logical_expr_schema = SchemaRef::from(logical_expr_dfschema.to_owned()); - let batches = result.collect().await?; - let physical_expr_schema = batches[0].schema(); - assert_eq!(logical_expr_schema, physical_expr_schema); - + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; Ok(()) } From 4dfa5977a8bb9c24bfc58714fd6d8b32bdaa2261 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Thu, 9 Nov 2023 19:49:27 +0800 Subject: [PATCH 10/10] rename and upd docstring Signed-off-by: jayzhan211 --- .../physical-expr/src/aggregate/array_agg.rs | 16 +++++++++----- .../src/aggregate/array_agg_distinct.rs | 14 ++++++------ .../src/aggregate/array_agg_ordered.rs | 22 ++++++++++++------- .../physical-expr/src/aggregate/build_in.rs | 11 +++------- 4 files changed, 34 insertions(+), 29 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index ee9d58385023..91d5c867d312 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -34,10 +34,14 @@ use std::sync::Arc; /// ARRAY_AGG aggregate expression #[derive(Debug)] pub struct ArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, + /// The input expression expr: Arc, - is_expr_nullable: bool, + /// If the input expression can have NULLs + nullable: bool, } impl ArrayAgg { @@ -46,13 +50,13 @@ impl ArrayAgg { expr: Arc, name: impl Into, data_type: DataType, - is_expr_nullable: bool, + nullable: bool, ) -> Self { Self { name: name.into(), - expr, input_data_type: data_type, - is_expr_nullable, + expr, + nullable, } } } @@ -67,7 +71,7 @@ impl AggregateExpr for ArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - self.is_expr_nullable, + self.nullable, )) } @@ -81,7 +85,7 @@ impl AggregateExpr for ArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - self.is_expr_nullable, + self.nullable, )]) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 99ece1bf5ef1..1efae424cc69 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -40,8 +40,8 @@ pub struct DistinctArrayAgg { input_data_type: DataType, /// The input expression expr: Arc, - /// Whether the input expression can produce NULL values - is_expr_nullable: bool, + /// If the input expression can have NULLs + nullable: bool, } impl DistinctArrayAgg { @@ -50,14 +50,14 @@ impl DistinctArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, - is_expr_nullable: bool, + nullable: bool, ) -> Self { let name = name.into(); Self { name, - expr, input_data_type, - is_expr_nullable, + expr, + nullable, } } } @@ -73,7 +73,7 @@ impl AggregateExpr for DistinctArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - self.is_expr_nullable, + self.nullable, )) } @@ -87,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), Field::new("item", self.input_data_type.clone(), true), - self.is_expr_nullable, + self.nullable, )]) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 19ba7b9ef05f..9ca83a781a01 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -48,11 +48,17 @@ use itertools::izip; /// and that can merge aggregations from multiple partitions. #[derive(Debug)] pub struct OrderSensitiveArrayAgg { + /// Column name name: String, + /// The DataType for the input expression input_data_type: DataType, - is_expr_nullable: bool, - order_by_data_types: Vec, + /// The input expression expr: Arc, + /// If the input expression can have NULLs + nullable: bool, + /// Ordering data types + order_by_data_types: Vec, + /// Ordering requirement ordering_req: LexOrdering, } @@ -62,15 +68,15 @@ impl OrderSensitiveArrayAgg { expr: Arc, name: impl Into, input_data_type: DataType, - is_expr_nullable: bool, + nullable: bool, order_by_data_types: Vec, ordering_req: LexOrdering, ) -> Self { Self { name: name.into(), - expr, input_data_type, - is_expr_nullable, + expr, + nullable, order_by_data_types, ordering_req, } @@ -87,7 +93,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg Field::new("item", self.input_data_type.clone(), true), - self.is_expr_nullable, + self.nullable, )) } @@ -103,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), Field::new("item", self.input_data_type.clone(), true), - self.is_expr_nullable, // This should be the same as field() + self.nullable, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), - self.is_expr_nullable, + self.nullable, )); Ok(fields) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 9c0cfbdf6ebe..596197b4eebe 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -114,21 +114,16 @@ pub fn create_aggregate_expr( ), (AggregateFunction::ArrayAgg, false) => { let expr = input_phy_exprs[0].clone(); - let is_expr_nullable = expr.nullable(input_schema)?; + let nullable = expr.nullable(input_schema)?; if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new( - expr, - name, - data_type, - is_expr_nullable, - )) + Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, name, data_type, - is_expr_nullable, + nullable, ordering_types, ordering_req.to_vec(), ))