From c192d48a79d5f68d7e9702bd7391a8cbe71b97ff Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 24 Aug 2022 09:25:30 -0600 Subject: [PATCH 1/3] save --- .../src/single_distinct_to_groupby.rs | 8 +- .../optimizer/tests/integration-test.rs | 90 +++++++++++++++---- 2 files changed, 77 insertions(+), 21 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 1769314ebc0c..2141d69cf9c2 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -72,7 +72,11 @@ fn optimize(plan: &LogicalPlan) -> Result { let new_aggr_expr = aggr_expr .iter() .map(|agg_expr| match agg_expr { - Expr::AggregateFunction { fun, args, .. } => { + Expr::AggregateFunction { + fun, + args, + distinct, + } => { // is_single_distinct_agg ensure args.len=1 if group_fields_set .insert(args[0].name(input.schema()).unwrap()) @@ -83,7 +87,7 @@ fn optimize(plan: &LogicalPlan) -> Result { Expr::AggregateFunction { fun: fun.clone(), args: vec![col(SINGLE_DISTINCT_ALIAS)], - distinct: false, + distinct: *distinct, } } _ => agg_expr.clone(), diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index b9d4d3b6333c..ce08aec5025f 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -17,7 +17,11 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; +use datafusion_expr::logical_plan::builder::LogicalTableSource; +use datafusion_expr::{ + col, count, count_distinct, AggregateUDF, LogicalPlan, LogicalPlanBuilder, ScalarUDF, + TableSource, +}; use datafusion_optimizer::common_subexpr_eliminate::CommonSubexprEliminate; use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery; use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; @@ -56,7 +60,71 @@ fn distribute_by() -> Result<()> { Ok(()) } +#[test] +fn count_distinct_multi_sql() -> Result<()> { + let sql = "SELECT COUNT(col_int32) AS num, COUNT(DISTINCT col_int32) AS num_distinct FROM test"; + let plan = test_sql(sql)?; + let expected = "Projection: #COUNT(test.col_int32) AS num, #COUNT(DISTINCT test.col_int32) AS num_distinct\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#test.col_int32), COUNT(DISTINCT #test.col_int32)]]\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + +#[test] +fn count_distinct_multi_plan_builder() -> Result<()> { + let schema_provider = MySchemaProvider {}; + let table_name: TableReference = "test".into(); + let table = schema_provider.get_table_provider(table_name)?; + let table_source = LogicalTableSource::new(table.schema()); + + let plan = LogicalPlanBuilder::scan("test", Arc::new(table_source), None)? + .aggregate( + vec![col("test.col_int32")], + vec![ + count(col("test.col_int32")), + count_distinct(col("test.col_int32")), + ], + )? + .project(vec![col("test.col_int32")])? + .build()?; + + println!("{}", plan.display_indent()); + + let plan = optimize_plan(&plan)?; + + let expected = "Projection: #COUNT(test.col_int32) AS num, #COUNT(DISTINCT test.col_int32) AS num_distinct\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#test.col_int32), COUNT(DISTINCT #test.col_int32)]]\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + +fn optimize_plan(plan: &LogicalPlan) -> Result { + let mut config = OptimizerConfig::new().with_skip_failing_rules(false); + let optimizer = create_optimizer(); + optimizer.optimize(&plan, &mut config, &observe) +} + fn test_sql(sql: &str) -> Result { + let optimizer = create_optimizer(); + + // parse the SQL + let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... + let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); + let statement = &ast[0]; + + // create a logical query plan + let schema_provider = MySchemaProvider {}; + let sql_to_rel = SqlToRel::new(&schema_provider); + let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); + + // optimize the logical plan + let mut config = OptimizerConfig::new().with_skip_failing_rules(false); + optimizer.optimize(&plan, &mut config, &observe) +} + +fn create_optimizer() -> Optimizer { let rules: Vec> = vec![ // Simplify expressions first to maximize the chance // of applying other optimizations @@ -78,29 +146,13 @@ fn test_sql(sql: &str) -> Result { ]; let optimizer = Optimizer::new(rules); - - // parse the SQL - let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... - let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); - let statement = &ast[0]; - - // create a logical query plan - let schema_provider = MySchemaProvider {}; - let sql_to_rel = SqlToRel::new(&schema_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); - - // optimize the logical plan - let mut config = OptimizerConfig::new().with_skip_failing_rules(false); - optimizer.optimize(&plan, &mut config, &observe) + optimizer } struct MySchemaProvider {} impl ContextProvider for MySchemaProvider { - fn get_table_provider( - &self, - name: TableReference, - ) -> datafusion_common::Result> { + fn get_table_provider(&self, name: TableReference) -> Result> { let table_name = name.table(); if table_name.starts_with("test") { let schema = Schema::new_with_metadata( From b5defe901efe89471787e9bf14c4a8783d8c467e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 24 Aug 2022 09:47:25 -0600 Subject: [PATCH 2/3] proto --- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/from_proto.rs | 2 +- datafusion/proto/src/lib.rs | 22 ++++++++++++++++++++++ datafusion/proto/src/to_proto.rs | 5 ++++- 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7b08e4f40456..0b4a43e83e71 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -474,6 +474,7 @@ enum AggregateFunction { message AggregateExprNode { AggregateFunction aggr_function = 1; repeated LogicalExprNode expr = 2; + bool distinct = 3; } message AggregateUDFExprNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 63d9fe2b79c0..12f94ce3620e 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -890,7 +890,7 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?, - distinct: false, // TODO + distinct: expr.distinct, }) } ExprType::Alias(alias) => Ok(Expr::Alias( diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index eecca1b6ad59..c843de630289 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -912,6 +912,28 @@ mod roundtrip_tests { roundtrip_expr_test(test_expr, ctx); } + #[test] + fn roundtrip_count() { + let test_expr = Expr::AggregateFunction { + fun: AggregateFunction::Count, + args: vec![col("bananas")], + distinct: false, + }; + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_count_distinct() { + let test_expr = Expr::AggregateFunction { + fun: AggregateFunction::Count, + args: vec![col("bananas")], + distinct: true, + }; + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + #[test] fn roundtrip_approx_percentile_cont() { let test_expr = Expr::AggregateFunction { diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index a022769dcab1..d3f68b3b4276 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -502,7 +502,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } Expr::AggregateFunction { - ref fun, ref args, .. + ref fun, + ref args, + ref distinct, } => { let aggr_function = match fun { AggregateFunction::ApproxDistinct => { @@ -550,6 +552,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .iter() .map(|v| v.try_into()) .collect::, _>>()?, + distinct: *distinct, }; Self { expr_type: Some(ExprType::AggregateExpr(aggregate_expr)), From 87b98d8a122cb8ecbbd3c5fa0ff0a1ddbf29ad1f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 24 Aug 2022 10:13:36 -0600 Subject: [PATCH 3/3] revert some changes --- .../src/single_distinct_to_groupby.rs | 8 +- .../optimizer/tests/integration-test.rs | 90 ++++--------------- 2 files changed, 21 insertions(+), 77 deletions(-) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 2141d69cf9c2..a4d6619f2a4f 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -72,11 +72,7 @@ fn optimize(plan: &LogicalPlan) -> Result { let new_aggr_expr = aggr_expr .iter() .map(|agg_expr| match agg_expr { - Expr::AggregateFunction { - fun, - args, - distinct, - } => { + Expr::AggregateFunction { fun, args, .. } => { // is_single_distinct_agg ensure args.len=1 if group_fields_set .insert(args[0].name(input.schema()).unwrap()) @@ -87,7 +83,7 @@ fn optimize(plan: &LogicalPlan) -> Result { Expr::AggregateFunction { fun: fun.clone(), args: vec![col(SINGLE_DISTINCT_ALIAS)], - distinct: *distinct, + distinct: false, // intentional to remove distict here } } _ => agg_expr.clone(), diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index ce08aec5025f..b9d4d3b6333c 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -17,11 +17,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::logical_plan::builder::LogicalTableSource; -use datafusion_expr::{ - col, count, count_distinct, AggregateUDF, LogicalPlan, LogicalPlanBuilder, ScalarUDF, - TableSource, -}; +use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; use datafusion_optimizer::common_subexpr_eliminate::CommonSubexprEliminate; use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery; use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; @@ -60,71 +56,7 @@ fn distribute_by() -> Result<()> { Ok(()) } -#[test] -fn count_distinct_multi_sql() -> Result<()> { - let sql = "SELECT COUNT(col_int32) AS num, COUNT(DISTINCT col_int32) AS num_distinct FROM test"; - let plan = test_sql(sql)?; - let expected = "Projection: #COUNT(test.col_int32) AS num, #COUNT(DISTINCT test.col_int32) AS num_distinct\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#test.col_int32), COUNT(DISTINCT #test.col_int32)]]\ - \n TableScan: test projection=[col_int32]"; - assert_eq!(expected, format!("{:?}", plan)); - Ok(()) -} - -#[test] -fn count_distinct_multi_plan_builder() -> Result<()> { - let schema_provider = MySchemaProvider {}; - let table_name: TableReference = "test".into(); - let table = schema_provider.get_table_provider(table_name)?; - let table_source = LogicalTableSource::new(table.schema()); - - let plan = LogicalPlanBuilder::scan("test", Arc::new(table_source), None)? - .aggregate( - vec![col("test.col_int32")], - vec![ - count(col("test.col_int32")), - count_distinct(col("test.col_int32")), - ], - )? - .project(vec![col("test.col_int32")])? - .build()?; - - println!("{}", plan.display_indent()); - - let plan = optimize_plan(&plan)?; - - let expected = "Projection: #COUNT(test.col_int32) AS num, #COUNT(DISTINCT test.col_int32) AS num_distinct\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#test.col_int32), COUNT(DISTINCT #test.col_int32)]]\ - \n TableScan: test projection=[col_int32]"; - assert_eq!(expected, format!("{:?}", plan)); - Ok(()) -} - -fn optimize_plan(plan: &LogicalPlan) -> Result { - let mut config = OptimizerConfig::new().with_skip_failing_rules(false); - let optimizer = create_optimizer(); - optimizer.optimize(&plan, &mut config, &observe) -} - fn test_sql(sql: &str) -> Result { - let optimizer = create_optimizer(); - - // parse the SQL - let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... - let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); - let statement = &ast[0]; - - // create a logical query plan - let schema_provider = MySchemaProvider {}; - let sql_to_rel = SqlToRel::new(&schema_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); - - // optimize the logical plan - let mut config = OptimizerConfig::new().with_skip_failing_rules(false); - optimizer.optimize(&plan, &mut config, &observe) -} - -fn create_optimizer() -> Optimizer { let rules: Vec> = vec![ // Simplify expressions first to maximize the chance // of applying other optimizations @@ -146,13 +78,29 @@ fn create_optimizer() -> Optimizer { ]; let optimizer = Optimizer::new(rules); - optimizer + + // parse the SQL + let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... + let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); + let statement = &ast[0]; + + // create a logical query plan + let schema_provider = MySchemaProvider {}; + let sql_to_rel = SqlToRel::new(&schema_provider); + let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); + + // optimize the logical plan + let mut config = OptimizerConfig::new().with_skip_failing_rules(false); + optimizer.optimize(&plan, &mut config, &observe) } struct MySchemaProvider {} impl ContextProvider for MySchemaProvider { - fn get_table_provider(&self, name: TableReference) -> Result> { + fn get_table_provider( + &self, + name: TableReference, + ) -> datafusion_common::Result> { let table_name = name.table(); if table_name.starts_with("test") { let schema = Schema::new_with_metadata(