From 9c1c808ca58d05212ac22ce451abd9f853e3799c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 23 Apr 2024 17:58:15 -0700 Subject: [PATCH] Add StatisticsType in expr.poto --- core/src/execution/datafusion/planner.rs | 40 ++++++++++--------- core/src/execution/proto/expr.proto | 17 ++++---- .../apache/comet/serde/QueryPlanSerde.scala | 10 +++-- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 2bf2e4576..bc2d9bed3 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -1236,27 +1236,29 @@ impl PhysicalPlanner { StatsType::Population, ))) } - AggExprStruct::VarianceSample(expr) => { + AggExprStruct::Variance(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(Variance::new( - child, - "variance", - datatype, - StatsType::Sample, - expr.null_on_divide_by_zero, - ))) - } - AggExprStruct::VariancePopulation(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(Variance::new( - child, - "variance_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, - ))) + match expr.stats_type { + 0 => Ok(Arc::new(Variance::new( + child, + "variance", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + ))), + 1 => Ok(Arc::new(Variance::new( + child, + "variance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + ))), + stats_type => Err(ExecutionError::GeneralError(format!( + "Unknown StatisticsType {:?} for Variance", + stats_type + ))), + } } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index ba18aee78..042a981f4 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -94,11 +94,15 @@ message AggExpr { BitXorAgg bitXorAgg = 11; CovSample covSample = 12; CovPopulation covPopulation = 13; - VarianceSample varianceSample = 14; - VariancePopulation variancePopulation = 15; + Variance variance = 14; } } +enum StatisticsType { + SAMPLE = 0; + POPULATION = 1; +} + message Count { repeated Expr children = 1; } @@ -167,16 +171,11 @@ message CovPopulation { DataType datatype = 4; } -message VarianceSample { - Expr child = 1; - bool null_on_divide_by_zero = 2; - DataType datatype = 3; -} - -message VariancePopulation { +message Variance { Expr child = 1; bool null_on_divide_by_zero = 2; DataType datatype = 3; + StatisticsType stats_type = 4; } message Literal { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4c3914bee..fddd47294 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -469,15 +469,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val dataType = serializeDataType(variance.dataType) if (childExpr.isDefined && dataType.isDefined) { - val varBuilder = ExprOuterClass.VarianceSample.newBuilder() + val varBuilder = ExprOuterClass.Variance.newBuilder() varBuilder.setChild(childExpr.get) varBuilder.setNullOnDivideByZero(nullOnDivideByZero) varBuilder.setDatatype(dataType.get) + varBuilder.setStatsTypeValue(0) Some( ExprOuterClass.AggExpr .newBuilder() - .setVarianceSample(varBuilder) + .setVariance(varBuilder) .build()) } else { None @@ -487,15 +488,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val dataType = serializeDataType(variancePop.dataType) if (childExpr.isDefined && dataType.isDefined) { - val varBuilder = ExprOuterClass.VariancePopulation.newBuilder() + val varBuilder = ExprOuterClass.Variance.newBuilder() varBuilder.setChild(childExpr.get) varBuilder.setNullOnDivideByZero(nullOnDivideByZero) varBuilder.setDatatype(dataType.get) + varBuilder.setStatsTypeValue(1) Some( ExprOuterClass.AggExpr .newBuilder() - .setVariancePopulation(varBuilder) + .setVariance(varBuilder) .build()) } else { None