From bddbc713c191205b82669defaf5b73f23b668b37 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 21 Apr 2024 09:08:21 -0700 Subject: [PATCH] feat: Support Variance --- EXPRESSIONS.md | 5 +- .../execution/datafusion/expressions/mod.rs | 1 + .../datafusion/expressions/variance.rs | 259 ++++++++++++++++++ core/src/execution/datafusion/planner.rs | 23 ++ core/src/execution/proto/expr.proto | 14 + .../apache/comet/serde/QueryPlanSerde.scala | 38 ++- .../comet/exec/CometAggregateSuite.scala | 52 ++++ 7 files changed, 390 insertions(+), 2 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/variance.rs diff --git a/EXPRESSIONS.md b/EXPRESSIONS.md index 45c36844b2..f0a2f69551 100644 --- a/EXPRESSIONS.md +++ b/EXPRESSIONS.md @@ -103,4 +103,7 @@ The following Spark expressions are currently available: + BitXor + BoolAnd + BoolOr - + Covariance + + CovPopulation + + CovSample + + VariancePop + + VarianceSamp diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 799790c9fc..78763fc2a1 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -34,3 +34,4 @@ pub mod subquery; pub mod sum_decimal; pub mod temporal; mod utils; +pub mod variance; diff --git a/core/src/execution/datafusion/expressions/variance.rs b/core/src/execution/datafusion/expressions/variance.rs new file mode 100644 index 0000000000..88cd0d8598 --- /dev/null +++ b/core/src/execution/datafusion/expressions/variance.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::{any::Any, sync::Arc}; + +use crate::execution::datafusion::expressions::{stats::StatsType, utils::down_cast_any_ref}; +use arrow::{ + array::{ArrayRef, Float64Array}, + compute::cast, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; + +/// VAR_SAMP and VAR_POP aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. +#[derive(Debug)] +pub struct Variance { + name: String, + expr: Arc, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl Variance { + /// Create a new VARIANCE aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + stats_type: StatsType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + stats_type, + null_on_divide_by_zero, + } + } +} + +impl AggregateExpr for Variance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Variance { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name && self.expr.eq(&x.expr) && self.stats_type == x.stats_type + }) + .unwrap_or(false) + } +} + +/// An accumulator to compute variance +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: f64, + mean: f64, + count: f64, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result { + Ok(Self { + m2: 0_f64, + mean: 0_f64, + count: 0_f64, + stats_type: s_type, + null_on_divide_by_zero, + }) + } + + pub fn get_count(&self) -> f64 { + self.count + } + + pub fn get_mean(&self) -> f64 { + self.mean + } + + pub fn get_m2(&self) -> f64 { + self.m2 + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean), + ScalarValue::from(self.m2), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count + 1.0; + let delta1 = value - self.mean; + let new_mean = delta1 / new_count + self.mean; + let delta2 = value - new_mean; + let new_m2 = self.m2 + delta1 * delta2; + + self.count += 1.0; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count - 1.0; + let delta1 = self.mean - value; + let new_mean = delta1 / new_count + self.mean; + let delta2 = new_mean - value; + let new_m2 = self.m2 - delta1 * delta2; + + self.count -= 1.0; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Float64Array); + let means = downcast_value!(states[1], Float64Array); + let m2s = downcast_value!(states[2], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_f64 { + continue; + } + let new_count = self.count + c; + let new_mean = self.mean * self.count / new_count + means.value(i) * c / new_count; + let delta = self.mean - means.value(i); + let new_m2 = self.m2 + m2s.value(i) + delta * delta * self.count * c / new_count; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0.0 { + self.count - 1.0 + } else { + self.count + } + } + }; + + Ok(ScalarValue::Float64(match self.count { + count if count == 0.0 => None, + count if count == 1.0 => { + if let StatsType::Population = self.stats_type { + Some(0.0) + } else if self.null_on_divide_by_zero { + None + } else { + Some(f64::NAN) + } + } + _ => Some(self.m2 / count), + })) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index ca926bf183..381be106ed 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -75,6 +75,7 @@ use crate::{ subquery::Subquery, sum_decimal::SumDecimal, temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec}, + variance::Variance, NormalizeNaNAndZero, }, operators::expand::CometExpandExec, @@ -1219,6 +1220,28 @@ impl PhysicalPlanner { StatsType::Population, ))) } + AggExprStruct::VarianceSample(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(Variance::new( + child, + "variance", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + ))) + } + AggExprStruct::VariancePopulation(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(Variance::new( + child, + "variance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + ))) + } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index 1a6c29cf35..3210803822 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -94,6 +94,8 @@ message AggExpr { BitXorAgg bitXorAgg = 11; CovSample covSample = 12; CovPopulation covPopulation = 13; + VarianceSample varianceSample = 14; + VariancePopulation variancePopulation = 15; } } @@ -165,6 +167,18 @@ message CovPopulation { DataType datatype = 4; } +message VarianceSample { + Expr child = 1; + bool null_on_divide_by_zero = 2; + DataType datatype = 3; +} + +message VariancePopulation { + Expr child = 1; + bool null_on_divide_by_zero = 2; + DataType datatype = 3; +} + message Literal { oneof value { bool bool_val = 1; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 555ab4084c..d0adf3134a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ @@ -426,6 +426,42 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { None } + case variance @ VarianceSamp(child, nullOnDivideByZero) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(variance.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val varBuilder = ExprOuterClass.VarianceSample.newBuilder() + varBuilder.setChild(childExpr.get) + varBuilder.setNullOnDivideByZero(nullOnDivideByZero) + varBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setVarianceSample(varBuilder) + .build()) + } else { + None + } + case variancePop @ VariancePop(child, nullOnDivideByZero) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(variancePop.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val varBuilder = ExprOuterClass.VariancePopulation.newBuilder() + varBuilder.setChild(childExpr.get) + varBuilder.setNullOnDivideByZero(nullOnDivideByZero) + varBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setVariancePopulation(varBuilder) + .build()) + } else { + None + } case fn => emitWarning(s"unsupported Spark aggregate function: $fn") None diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index f6415cbfc6..bd4042ec11 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1117,6 +1117,46 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("var_pop and var_samp") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { cometColumnShuffleEnabled => + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + Seq(true, false).foreach { nullOnDivideByZero => + withSQLConf( + "spark.sql.legacy.statisticalAggregate" -> nullOnDivideByZero.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int, col4 float, col5 double, col6 int) using parquet") + sql(s"insert into $table values(1, null, null, 1.1, 2.2, 1)," + + " (2, null, null, 3.4, 5.6, 1), (3, null, 4, 7.9, 2.4, 2)") + val expectedNumOfCometAggregates = 2 + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT var_samp(col1), var_samp(col2), var_samp(col3), var_samp(col4), var_samp(col5) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT var_pop(col1), var_pop(col2), var_pop(col3), var_pop(col4), var_samp(col5) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT var_samp(col1), var_samp(col2), var_samp(col3), var_samp(col4), var_samp(col5)" + + " FROM test GROUP BY col6", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT var_pop(col1), var_pop(col2), var_pop(col3), var_pop(col4), var_samp(col5)" + + " FROM test GROUP BY col6", + expectedNumOfCometAggregates) + } + } + } + } + } + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) @@ -1126,6 +1166,18 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { s"Expected $numAggregates Comet aggregate operators, but found $actualNumAggregates") } + protected def checkSparkAnswerWithTolAndNumOfAggregates( + query: String, + numAggregates: Int, + absTol: Double = 1e-6): Unit = { + val df = sql(query) + checkSparkAnswerWithTol(df, absTol) + val actualNumAggregates = getNumCometHashAggregate(df) + assert( + actualNumAggregates == numAggregates, + s"Expected $numAggregates Comet aggregate operators, but found $actualNumAggregates") + } + def getNumCometHashAggregate(df: DataFrame): Int = { val sparkPlan = stripAQEPlan(df.queryExecution.executedPlan) sparkPlan.collect { case s: CometHashAggregateExec => s }.size