From 992b93c633080616c2b7bfa9a9a9077129d62099 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Thu, 29 Aug 2024 09:28:37 -0400 Subject: [PATCH] fix: Support type coercion for ScalarUDFs (#865) * Add type coercion to ScalarUDFs and support dictionaries in lists * Formatting * Remove unused var * Cleanup planner * Update comment for failing text * Add struct tests * Change back row count (cherry picked from commit 7484588b115dfa842957fbd7de0b26ecba08c7b1) --- .../apache/comet/vector/CometListVector.java | 10 ++-- .../org/apache/comet/vector/CometVector.java | 2 +- .../core/src/execution/datafusion/planner.rs | 52 ++++++++++++++----- .../apache/comet/serde/QueryPlanSerde.scala | 6 +-- .../apache/comet/CometExpressionSuite.scala | 11 ++++ 5 files changed, 60 insertions(+), 21 deletions(-) diff --git a/common/src/main/java/org/apache/comet/vector/CometListVector.java b/common/src/main/java/org/apache/comet/vector/CometListVector.java index 1c8f3e658c..752495c0d8 100644 --- a/common/src/main/java/org/apache/comet/vector/CometListVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometListVector.java @@ -21,6 +21,7 @@ import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.util.TransferPair; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; @@ -30,13 +31,16 @@ public class CometListVector extends CometDecodedVector { final ListVector listVector; final ValueVector dataVector; final ColumnVector dataColumnVector; + final DictionaryProvider dictionaryProvider; - public CometListVector(ValueVector vector, boolean useDecimal128) { + public CometListVector( + ValueVector vector, boolean useDecimal128, DictionaryProvider dictionaryProvider) { super(vector, vector.getField(), useDecimal128); this.listVector = ((ListVector) vector); this.dataVector = listVector.getDataVector(); - this.dataColumnVector = getVector(dataVector, useDecimal128); + this.dictionaryProvider = dictionaryProvider; + this.dataColumnVector = getVector(dataVector, useDecimal128, dictionaryProvider); } @Override @@ -52,6 +56,6 @@ public CometVector slice(int offset, int length) { TransferPair tp = this.valueVector.getTransferPair(this.valueVector.getAllocator()); tp.splitAndTransfer(offset, length); - return new CometListVector(tp.getTo(), useDecimal128); + return new CometListVector(tp.getTo(), useDecimal128, dictionaryProvider); } } diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java b/common/src/main/java/org/apache/comet/vector/CometVector.java index 3a5ae4476d..3b0ca35bf9 100644 --- a/common/src/main/java/org/apache/comet/vector/CometVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometVector.java @@ -239,7 +239,7 @@ protected static CometVector getVector( } else if (vector instanceof MapVector) { return new CometMapVector(vector, useDecimal128, dictionaryProvider); } else if (vector instanceof ListVector) { - return new CometListVector(vector, useDecimal128); + return new CometListVector(vector, useDecimal128, dictionaryProvider); } else { DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary(); CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128); diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 45de2bca97..76483cc1f8 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -49,6 +49,7 @@ use crate::{ serde::to_arrow_datatype, }, }; +use arrow::compute::CastOptions; use arrow_schema::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; use datafusion::functions_aggregate::min_max::max_udaf; @@ -1777,23 +1778,48 @@ impl PhysicalPlanner { .map(|x| x.data_type(input_schema.as_ref())) .collect::, _>>()?; - let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) { - Some(t) => t, - None => { - let fun_name = match fun_name.as_str() { - "read_side_padding" => "rpad", // use the same return type as rpad - other => other, - }; - self.session_ctx - .udf(fun_name)? - .inner() - .return_type(&input_expr_types)? - } - }; + let (data_type, coerced_input_types) = + match expr.return_type.as_ref().map(to_arrow_datatype) { + Some(t) => (t, input_expr_types.clone()), + None => { + let fun_name = match fun_name.as_ref() { + "read_side_padding" => "rpad", // use the same return type as rpad + other => other, + }; + let func = self.session_ctx.udf(fun_name)?; + + let coerced_types = func + .coerce_types(&input_expr_types) + .unwrap_or_else(|_| input_expr_types.clone()); + + let data_type = func.inner().return_type(&coerced_types)?; + + (data_type, coerced_types) + } + }; let fun_expr = create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; + let args = args + .into_iter() + .zip(input_expr_types.into_iter().zip(coerced_input_types)) + .map(|(expr, (from_type, to_type))| { + if !from_type.equals_datatype(&to_type) { + Arc::new(CastExpr::new( + expr, + to_type, + Some(CastOptions { + safe: false, + ..Default::default() + }), + )) + } else { + expr + } + }) + .collect::>(); + let scalar_expr: Arc = Arc::new(ScalarFunctionExpr::new( fun_name, fun_expr, diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 717ea8911d..596bc945cf 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2450,13 +2450,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim .build() } - // datafusion's make_array only supports nullable element types - // https://github.com/apache/datafusion/issues/11923 - case array @ CreateArray(children, _) if array.dataType.containsNull => + case CreateArray(children, _) => val childExprs = children.map(exprToProto(_, inputs, binding)) if (childExprs.forall(_.isDefined)) { - scalarExprToProtoWithReturnType("make_array", array.dataType, childExprs: _*) + scalarExprToProto("make_array", childExprs: _*) } else { withInfo(expr, "unsupported arguments for CreateArray", children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index c86cfa84ec..4261b6218e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2079,6 +2079,17 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val df = spark.read.parquet(path.toString) checkSparkAnswerAndOperator(df.select(array(col("_2"), col("_3"), col("_4")))) checkSparkAnswerAndOperator(df.select(array(col("_4"), col("_11"), lit(null)))) + checkSparkAnswerAndOperator( + df.select(array(array(col("_4")), array(col("_4"), lit(null))))) + checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13")))) + // This ends up returning empty strings instead of nulls for the last element + // Fixed by https://github.com/apache/datafusion/commit/27304239ef79b50a443320791755bf74eed4a85d + // checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13"), lit(null)))) + checkSparkAnswerAndOperator(df.select(array(array(col("_8")), array(col("_13"))))) + checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_8"), lit(null)))) + checkSparkAnswerAndOperator(df.select(array(struct("_4"), struct("_4")))) + // Fixed by https://github.com/apache/datafusion/commit/140f7cec78febd73d3db537a816badaaf567530a + // checkSparkAnswerAndOperator(df.select(array(struct(col("_8").alias("a")), struct(col("_13").alias("a"))))) } } }