diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 1769314ebc0c..a4d6619f2a4f 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -83,7 +83,7 @@ fn optimize(plan: &LogicalPlan) -> Result { Expr::AggregateFunction { fun: fun.clone(), args: vec![col(SINGLE_DISTINCT_ALIAS)], - distinct: false, + distinct: false, // intentional to remove distict here } } _ => agg_expr.clone(), diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7b08e4f40456..0b4a43e83e71 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -474,6 +474,7 @@ enum AggregateFunction { message AggregateExprNode { AggregateFunction aggr_function = 1; repeated LogicalExprNode expr = 2; + bool distinct = 3; } message AggregateUDFExprNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 63d9fe2b79c0..12f94ce3620e 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -890,7 +890,7 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?, - distinct: false, // TODO + distinct: expr.distinct, }) } ExprType::Alias(alias) => Ok(Expr::Alias( diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index eecca1b6ad59..c843de630289 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -912,6 +912,28 @@ mod roundtrip_tests { roundtrip_expr_test(test_expr, ctx); } + #[test] + fn roundtrip_count() { + let test_expr = Expr::AggregateFunction { + fun: AggregateFunction::Count, + args: vec![col("bananas")], + distinct: false, + }; + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_count_distinct() { + let test_expr = Expr::AggregateFunction { + fun: AggregateFunction::Count, + args: vec![col("bananas")], + distinct: true, + }; + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + #[test] fn roundtrip_approx_percentile_cont() { let test_expr = Expr::AggregateFunction { diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index a022769dcab1..d3f68b3b4276 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -502,7 +502,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } Expr::AggregateFunction { - ref fun, ref args, .. + ref fun, + ref args, + ref distinct, } => { let aggr_function = match fun { AggregateFunction::ApproxDistinct => { @@ -550,6 +552,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { .iter() .map(|v| v.try_into()) .collect::, _>>()?, + distinct: *distinct, }; Self { expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),