From e183c0c789a9193ec52a5f825af6b4ec5a0f7e31 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 21 Apr 2024 09:08:21 -0700 Subject: [PATCH 1/5] feat: Support Variance --- EXPRESSIONS.md | 5 +- .../execution/datafusion/expressions/mod.rs | 1 + .../datafusion/expressions/variance.rs | 259 ++++++++++++++++++ core/src/execution/datafusion/planner.rs | 23 ++ core/src/execution/proto/expr.proto | 14 + .../apache/comet/serde/QueryPlanSerde.scala | 38 ++- .../comet/exec/CometAggregateSuite.scala | 52 ++++ 7 files changed, 390 insertions(+), 2 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/variance.rs diff --git a/EXPRESSIONS.md b/EXPRESSIONS.md index 45c36844b..f0a2f6955 100644 --- a/EXPRESSIONS.md +++ b/EXPRESSIONS.md @@ -103,4 +103,7 @@ The following Spark expressions are currently available: + BitXor + BoolAnd + BoolOr - + Covariance + + CovPopulation + + CovSample + + VariancePop + + VarianceSamp diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 799790c9f..78763fc2a 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -34,3 +34,4 @@ pub mod subquery; pub mod sum_decimal; pub mod temporal; mod utils; +pub mod variance; diff --git a/core/src/execution/datafusion/expressions/variance.rs b/core/src/execution/datafusion/expressions/variance.rs new file mode 100644 index 000000000..88cd0d859 --- /dev/null +++ b/core/src/execution/datafusion/expressions/variance.rs @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::{any::Any, sync::Arc}; + +use crate::execution::datafusion::expressions::{stats::StatsType, utils::down_cast_any_ref}; +use arrow::{ + array::{ArrayRef, Float64Array}, + compute::cast, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::{expressions::format_state_name, AggregateExpr, PhysicalExpr}; + +/// VAR_SAMP and VAR_POP aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. +#[derive(Debug)] +pub struct Variance { + name: String, + expr: Arc, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl Variance { + /// Create a new VARIANCE aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + stats_type: StatsType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + expr, + stats_type, + null_on_divide_by_zero, + } + } +} + +impl AggregateExpr for Variance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Variance { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name && self.expr.eq(&x.expr) && self.stats_type == x.stats_type + }) + .unwrap_or(false) + } +} + +/// An accumulator to compute variance +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: f64, + mean: f64, + count: f64, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result { + Ok(Self { + m2: 0_f64, + mean: 0_f64, + count: 0_f64, + stats_type: s_type, + null_on_divide_by_zero, + }) + } + + pub fn get_count(&self) -> f64 { + self.count + } + + pub fn get_mean(&self) -> f64 { + self.mean + } + + pub fn get_m2(&self) -> f64 { + self.m2 + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean), + ScalarValue::from(self.m2), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count + 1.0; + let delta1 = value - self.mean; + let new_mean = delta1 / new_count + self.mean; + let delta2 = value - new_mean; + let new_m2 = self.m2 + delta1 * delta2; + + self.count += 1.0; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count - 1.0; + let delta1 = self.mean - value; + let new_mean = delta1 / new_count + self.mean; + let delta2 = new_mean - value; + let new_m2 = self.m2 - delta1 * delta2; + + self.count -= 1.0; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Float64Array); + let means = downcast_value!(states[1], Float64Array); + let m2s = downcast_value!(states[2], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_f64 { + continue; + } + let new_count = self.count + c; + let new_mean = self.mean * self.count / new_count + means.value(i) * c / new_count; + let delta = self.mean - means.value(i); + let new_m2 = self.m2 + m2s.value(i) + delta * delta * self.count * c / new_count; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0.0 { + self.count - 1.0 + } else { + self.count + } + } + }; + + Ok(ScalarValue::Float64(match self.count { + count if count == 0.0 => None, + count if count == 1.0 => { + if let StatsType::Population = self.stats_type { + Some(0.0) + } else if self.null_on_divide_by_zero { + None + } else { + Some(f64::NAN) + } + } + _ => Some(self.m2 / count), + })) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 5c379d43d..2bf2e4576 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -75,6 +75,7 @@ use crate::{ subquery::Subquery, sum_decimal::SumDecimal, temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec}, + variance::Variance, NormalizeNaNAndZero, }, operators::expand::CometExpandExec, @@ -1235,6 +1236,28 @@ impl PhysicalPlanner { StatsType::Population, ))) } + AggExprStruct::VarianceSample(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(Variance::new( + child, + "variance", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + ))) + } + AggExprStruct::VariancePopulation(expr) => { + let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; + let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(Variance::new( + child, + "variance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + ))) + } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index afe75ecb4..ba18aee78 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -94,6 +94,8 @@ message AggExpr { BitXorAgg bitXorAgg = 11; CovSample covSample = 12; CovPopulation covPopulation = 13; + VarianceSample varianceSample = 14; + VariancePopulation variancePopulation = 15; } } @@ -165,6 +167,18 @@ message CovPopulation { DataType datatype = 4; } +message VarianceSample { + Expr child = 1; + bool null_on_divide_by_zero = 2; + DataType datatype = 3; +} + +message VariancePopulation { + Expr child = 1; + bool null_on_divide_by_zero = 2; + DataType datatype = 3; +} + message Literal { oneof value { bool bool_val = 1; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9a12930fc..4c3914bee 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ @@ -464,6 +464,42 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } else { None } + case variance @ VarianceSamp(child, nullOnDivideByZero) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(variance.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val varBuilder = ExprOuterClass.VarianceSample.newBuilder() + varBuilder.setChild(childExpr.get) + varBuilder.setNullOnDivideByZero(nullOnDivideByZero) + varBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setVarianceSample(varBuilder) + .build()) + } else { + None + } + case variancePop @ VariancePop(child, nullOnDivideByZero) => + val childExpr = exprToProto(child, inputs, binding) + val dataType = serializeDataType(variancePop.dataType) + + if (childExpr.isDefined && dataType.isDefined) { + val varBuilder = ExprOuterClass.VariancePopulation.newBuilder() + varBuilder.setChild(childExpr.get) + varBuilder.setNullOnDivideByZero(nullOnDivideByZero) + varBuilder.setDatatype(dataType.get) + + Some( + ExprOuterClass.AggExpr + .newBuilder() + .setVariancePopulation(varBuilder) + .build()) + } else { + None + } case fn => val msg = s"unsupported Spark aggregate function: ${fn.prettyName}" emitWarning(msg) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index f6415cbfc..bd4042ec1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1117,6 +1117,46 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("var_pop and var_samp") { + withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") { + Seq(true, false).foreach { cometColumnShuffleEnabled => + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> cometColumnShuffleEnabled.toString) { + Seq(true, false).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + Seq(true, false).foreach { nullOnDivideByZero => + withSQLConf( + "spark.sql.legacy.statisticalAggregate" -> nullOnDivideByZero.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 int, col2 int, col3 int, col4 float, col5 double, col6 int) using parquet") + sql(s"insert into $table values(1, null, null, 1.1, 2.2, 1)," + + " (2, null, null, 3.4, 5.6, 1), (3, null, 4, 7.9, 2.4, 2)") + val expectedNumOfCometAggregates = 2 + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT var_samp(col1), var_samp(col2), var_samp(col3), var_samp(col4), var_samp(col5) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerWithTolAndNumOfAggregates( + "SELECT var_pop(col1), var_pop(col2), var_pop(col3), var_pop(col4), var_samp(col5) FROM test", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT var_samp(col1), var_samp(col2), var_samp(col3), var_samp(col4), var_samp(col5)" + + " FROM test GROUP BY col6", + expectedNumOfCometAggregates) + checkSparkAnswerAndNumOfAggregates( + "SELECT var_pop(col1), var_pop(col2), var_pop(col3), var_pop(col4), var_samp(col5)" + + " FROM test GROUP BY col6", + expectedNumOfCometAggregates) + } + } + } + } + } + } + } + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) @@ -1126,6 +1166,18 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { s"Expected $numAggregates Comet aggregate operators, but found $actualNumAggregates") } + protected def checkSparkAnswerWithTolAndNumOfAggregates( + query: String, + numAggregates: Int, + absTol: Double = 1e-6): Unit = { + val df = sql(query) + checkSparkAnswerWithTol(df, absTol) + val actualNumAggregates = getNumCometHashAggregate(df) + assert( + actualNumAggregates == numAggregates, + s"Expected $numAggregates Comet aggregate operators, but found $actualNumAggregates") + } + def getNumCometHashAggregate(df: DataFrame): Int = { val sparkPlan = stripAQEPlan(df.queryExecution.executedPlan) sparkPlan.collect { case s: CometHashAggregateExec => s }.size From 9c1c808ca58d05212ac22ce451abd9f853e3799c Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 23 Apr 2024 17:58:15 -0700 Subject: [PATCH 2/5] Add StatisticsType in expr.poto --- core/src/execution/datafusion/planner.rs | 40 ++++++++++--------- core/src/execution/proto/expr.proto | 17 ++++---- .../apache/comet/serde/QueryPlanSerde.scala | 10 +++-- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 2bf2e4576..bc2d9bed3 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -1236,27 +1236,29 @@ impl PhysicalPlanner { StatsType::Population, ))) } - AggExprStruct::VarianceSample(expr) => { + AggExprStruct::Variance(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(Variance::new( - child, - "variance", - datatype, - StatsType::Sample, - expr.null_on_divide_by_zero, - ))) - } - AggExprStruct::VariancePopulation(expr) => { - let child = self.create_expr(expr.child.as_ref().unwrap(), schema.clone())?; - let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); - Ok(Arc::new(Variance::new( - child, - "variance_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, - ))) + match expr.stats_type { + 0 => Ok(Arc::new(Variance::new( + child, + "variance", + datatype, + StatsType::Sample, + expr.null_on_divide_by_zero, + ))), + 1 => Ok(Arc::new(Variance::new( + child, + "variance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, + ))), + stats_type => Err(ExecutionError::GeneralError(format!( + "Unknown StatisticsType {:?} for Variance", + stats_type + ))), + } } } } diff --git a/core/src/execution/proto/expr.proto b/core/src/execution/proto/expr.proto index ba18aee78..042a981f4 100644 --- a/core/src/execution/proto/expr.proto +++ b/core/src/execution/proto/expr.proto @@ -94,11 +94,15 @@ message AggExpr { BitXorAgg bitXorAgg = 11; CovSample covSample = 12; CovPopulation covPopulation = 13; - VarianceSample varianceSample = 14; - VariancePopulation variancePopulation = 15; + Variance variance = 14; } } +enum StatisticsType { + SAMPLE = 0; + POPULATION = 1; +} + message Count { repeated Expr children = 1; } @@ -167,16 +171,11 @@ message CovPopulation { DataType datatype = 4; } -message VarianceSample { - Expr child = 1; - bool null_on_divide_by_zero = 2; - DataType datatype = 3; -} - -message VariancePopulation { +message Variance { Expr child = 1; bool null_on_divide_by_zero = 2; DataType datatype = 3; + StatisticsType stats_type = 4; } message Literal { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4c3914bee..fddd47294 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -469,15 +469,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val dataType = serializeDataType(variance.dataType) if (childExpr.isDefined && dataType.isDefined) { - val varBuilder = ExprOuterClass.VarianceSample.newBuilder() + val varBuilder = ExprOuterClass.Variance.newBuilder() varBuilder.setChild(childExpr.get) varBuilder.setNullOnDivideByZero(nullOnDivideByZero) varBuilder.setDatatype(dataType.get) + varBuilder.setStatsTypeValue(0) Some( ExprOuterClass.AggExpr .newBuilder() - .setVarianceSample(varBuilder) + .setVariance(varBuilder) .build()) } else { None @@ -487,15 +488,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { val dataType = serializeDataType(variancePop.dataType) if (childExpr.isDefined && dataType.isDefined) { - val varBuilder = ExprOuterClass.VariancePopulation.newBuilder() + val varBuilder = ExprOuterClass.Variance.newBuilder() varBuilder.setChild(childExpr.get) varBuilder.setNullOnDivideByZero(nullOnDivideByZero) varBuilder.setDatatype(dataType.get) + varBuilder.setStatsTypeValue(1) Some( ExprOuterClass.AggExpr .newBuilder() - .setVariancePopulation(varBuilder) + .setVariance(varBuilder) .build()) } else { None From 862d20a178a32e480b7e64c6a9aab8a25d9cfaa1 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 23 Apr 2024 18:14:09 -0700 Subject: [PATCH 3/5] add explainPlan info and fix fmt --- core/src/execution/datafusion/planner.rs | 10 +++++----- .../scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 ++ 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index bc2d9bed3..72174790b 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -1248,11 +1248,11 @@ impl PhysicalPlanner { expr.null_on_divide_by_zero, ))), 1 => Ok(Arc::new(Variance::new( - child, - "variance_pop", - datatype, - StatsType::Population, - expr.null_on_divide_by_zero, + child, + "variance_pop", + datatype, + StatsType::Population, + expr.null_on_divide_by_zero, ))), stats_type => Err(ExecutionError::GeneralError(format!( "Unknown StatisticsType {:?} for Variance", diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index fddd47294..d08fb6b90 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -481,6 +481,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setVariance(varBuilder) .build()) } else { + withInfo(aggExpr, child) None } case variancePop @ VariancePop(child, nullOnDivideByZero) => @@ -500,6 +501,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { .setVariance(varBuilder) .build()) } else { + withInfo(aggExpr, child) None } case fn => From d11f47e0ce244ea1d357d93734ea388266752ffe Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 24 Apr 2024 13:34:00 -0700 Subject: [PATCH 4/5] remove iunnecessary cast --- core/src/execution/datafusion/expressions/variance.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/execution/datafusion/expressions/variance.rs b/core/src/execution/datafusion/expressions/variance.rs index 88cd0d859..58446a773 100644 --- a/core/src/execution/datafusion/expressions/variance.rs +++ b/core/src/execution/datafusion/expressions/variance.rs @@ -167,8 +167,7 @@ impl Accumulator for VarianceAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); + let arr = downcast_value!(&values[0], Float64Array).iter().flatten(); for value in arr { let new_count = self.count + 1.0; @@ -186,8 +185,7 @@ impl Accumulator for VarianceAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); + let arr = downcast_value!(&values[0], Float64Array).iter().flatten(); for value in arr { let new_count = self.count - 1.0; From 86f68a465fb23e74fe5b935b2a35fa9d63a1ae58 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 24 Apr 2024 13:39:28 -0700 Subject: [PATCH 5/5] remove unused import --- core/src/execution/datafusion/expressions/variance.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/execution/datafusion/expressions/variance.rs b/core/src/execution/datafusion/expressions/variance.rs index 58446a773..6aae01ed8 100644 --- a/core/src/execution/datafusion/expressions/variance.rs +++ b/core/src/execution/datafusion/expressions/variance.rs @@ -22,7 +22,6 @@ use std::{any::Any, sync::Arc}; use crate::execution::datafusion::expressions::{stats::StatsType, utils::down_cast_any_ref}; use arrow::{ array::{ArrayRef, Float64Array}, - compute::cast, datatypes::{DataType, Field}, }; use datafusion::logical_expr::Accumulator;