From 7fcf314e7459b973cd908a485f18fcea8ae037f3 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 20 Aug 2024 18:55:53 -0400 Subject: [PATCH 01/20] Start messing with get array item --- .../apache/comet/serde/QueryPlanSerde.scala | 32 ++++++++++- .../apache/comet/CometExpressionSuite.scala | 57 +++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index caeda5167..06564b72e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2392,12 +2392,42 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val childExprs = children.map(exprToProto(_, inputs, binding)) if (childExprs.forall(_.isDefined)) { - scalarExprToProtoWithReturnType("make_array", array.dataType, childExprs: _*) + scalarExprToProto("make_array", childExprs: _*) } else { withInfo(expr, "unsupported arguments for CreateArray", children: _*) None } + case get @ GetArrayItem(child, ordinal, _) => + val childExpr = exprToProto(child, inputs, binding) + + // DataFusion expects the indices to be int64 + val ordinalExpr = + exprToProto(Add(Cast(ordinal, LongType), Literal(1L)), inputs, binding) + // scalastyle:off println + println(ordinal.dataType) + + if (childExpr.isDefined && ordinalExpr.isDefined) { + scalarExprToProtoWithReturnType("array_element", get.dataType, childExpr, ordinalExpr) + } else { + withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) + None + } + + case get @ ElementAt(child, ordinal, _, _) => + val childExpr = exprToProto(child, inputs, binding) + + // DataFusion expects the indices to be int64 + val ordinalExpr = + exprToProto(Cast(ordinal, LongType), inputs, binding) + + if (childExpr.isDefined && ordinalExpr.isDefined) { + scalarExprToProtoWithReturnType("array_element", get.dataType, childExpr, ordinalExpr) + } else { + withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) + None + } + case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 38945bbb2..5b4f48b58 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2009,6 +2009,63 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val df = spark.read.parquet(path.toString) checkSparkAnswerAndOperator(df.select(array(col("_2"), col("_3"), col("_4")))) checkSparkAnswerAndOperator(df.select(array(col("_4"), col("_11"), lit(null)))) + checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13")))) + checkSparkAnswerAndOperator(df.select(array(col("_13"), col("_13")))) + } + } + } + + test("GetArrayItem") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + + Seq(false).foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString()) { + val df = spark.read + .parquet(path.toString) + + val stringArray = df.select(array(col("_8"), col("_13")).alias("arr")) + stringArray.show() + checkSparkAnswerAndOperator( + stringArray.select( + col("arr").getItem(-3), + col("arr").getItem(-4), + col("arr").getItem(1), + col("arr").getItem(2))) + + // stringArray.select( + // col("arr").getItem(-1), + // col("arr").getItem(-2), + // col("arr").getItem(1), + // col("arr").getItem(2)) + // .show() + + stringArray.select( + // element_at(col("arr"), -2), + // element_at(col("arr"), -1), + // element_at(col("arr"), 1), + element_at(col("arr"), lit(2))) + .show() + + // val intArray = df.select(array(col("_2"), col("_3"), col("_4")).alias("arr")) + // checkSparkAnswerAndOperator( + // intArray + // .select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(-1))) + + // intArray + // .select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(-1)) + // .explain() + + // checkSparkAnswerAndOperator( + // intArray.select( + // element_at(col("arr"), 1), + // element_at(col("arr"), 3), + // element_at(col("arr"), 4), + // element_at(col("arr"), -1))) + } + } } } } From 4a8037904362fe1068c298bd3abb4524c6497c0e Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 21 Aug 2024 07:24:27 -0400 Subject: [PATCH 02/20] Add type coercion to ScalarUDFs and support dictionaries in lists --- .../apache/comet/vector/CometListVector.java | 10 ++- .../org/apache/comet/vector/CometVector.java | 2 +- .../core/src/execution/datafusion/planner.rs | 84 ++++++++++++++++--- .../apache/comet/serde/QueryPlanSerde.scala | 4 +- .../apache/comet/CometExpressionSuite.scala | 8 +- 5 files changed, 88 insertions(+), 20 deletions(-) diff --git a/common/src/main/java/org/apache/comet/vector/CometListVector.java b/common/src/main/java/org/apache/comet/vector/CometListVector.java index 1c8f3e658..752495c0d 100644 --- a/common/src/main/java/org/apache/comet/vector/CometListVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometListVector.java @@ -21,6 +21,7 @@ import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.util.TransferPair; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; @@ -30,13 +31,16 @@ public class CometListVector extends CometDecodedVector { final ListVector listVector; final ValueVector dataVector; final ColumnVector dataColumnVector; + final DictionaryProvider dictionaryProvider; - public CometListVector(ValueVector vector, boolean useDecimal128) { + public CometListVector( + ValueVector vector, boolean useDecimal128, DictionaryProvider dictionaryProvider) { super(vector, vector.getField(), useDecimal128); this.listVector = ((ListVector) vector); this.dataVector = listVector.getDataVector(); - this.dataColumnVector = getVector(dataVector, useDecimal128); + this.dictionaryProvider = dictionaryProvider; + this.dataColumnVector = getVector(dataVector, useDecimal128, dictionaryProvider); } @Override @@ -52,6 +56,6 @@ public CometVector slice(int offset, int length) { TransferPair tp = this.valueVector.getTransferPair(this.valueVector.getAllocator()); tp.splitAndTransfer(offset, length); - return new CometListVector(tp.getTo(), useDecimal128); + return new CometListVector(tp.getTo(), useDecimal128, dictionaryProvider); } } diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java b/common/src/main/java/org/apache/comet/vector/CometVector.java index 3a5ae4476..3b0ca35bf 100644 --- a/common/src/main/java/org/apache/comet/vector/CometVector.java +++ b/common/src/main/java/org/apache/comet/vector/CometVector.java @@ -239,7 +239,7 @@ protected static CometVector getVector( } else if (vector instanceof MapVector) { return new CometMapVector(vector, useDecimal128, dictionaryProvider); } else if (vector instanceof ListVector) { - return new CometListVector(vector, useDecimal128); + return new CometListVector(vector, useDecimal128, dictionaryProvider); } else { DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary(); CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128); diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index fe6ef9f7b..df1d1b5a9 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -49,6 +49,7 @@ use crate::{ serde::to_arrow_datatype, }, }; +use arrow::compute::CastOptions; use arrow_schema::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; use datafusion::functions_aggregate::min_max::max_udaf; @@ -1719,6 +1720,33 @@ impl PhysicalPlanner { } } + fn infer_data_type( + &self, + fun_name: &str, + arg_types: &[DataType], + ) -> Result { + let fun_name = match fun_name { + "read_side_padding" => "rpad", // use the same return type as rpad + other => other, + }; + Ok(self + .session_ctx + .udf(fun_name)? + .inner() + .return_type(arg_types)?) + } + + fn coerce_types(&self, fun_name: &str, arg_types: &[DataType]) -> Option> { + let fun_name = match fun_name { + "read_side_padding" => "rpad", // use the same return type as rpad + other => other, + }; + self.session_ctx + .udf(fun_name) + .ok() + .and_then(|func| func.coerce_types(arg_types).ok()) + } + fn create_scalar_function_expr( &self, expr: &ScalarFunc, @@ -1736,23 +1764,53 @@ impl PhysicalPlanner { .map(|x| x.data_type(input_schema.as_ref())) .collect::, _>>()?; - let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) { - Some(t) => t, - None => { - let fun_name = match fun_name.as_str() { - "read_side_padding" => "rpad", // use the same return type as rpad - other => other, - }; - self.session_ctx - .udf(fun_name)? - .inner() - .return_type(&input_expr_types)? - } - }; + let (data_type, coerced_input_types) = + match expr.return_type.as_ref().map(to_arrow_datatype) { + Some(t) => (t, input_expr_types), + None => { + let fun_name = match fun_name.as_ref() { + "read_side_padding" => "rpad", // use the same return type as rpad + other => other, + }; + let func = self.session_ctx.udf(fun_name)?; + + let coerced_types = func + .coerce_types(&input_expr_types) + .unwrap_or(input_expr_types); + + let data_type = func.inner().return_type(&coerced_types)?; + + (data_type, coerced_types) + } + }; let fun_expr = create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?; + let args = args + .into_iter() + .zip(coerced_input_types) + .map(|(expr, to_type)| { + if !expr + .data_type(input_schema.as_ref()) + .unwrap() + .equals_datatype(&to_type) + { + println!("Casting {:?} to {:?}", expr, to_type); + Arc::new(CastExpr::new( + expr, + to_type, + Some(CastOptions { + safe: false, + ..Default::default() + }), + )) + } else { + expr + } + }) + .collect::>(); + let scalar_expr: Arc = Arc::new(ScalarFunctionExpr::new( fun_name, fun_expr, diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 5ef924f6a..92c6490fd 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2388,11 +2388,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // datafusion's make_array only supports nullable element types // https://github.com/apache/datafusion/issues/11923 - case array @ CreateArray(children, _) if array.dataType.containsNull => + case array @ CreateArray(children, _) => val childExprs = children.map(exprToProto(_, inputs, binding)) if (childExprs.forall(_.isDefined)) { - scalarExprToProtoWithReturnType("make_array", array.dataType, childExprs: _*) + scalarExprToProto("make_array", childExprs: _*) } else { withInfo(expr, "unsupported arguments for CreateArray", children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 09d7ca979..20d35dc78 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2003,10 +2003,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 1000) val df = spark.read.parquet(path.toString) checkSparkAnswerAndOperator(df.select(array(col("_2"), col("_3"), col("_4")))) checkSparkAnswerAndOperator(df.select(array(col("_4"), col("_11"), lit(null)))) + checkSparkAnswerAndOperator(df.select(array(array(col("_4")), array(col("_4"), lit(null))))) + checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13")))) + // TODO: Some part of this converts the null to an empty string + // checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13"), lit(null)))) + checkSparkAnswerAndOperator(df.select(array(array(col("_8")), array(col("_13"))))) + checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_8"), lit(null)))) } } } From 7e396a2d54deadd464e26d6f6ccde8fd46af591b Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 21 Aug 2024 07:50:51 -0400 Subject: [PATCH 03/20] Formatting --- native/core/src/execution/datafusion/planner.rs | 1 - .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index df1d1b5a9..73b6b2a19 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1796,7 +1796,6 @@ impl PhysicalPlanner { .unwrap() .equals_datatype(&to_type) { - println!("Casting {:?} to {:?}", expr, to_type); Arc::new(CastExpr::new( expr, to_type, diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 20d35dc78..a9d55cfab 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2007,7 +2007,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val df = spark.read.parquet(path.toString) checkSparkAnswerAndOperator(df.select(array(col("_2"), col("_3"), col("_4")))) checkSparkAnswerAndOperator(df.select(array(col("_4"), col("_11"), lit(null)))) - checkSparkAnswerAndOperator(df.select(array(array(col("_4")), array(col("_4"), lit(null))))) + checkSparkAnswerAndOperator( + df.select(array(array(col("_4")), array(col("_4"), lit(null))))) checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13")))) // TODO: Some part of this converts the null to an empty string // checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13"), lit(null)))) From 4f42fcbc3f75b41f85287439d03ef2a2c7318bf5 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 21 Aug 2024 17:17:28 +0000 Subject: [PATCH 04/20] Remove unused var --- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 92c6490fd..16b7b6640 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2386,9 +2386,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim .build() } - // datafusion's make_array only supports nullable element types - // https://github.com/apache/datafusion/issues/11923 - case array @ CreateArray(children, _) => + case CreateArray(children, _) => val childExprs = children.map(exprToProto(_, inputs, binding)) if (childExprs.forall(_.isDefined)) { From aecb44f3fd86fdeea371793aae5a66a7519caa48 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Thu, 22 Aug 2024 07:34:21 -0400 Subject: [PATCH 05/20] Cleanup planner --- .../core/src/execution/datafusion/planner.rs | 41 +++---------------- 1 file changed, 5 insertions(+), 36 deletions(-) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 73b6b2a19..998d08b25 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -1720,33 +1720,6 @@ impl PhysicalPlanner { } } - fn infer_data_type( - &self, - fun_name: &str, - arg_types: &[DataType], - ) -> Result { - let fun_name = match fun_name { - "read_side_padding" => "rpad", // use the same return type as rpad - other => other, - }; - Ok(self - .session_ctx - .udf(fun_name)? - .inner() - .return_type(arg_types)?) - } - - fn coerce_types(&self, fun_name: &str, arg_types: &[DataType]) -> Option> { - let fun_name = match fun_name { - "read_side_padding" => "rpad", // use the same return type as rpad - other => other, - }; - self.session_ctx - .udf(fun_name) - .ok() - .and_then(|func| func.coerce_types(arg_types).ok()) - } - fn create_scalar_function_expr( &self, expr: &ScalarFunc, @@ -1766,7 +1739,7 @@ impl PhysicalPlanner { let (data_type, coerced_input_types) = match expr.return_type.as_ref().map(to_arrow_datatype) { - Some(t) => (t, input_expr_types), + Some(t) => (t, input_expr_types.clone()), None => { let fun_name = match fun_name.as_ref() { "read_side_padding" => "rpad", // use the same return type as rpad @@ -1776,7 +1749,7 @@ impl PhysicalPlanner { let coerced_types = func .coerce_types(&input_expr_types) - .unwrap_or(input_expr_types); + .unwrap_or_else(|_| input_expr_types.clone()); let data_type = func.inner().return_type(&coerced_types)?; @@ -1789,13 +1762,9 @@ impl PhysicalPlanner { let args = args .into_iter() - .zip(coerced_input_types) - .map(|(expr, to_type)| { - if !expr - .data_type(input_schema.as_ref()) - .unwrap() - .equals_datatype(&to_type) - { + .zip(input_expr_types.into_iter().zip(coerced_input_types)) + .map(|(expr, (from_type, to_type))| { + if !from_type.equals_datatype(&to_type) { Arc::new(CastExpr::new( expr, to_type, From 8b392ee8d2ac208acdd502a74bcc0bac16d784dc Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 23 Aug 2024 18:11:38 +0000 Subject: [PATCH 06/20] Update comment for failing text --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index a9d55cfab..8a4bf1267 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2010,7 +2010,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator( df.select(array(array(col("_4")), array(col("_4"), lit(null))))) checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13")))) - // TODO: Some part of this converts the null to an empty string + // This ends up returning empty strings instead of nulls for the last element + // Fixed by https://github.com/apache/datafusion/commit/27304239ef79b50a443320791755bf74eed4a85d // checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13"), lit(null)))) checkSparkAnswerAndOperator(df.select(array(array(col("_8")), array(col("_13"))))) checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_8"), lit(null)))) From d7783ec00cc5edfc6b3f0f11dbe2d624d228d1f8 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sat, 24 Aug 2024 07:46:21 -0400 Subject: [PATCH 07/20] Add struct tests --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 52997d1d0..9c735a3c7 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2014,6 +2014,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13"), lit(null)))) checkSparkAnswerAndOperator(df.select(array(array(col("_8")), array(col("_13"))))) checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_8"), lit(null)))) + checkSparkAnswerAndOperator(df.select(array(struct("_4"), struct("_4")))) + // Fixed by https://github.com/apache/datafusion/commit/140f7cec78febd73d3db537a816badaaf567530a + // checkSparkAnswerAndOperator(df.select(array(struct(col("_8").alias("a")), struct(col("_13").alias("a"))))) } } } From 1259267cbe04362a3167b01e778519f22157dcb1 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sun, 25 Aug 2024 21:36:58 -0400 Subject: [PATCH 08/20] Support array extraction --- .../apache/spark/sql/comet/util/Utils.scala | 3 +- .../core/src/execution/datafusion/planner.rs | 22 +- native/proto/src/proto/expr.proto | 9 + native/spark-expr/src/array.rs | 313 ++++++++++++++++++ native/spark-expr/src/lib.rs | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 60 +++- .../sql/comet/CometSparkToColumnarExec.scala | 2 +- .../apache/comet/CometExpressionSuite.scala | 94 +++--- 8 files changed, 445 insertions(+), 60 deletions(-) create mode 100644 native/spark-expr/src/array.rs diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 8d6a63343..d92df66af 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.c.CDataDictionaryProvider import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} +import org.apache.arrow.vector.complex.ListVector import org.apache.arrow.vector.complex.MapVector import org.apache.arrow.vector.complex.StructVector import org.apache.arrow.vector.dictionary.DictionaryProvider @@ -259,7 +260,7 @@ object Utils { case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector | _: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector | - _: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector) => + _: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector | _: ListVector) => v.asInstanceOf[FieldVector] case _ => throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}") diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index aa957e7ac..01fe8c2a1 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -95,8 +95,8 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, MinuteExpr, RLike, - SecondExpr, TimestampTruncExpr, + ArrayExtract, Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, + MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, }; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ @@ -651,6 +651,24 @@ impl PhysicalPlanner { self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; Ok(Arc::new(GetStructField::new(child, expr.ordinal as usize))) } + ExprStruct::ArrayExtract(expr) => { + let child = + self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; + let ordinal = + self.create_expr(expr.ordinal.as_ref().unwrap(), Arc::clone(&input_schema))?; + let default_value = expr + .default_value + .as_ref() + .map(|e| self.create_expr(e, Arc::clone(&input_schema))) + .transpose()?; + Ok(Arc::new(ArrayExtract::new( + child, + ordinal, + default_value, + expr.one_based, + expr.fail_on_error, + ))) + } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", expr diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 50ab8f514..3a695e710 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -79,6 +79,7 @@ message Expr { BloomFilterMightContain bloom_filter_might_contain = 52; CreateNamedStruct create_named_struct = 53; GetStructField get_struct_field = 54; + ArrayExtract array_extract = 55; } } @@ -498,6 +499,14 @@ message GetStructField { int32 ordinal = 2; } +message ArrayExtract { + Expr child = 1; + Expr ordinal = 2; + bool one_based = 3; + Expr default_value = 4; + bool fail_on_error = 5; +} + enum SortDirection { Ascending = 0; Descending = 1; diff --git a/native/spark-expr/src/array.rs b/native/spark-expr/src/array.rs new file mode 100644 index 000000000..b6a1dce46 --- /dev/null +++ b/native/spark-expr/src/array.rs @@ -0,0 +1,313 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch}; +use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait}; +use arrow_schema::{DataType, FieldRef, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_int32_array, as_large_list_array, as_list_array}, + DataFusionError, Result as DataFusionResult, ScalarValue, +}; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use crate::utils::down_cast_any_ref; + +#[derive(Debug, Hash)] +pub struct ArrayExtract { + child: Arc, + ordinal: Arc, + default_value: Option>, + one_based: bool, + fail_on_error: bool, +} + +impl ArrayExtract { + pub fn new( + child: Arc, + ordinal: Arc, + default_value: Option>, + one_based: bool, + fail_on_error: bool, + ) -> Self { + Self { + child, + ordinal, + default_value, + one_based, + fail_on_error, + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in ArrayExtract: {:?}", + data_type + ))), + } + } + + // fn one_based_index(&self, index: i32, len: ) +} + +impl PhysicalExpr for ArrayExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + // Only non-nullable if fail_on_error is enabled and the element is non-nullable + Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?; + + let default_value = self + .default_value + .as_ref() + .map(|d| { + d.evaluate(batch).map(|value| match value { + // TODO: Handle converting this to dictionary + ColumnarValue::Scalar(scalar) => scalar, + _ => todo!(), + }) + }) + .unwrap_or(self.data_type(&batch.schema())?.try_into())?; + + let adjust_index = if self.one_based { + one_based_index + } else { + zero_based_index + }; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + let index_array = as_int32_array(&ordinal_value)?; + + array_extract( + list_array, + index_array, + &default_value, + self.fail_on_error, + adjust_index, + ) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + let index_array = as_int32_array(&ordinal_value)?; + + array_extract( + list_array, + index_array, + &default_value, + self.fail_on_error, + adjust_index, + ) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ArrayExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child, &self.ordinal] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(ArrayExtract::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.default_value.clone(), + self.one_based, + self.fail_on_error, + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.ordinal.hash(&mut s); + self.one_based.hash(&mut s); + self.default_value.hash(&mut s); + self.fail_on_error.hash(&mut s); + self.hash(&mut s); + } +} + +fn one_based_index(index: i32, len: usize) -> DataFusionResult> { + if index == 0 { + return Err(DataFusionError::Execution( + "Invalid index of 0 for one-based ArrayExtract".to_string(), + )); + } + + let abs_index = index.abs().as_usize(); + if abs_index <= len { + if index > 0 { + Ok(Some(index.as_usize() - 1)) + } else { + Ok(Some(len - abs_index)) + } + } else { + Ok(None) + } +} + +fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { + if index < 0 { + Ok(None) + } else { + let positive_index = index.as_usize(); + if positive_index < len { + Ok(Some(positive_index)) + } else { + Ok(None) + } + } +} + +fn array_extract( + list_array: &GenericListArray, + index_array: &Int32Array, + default_value: &ScalarValue, + fail_on_error: bool, + adjust_index: impl Fn(i32, usize) -> DataFusionResult>, +) -> DataFusionResult { + let values = list_array.values(); + let offsets = list_array.offsets(); + + let data = values.to_data(); + + let default_data = default_value.to_array()?.to_data(); + + let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len()); + + for (offset_window, index) in offsets.windows(2).zip(index_array.values()) { + let start = offset_window[0].as_usize(); + let len = offset_window[1].as_usize() - start; + + if let Some(i) = adjust_index(*index, len)? { + mutable.extend(0, start + i, start + i + 1); + } else if fail_on_error { + return Err(DataFusionError::Execution( + "Index out of bounds for array".to_string(), + )); + } else { + mutable.extend(1, 0, 1); + } + } + + let data = mutable.freeze(); + Ok(ColumnarValue::Array(arrow::array::make_array(data))) +} + +impl Display for ArrayExtract { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ArrayExtract [child: {:?}, ordinal: {:?}, one_based: {:?}, default_value: {:?}, fail_on_error: {:?}]", + self.child, self.ordinal, self.one_based, self.default_value, self.fail_on_error + ) + } +} + +impl PartialEq for ArrayExtract { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.child.eq(&x.child) + && self.ordinal.eq(&x.ordinal) + && (self.default_value.is_none() == x.default_value.is_none()) + && self + .default_value + .as_ref() + .zip(x.default_value.as_ref()) + .map(|(s, x)| s.eq(x)) + .unwrap_or(true) + && self.one_based.eq(&x.one_based) + && self.fail_on_error.eq(&x.fail_on_error) + }) + .unwrap_or(false) + } +} + +// #[cfg(test)] +// mod test { +// use super::CreateNamedStruct; +// use arrow_array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray}; +// use arrow_schema::{DataType, Field, Schema}; +// use datafusion_common::Result; +// use datafusion_expr::ColumnarValue; +// use datafusion_physical_expr_common::expressions::column::Column; +// use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +// use std::sync::Arc; + +// #[test] +// fn test_create_struct_from_dict_encoded_i32() -> Result<()> { +// let keys = Int32Array::from(vec![0, 1, 2]); +// let values = Int32Array::from(vec![0, 111, 233]); +// let dict = DictionaryArray::try_new(keys, Arc::new(values))?; +// let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32)); +// let schema = Schema::new(vec![Field::new("a", data_type, false)]); +// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?; +// let field_names = vec!["a".to_string()]; +// let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names); +// let ColumnarValue::Array(x) = x.evaluate(&batch)? else { +// unreachable!() +// }; +// assert_eq!(3, x.len()); +// Ok(()) +// } + +// #[test] +// fn test_create_struct_from_dict_encoded_string() -> Result<()> { +// let keys = Int32Array::from(vec![0, 1, 2]); +// let values = StringArray::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]); +// let dict = DictionaryArray::try_new(keys, Arc::new(values))?; +// let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); +// let schema = Schema::new(vec![Field::new("a", data_type, false)]); +// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?; +// let field_names = vec!["a".to_string()]; +// let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names); +// let ColumnarValue::Array(x) = x.evaluate(&batch)? else { +// unreachable!() +// }; +// assert_eq!(3, x.len()); +// Ok(()) +// } +// } diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 9fb16f94d..baf1a0ef0 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -23,6 +23,7 @@ mod cast; mod error; mod if_expr; +mod array; mod kernels; mod regexp; pub mod scalar_funcs; @@ -33,6 +34,7 @@ pub mod timezone; pub mod utils; mod xxhash64; +pub use array::ArrayExtract; pub use cast::{spark_cast, Cast}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 978970e8d..4e4755d00 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -59,7 +59,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim logWarning(s"Comet native execution is disabled due to: $reason") } - def supportedDataType(dt: DataType, allowStruct: Boolean = false): Boolean = dt match { + def supportedDataType( + dt: DataType, + allowStruct: Boolean = false, + allowArray: Boolean = false): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | _: DateType | _: BooleanType | _: NullType => @@ -67,6 +70,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case dt if isTimestampNTZType(dt) => true case s: StructType if allowStruct => s.fields.map(_.dataType).forall(supportedDataType(_, allowStruct)) + case a: ArrayType if allowArray => + supportedDataType(a.elementType, allowStruct, allowArray) case dt => emitWarning(s"unsupported Spark data type: $dt") false @@ -2399,33 +2404,52 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case get @ GetArrayItem(child, ordinal, _) => + case GetArrayItem(child, ordinal, failOnError) => val childExpr = exprToProto(child, inputs, binding) - - // DataFusion expects the indices to be int64 - val ordinalExpr = - exprToProto(Add(Cast(ordinal, LongType), Literal(1L)), inputs, binding) - // scalastyle:off println - println(ordinal.dataType) + val ordinalExpr = exprToProto(ordinal, inputs, binding) if (childExpr.isDefined && ordinalExpr.isDefined) { - scalarExprToProtoWithReturnType("array_element", get.dataType, childExpr, ordinalExpr) + val arrayExtractBuilder = ExprOuterClass.ArrayExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(false) + .setFailOnError(failOnError) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayExtract(arrayExtractBuilder) + .build()) } else { withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) None } - case get @ ElementAt(child, ordinal, _, _) => + case ElementAt(child, ordinal, defaultValue, failOnError) + if child.dataType.isInstanceOf[ArrayType] => val childExpr = exprToProto(child, inputs, binding) + val ordinalExpr = exprToProto(ordinal, inputs, binding) + val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, binding)) - // DataFusion expects the indices to be int64 - val ordinalExpr = - exprToProto(Cast(ordinal, LongType), inputs, binding) + if (childExpr.isDefined && ordinalExpr.isDefined && + defaultExpr.isDefined == defaultValue.isDefined) { + val arrayExtractBuilder = ExprOuterClass.ArrayExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(true) + .setFailOnError(failOnError) - if (childExpr.isDefined && ordinalExpr.isDefined) { - scalarExprToProtoWithReturnType("array_element", get.dataType, childExpr, ordinalExpr) + defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayExtract(arrayExtractBuilder) + .build()) } else { - withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) + withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) None } @@ -2933,7 +2957,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(join, "SortMergeJoin is not enabled") None - case op if isCometSink(op) && op.output.forall(a => supportedDataType(a.dataType, true)) => + case op + if isCometSink(op) && op.output.forall(a => + supportedDataType(a.dataType, true, true)) => // These operators are source of Comet native execution chain val scanBuilder = OperatorOuterClass.Scan.newBuilder() scanBuilder.setSource(op.simpleStringWithNodeId()) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index d615846ce..686146f3b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -99,7 +99,7 @@ case class CometSparkToColumnarExec(child: SparkPlan) object CometSparkToColumnarExec extends DataTypeSupport { override def isAdditionallySupported(dt: DataType): Boolean = dt match { - case _: StructType => true + case _: StructType | _: ArrayType => true case _ => false } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 13804f12a..0c4b2c191 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -27,6 +27,7 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} +import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps import org.apache.spark.sql.comet.CometProjectExec import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -2021,55 +2022,70 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("GetArrayItem") { + test("ArrayExtract") { + def assertBothThrow(df: DataFrame): Unit = { + checkSparkMaybeThrows(df) match { + case (Some(_), Some(_)) => () + case (spark, comet) => fail(s"Expected Spark and Comet to throw exception, but got\nSpark: $spark\nComet: $comet") + } + } + Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 100) - Seq(false).foreach { ansiEnabled => - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString()) { - val df = spark.read - .parquet(path.toString) + Seq(true, false).foreach { ansiEnabled => + withSQLConf( + CometConf.COMET_ANSI_MODE_ENABLED.key -> "true", + SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString(), + // Prevent the optimizer from collapsing an extract value of a create array + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) { + val df = spark.read.parquet(path.toString) - val stringArray = df.select(array(col("_8"), col("_13")).alias("arr")) - stringArray.show() + val stringArray = df.select(array(col("_8"), col("_8")).alias("arr")) checkSparkAnswerAndOperator( stringArray.select( - col("arr").getItem(-3), - col("arr").getItem(-4), + col("arr").getItem(0), + col("arr").getItem(1))) + + checkSparkAnswerAndOperator( + stringArray.select( + element_at(col("arr"), -2), + element_at(col("arr"), -1), + element_at(col("arr"), 1), + element_at(col("arr"), 2))) + + // 0 is an invalid index for element_at + assertBothThrow(stringArray.select(element_at(col("arr"), 0))) + + if (ansiEnabled) { + assertBothThrow(stringArray.select(col("arr").getItem(-1))) + assertBothThrow(stringArray.select(col("arr").getItem(2))) + assertBothThrow(stringArray.select(element_at(col("arr"), -3))) + assertBothThrow(stringArray.select(element_at(col("arr"), 3))) + } else { + checkSparkAnswerAndOperator(stringArray.select(col("arr").getItem(-1))) + checkSparkAnswerAndOperator(stringArray.select(col("arr").getItem(2))) + checkSparkAnswerAndOperator(stringArray.select(element_at(col("arr"), -3))) + checkSparkAnswerAndOperator(stringArray.select(element_at(col("arr"), 3))) + } + + val intArray = df.select(array(col("_4"), col("_4"), col("_4")).alias("arr")) + checkSparkAnswerAndOperator( + intArray.select( + col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(2))) - // stringArray.select( - // col("arr").getItem(-1), - // col("arr").getItem(-2), - // col("arr").getItem(1), - // col("arr").getItem(2)) - // .show() - - stringArray.select( - // element_at(col("arr"), -2), - // element_at(col("arr"), -1), - // element_at(col("arr"), 1), - element_at(col("arr"), lit(2))) - .show() - - // val intArray = df.select(array(col("_2"), col("_3"), col("_4")).alias("arr")) - // checkSparkAnswerAndOperator( - // intArray - // .select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(-1))) - - // intArray - // .select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(-1)) - // .explain() - - // checkSparkAnswerAndOperator( - // intArray.select( - // element_at(col("arr"), 1), - // element_at(col("arr"), 3), - // element_at(col("arr"), 4), - // element_at(col("arr"), -1))) + checkSparkAnswerAndOperator( + intArray.select( + element_at(col("arr"), 1), + element_at(col("arr"), 2), + element_at(col("arr"), 3), + element_at(col("arr"), -1), + element_at(col("arr"), -2), + element_at(col("arr"), -3))) } } } From a62103cadb50cae95be2ffa60ffe2968235295c0 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Sun, 25 Aug 2024 21:53:39 -0400 Subject: [PATCH 09/20] Formatting --- .../apache/comet/CometExpressionSuite.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0c4b2c191..5e7b072e5 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2025,9 +2025,11 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("ArrayExtract") { def assertBothThrow(df: DataFrame): Unit = { checkSparkMaybeThrows(df) match { - case (Some(_), Some(_)) => () - case (spark, comet) => fail(s"Expected Spark and Comet to throw exception, but got\nSpark: $spark\nComet: $comet") - } + case (Some(_), Some(_)) => () + case (spark, comet) => + fail( + s"Expected Spark and Comet to throw exception, but got\nSpark: $spark\nComet: $comet") + } } Seq(true, false).foreach { dictionaryEnabled => @@ -2045,9 +2047,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val stringArray = df.select(array(col("_8"), col("_8")).alias("arr")) checkSparkAnswerAndOperator( - stringArray.select( - col("arr").getItem(0), - col("arr").getItem(1))) + stringArray.select(col("arr").getItem(0), col("arr").getItem(1))) checkSparkAnswerAndOperator( stringArray.select( @@ -2073,10 +2073,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val intArray = df.select(array(col("_4"), col("_4"), col("_4")).alias("arr")) checkSparkAnswerAndOperator( - intArray.select( - col("arr").getItem(0), - col("arr").getItem(1), - col("arr").getItem(2))) + intArray + .select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(2))) checkSparkAnswerAndOperator( intArray.select( From e8639a65e041a83db33f40f6dd43595646d8609a Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 28 Aug 2024 20:48:40 -0400 Subject: [PATCH 10/20] Remove array support that isn't needed yet --- .../scala/org/apache/spark/sql/comet/util/Utils.scala | 2 +- native/spark-expr/src/array.rs | 2 -- .../scala/org/apache/comet/serde/QueryPlanSerde.scala | 11 ++--------- 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index d92df66af..061a48831 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -260,7 +260,7 @@ object Utils { case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector | _: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector | - _: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector | _: ListVector) => + _: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector) => v.asInstanceOf[FieldVector] case _ => throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}") diff --git a/native/spark-expr/src/array.rs b/native/spark-expr/src/array.rs index b6a1dce46..7dee90225 100644 --- a/native/spark-expr/src/array.rs +++ b/native/spark-expr/src/array.rs @@ -68,8 +68,6 @@ impl ArrayExtract { ))), } } - - // fn one_based_index(&self, index: i32, len: ) } impl PhysicalExpr for ArrayExtract { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4e4755d00..d7aea3fbc 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -59,10 +59,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim logWarning(s"Comet native execution is disabled due to: $reason") } - def supportedDataType( - dt: DataType, - allowStruct: Boolean = false, - allowArray: Boolean = false): Boolean = dt match { + def supportedDataType(dt: DataType, allowStruct: Boolean = false): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | _: DateType | _: BooleanType | _: NullType => @@ -70,8 +67,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case dt if isTimestampNTZType(dt) => true case s: StructType if allowStruct => s.fields.map(_.dataType).forall(supportedDataType(_, allowStruct)) - case a: ArrayType if allowArray => - supportedDataType(a.elementType, allowStruct, allowArray) case dt => emitWarning(s"unsupported Spark data type: $dt") false @@ -2957,9 +2952,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(join, "SortMergeJoin is not enabled") None - case op - if isCometSink(op) && op.output.forall(a => - supportedDataType(a.dataType, true, true)) => + case op if isCometSink(op) && op.output.forall(a => supportedDataType(a.dataType, true)) => // These operators are source of Comet native execution chain val scanBuilder = OperatorOuterClass.Scan.newBuilder() scanBuilder.setSource(op.simpleStringWithNodeId()) From bd2c516d4a5e2b4c8ca0e19a27448e5061c6a145 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 28 Aug 2024 20:54:17 -0400 Subject: [PATCH 11/20] Remove unused import --- .../src/main/scala/org/apache/spark/sql/comet/util/Utils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 061a48831..8d6a63343 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -27,7 +27,6 @@ import scala.collection.JavaConverters._ import org.apache.arrow.c.CDataDictionaryProvider import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} -import org.apache.arrow.vector.complex.ListVector import org.apache.arrow.vector.complex.MapVector import org.apache.arrow.vector.complex.StructVector import org.apache.arrow.vector.dictionary.DictionaryProvider From 24dfa53d5acbbc7f2b536bda59592ffd4fd06864 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Thu, 29 Aug 2024 19:48:33 -0400 Subject: [PATCH 12/20] Cleanup --- native/proto/src/proto/expr.proto | 4 ++-- native/spark-expr/src/array.rs | 8 ++++---- .../scala/org/apache/comet/CometExpressionSuite.scala | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 6c5136ad1..579627e6a 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -512,8 +512,8 @@ message GetStructField { message ArrayExtract { Expr child = 1; Expr ordinal = 2; - bool one_based = 3; - Expr default_value = 4; + Expr default_value = 3; + bool one_based = 4; bool fail_on_error = 5; } diff --git a/native/spark-expr/src/array.rs b/native/spark-expr/src/array.rs index 7dee90225..7f2f662c9 100644 --- a/native/spark-expr/src/array.rs +++ b/native/spark-expr/src/array.rs @@ -159,8 +159,8 @@ impl PhysicalExpr for ArrayExtract { let mut s = state; self.child.hash(&mut s); self.ordinal.hash(&mut s); - self.one_based.hash(&mut s); self.default_value.hash(&mut s); + self.one_based.hash(&mut s); self.fail_on_error.hash(&mut s); self.hash(&mut s); } @@ -176,7 +176,7 @@ fn one_based_index(index: i32, len: usize) -> DataFusionResult> { let abs_index = index.abs().as_usize(); if abs_index <= len { if index > 0 { - Ok(Some(index.as_usize() - 1)) + Ok(Some(abs_index - 1)) } else { Ok(Some(len - abs_index)) } @@ -237,8 +237,8 @@ impl Display for ArrayExtract { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, - "ArrayExtract [child: {:?}, ordinal: {:?}, one_based: {:?}, default_value: {:?}, fail_on_error: {:?}]", - self.child, self.ordinal, self.one_based, self.default_value, self.fail_on_error + "ArrayExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]", + self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error ) } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 1d70884f6..70fc67277 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2077,7 +2077,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 1000) + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) val df = spark.read.parquet(path.toString) checkSparkAnswerAndOperator(df.select(array(col("_2"), col("_3"), col("_4")))) checkSparkAnswerAndOperator(df.select(array(col("_4"), col("_11"), lit(null)))) From ba04a452496062bf0bba13b7deda6da9e2b834a1 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Thu, 29 Aug 2024 21:00:47 -0400 Subject: [PATCH 13/20] Cast default value if needed --- native/spark-expr/src/array.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/native/spark-expr/src/array.rs b/native/spark-expr/src/array.rs index 7f2f662c9..205c19366 100644 --- a/native/spark-expr/src/array.rs +++ b/native/spark-expr/src/array.rs @@ -93,11 +93,19 @@ impl PhysicalExpr for ArrayExtract { .as_ref() .map(|d| { d.evaluate(batch).map(|value| match value { - // TODO: Handle converting this to dictionary - ColumnarValue::Scalar(scalar) => scalar, - _ => todo!(), + ColumnarValue::Scalar(scalar) + if !scalar.data_type().equals_datatype(child_value.data_type()) => + { + scalar.cast_to(child_value.data_type()) + } + ColumnarValue::Scalar(scalar) => Ok(scalar), + v => Err(DataFusionError::Execution(format!( + "Expected scalar default value for ArrayExtract, got {:?}", + v + ))), }) }) + .transpose()? .unwrap_or(self.data_type(&batch.schema())?.try_into())?; let adjust_index = if self.one_based { @@ -212,7 +220,7 @@ fn array_extract( let default_data = default_value.to_array()?.to_data(); - let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len()); + let mut mutable = MutableArrayData::new(vec![&data, &default_data], false, index_array.len()); for (offset_window, index) in offsets.windows(2).zip(index_array.values()) { let start = offset_window[0].as_usize(); From dea566167d4191478dc99c69f57a02aaf3c14c32 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 30 Aug 2024 07:43:30 -0400 Subject: [PATCH 14/20] Rename and fix null array issue --- .../core/src/execution/datafusion/planner.rs | 6 ++-- native/proto/src/proto/expr.proto | 4 +-- native/spark-expr/src/lib.rs | 4 +-- native/spark-expr/src/{array.rs => list.rs} | 28 +++++++++-------- .../apache/comet/serde/QueryPlanSerde.scala | 8 ++--- .../apache/comet/CometExpressionSuite.scala | 30 +++++++++---------- 6 files changed, 41 insertions(+), 39 deletions(-) rename native/spark-expr/src/{array.rs => list.rs} (92%) diff --git a/native/core/src/execution/datafusion/planner.rs b/native/core/src/execution/datafusion/planner.rs index 6f966eb77..832621d6a 100644 --- a/native/core/src/execution/datafusion/planner.rs +++ b/native/core/src/execution/datafusion/planner.rs @@ -95,7 +95,7 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - ArrayExtract, Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, + Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson, }; use datafusion_common::scalar::ScalarStructBuilder; @@ -660,7 +660,7 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?; Ok(Arc::new(ToJson::new(child, &expr.timezone))) } - ExprStruct::ArrayExtract(expr) => { + ExprStruct::ListExtract(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; let ordinal = @@ -670,7 +670,7 @@ impl PhysicalPlanner { .as_ref() .map(|e| self.create_expr(e, Arc::clone(&input_schema))) .transpose()?; - Ok(Arc::new(ArrayExtract::new( + Ok(Arc::new(ListExtract::new( child, ordinal, default_value, diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 579627e6a..88940f386 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -80,7 +80,7 @@ message Expr { CreateNamedStruct create_named_struct = 53; GetStructField get_struct_field = 54; ToJson to_json = 55; - ArrayExtract array_extract = 56; + ListExtract list_extract = 56; } } @@ -509,7 +509,7 @@ message GetStructField { int32 ordinal = 2; } -message ArrayExtract { +message ListExtract { Expr child = 1; Expr ordinal = 2; Expr default_value = 3; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 73105585b..c4b1c99ba 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -23,8 +23,8 @@ mod cast; mod error; mod if_expr; -mod array; mod kernels; +mod list; mod regexp; pub mod scalar_funcs; pub mod spark_hash; @@ -35,10 +35,10 @@ mod to_json; pub mod utils; mod xxhash64; -pub use array::ArrayExtract; pub use cast::{spark_cast, Cast}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; +pub use list::ListExtract; pub use regexp::RLike; pub use structs::{CreateNamedStruct, GetStructField}; pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; diff --git a/native/spark-expr/src/array.rs b/native/spark-expr/src/list.rs similarity index 92% rename from native/spark-expr/src/array.rs rename to native/spark-expr/src/list.rs index 205c19366..7149e9c35 100644 --- a/native/spark-expr/src/array.rs +++ b/native/spark-expr/src/list.rs @@ -34,7 +34,7 @@ use std::{ use crate::utils::down_cast_any_ref; #[derive(Debug, Hash)] -pub struct ArrayExtract { +pub struct ListExtract { child: Arc, ordinal: Arc, default_value: Option>, @@ -42,7 +42,7 @@ pub struct ArrayExtract { fail_on_error: bool, } -impl ArrayExtract { +impl ListExtract { pub fn new( child: Arc, ordinal: Arc, @@ -63,14 +63,14 @@ impl ArrayExtract { match self.child.data_type(input_schema)? { DataType::List(field) | DataType::LargeList(field) => Ok(field), data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in ArrayExtract: {:?}", + "Unexpected data type in ListExtract: {:?}", data_type ))), } } } -impl PhysicalExpr for ArrayExtract { +impl PhysicalExpr for ListExtract { fn as_any(&self) -> &dyn Any { self } @@ -100,7 +100,7 @@ impl PhysicalExpr for ArrayExtract { } ColumnarValue::Scalar(scalar) => Ok(scalar), v => Err(DataFusionError::Execution(format!( - "Expected scalar default value for ArrayExtract, got {:?}", + "Expected scalar default value for ListExtract, got {:?}", v ))), }) @@ -140,7 +140,7 @@ impl PhysicalExpr for ArrayExtract { ) } data_type => Err(DataFusionError::Internal(format!( - "Unexpected child type for ArrayExtract: {:?}", + "Unexpected child type for ListExtract: {:?}", data_type ))), } @@ -154,7 +154,7 @@ impl PhysicalExpr for ArrayExtract { self: Arc, children: Vec>, ) -> datafusion_common::Result> { - Ok(Arc::new(ArrayExtract::new( + Ok(Arc::new(ListExtract::new( Arc::clone(&children[0]), Arc::clone(&children[1]), self.default_value.clone(), @@ -177,7 +177,7 @@ impl PhysicalExpr for ArrayExtract { fn one_based_index(index: i32, len: usize) -> DataFusionResult> { if index == 0 { return Err(DataFusionError::Execution( - "Invalid index of 0 for one-based ArrayExtract".to_string(), + "Invalid index of 0 for one-based ListExtract".to_string(), )); } @@ -220,14 +220,16 @@ fn array_extract( let default_data = default_value.to_array()?.to_data(); - let mut mutable = MutableArrayData::new(vec![&data, &default_data], false, index_array.len()); + let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len()); - for (offset_window, index) in offsets.windows(2).zip(index_array.values()) { + for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() { let start = offset_window[0].as_usize(); let len = offset_window[1].as_usize() - start; if let Some(i) = adjust_index(*index, len)? { mutable.extend(0, start + i, start + i + 1); + } else if list_array.is_null(row) { + mutable.extend_nulls(1); } else if fail_on_error { return Err(DataFusionError::Execution( "Index out of bounds for array".to_string(), @@ -241,17 +243,17 @@ fn array_extract( Ok(ColumnarValue::Array(arrow::array::make_array(data))) } -impl Display for ArrayExtract { +impl Display for ListExtract { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, - "ArrayExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]", + "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]", self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error ) } } -impl PartialEq for ArrayExtract { +impl PartialEq for ListExtract { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index dc0232f45..f3fb58aae 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2465,7 +2465,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val ordinalExpr = exprToProto(ordinal, inputs, binding) if (childExpr.isDefined && ordinalExpr.isDefined) { - val arrayExtractBuilder = ExprOuterClass.ArrayExtract + val listExtractBuilder = ExprOuterClass.ListExtract .newBuilder() .setChild(childExpr.get) .setOrdinal(ordinalExpr.get) @@ -2475,7 +2475,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim Some( ExprOuterClass.Expr .newBuilder() - .setArrayExtract(arrayExtractBuilder) + .setListExtract(listExtractBuilder) .build()) } else { withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) @@ -2490,7 +2490,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim if (childExpr.isDefined && ordinalExpr.isDefined && defaultExpr.isDefined == defaultValue.isDefined) { - val arrayExtractBuilder = ExprOuterClass.ArrayExtract + val arrayExtractBuilder = ExprOuterClass.ListExtract .newBuilder() .setChild(childExpr.get) .setOrdinal(ordinalExpr.get) @@ -2502,7 +2502,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim Some( ExprOuterClass.Expr .newBuilder() - .setArrayExtract(arrayExtractBuilder) + .setListExtract(arrayExtractBuilder) .build()) } else { withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 70fc67277..f5016e13b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2096,7 +2096,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("ArrayExtract") { + test("ListExtract") { def assertBothThrow(df: DataFrame): Unit = { checkSparkMaybeThrows(df) match { case (Some(_), Some(_)) => () @@ -2119,45 +2119,45 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) { val df = spark.read.parquet(path.toString) - val stringArray = df.select(array(col("_8"), col("_8")).alias("arr")) + val stringArray = df.select(array(col("_8"), col("_8"), lit(null)).alias("arr")) checkSparkAnswerAndOperator( - stringArray.select(col("arr").getItem(0), col("arr").getItem(1))) + stringArray.select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(2))) checkSparkAnswerAndOperator( stringArray.select( + element_at(col("arr"), -3), element_at(col("arr"), -2), element_at(col("arr"), -1), element_at(col("arr"), 1), - element_at(col("arr"), 2))) + element_at(col("arr"), 2), + element_at(col("arr"), 3))) // 0 is an invalid index for element_at assertBothThrow(stringArray.select(element_at(col("arr"), 0))) if (ansiEnabled) { assertBothThrow(stringArray.select(col("arr").getItem(-1))) - assertBothThrow(stringArray.select(col("arr").getItem(2))) - assertBothThrow(stringArray.select(element_at(col("arr"), -3))) - assertBothThrow(stringArray.select(element_at(col("arr"), 3))) + assertBothThrow(stringArray.select(col("arr").getItem(3))) + assertBothThrow(stringArray.select(element_at(col("arr"), -4))) + assertBothThrow(stringArray.select(element_at(col("arr"), 4))) } else { checkSparkAnswerAndOperator(stringArray.select(col("arr").getItem(-1))) - checkSparkAnswerAndOperator(stringArray.select(col("arr").getItem(2))) - checkSparkAnswerAndOperator(stringArray.select(element_at(col("arr"), -3))) - checkSparkAnswerAndOperator(stringArray.select(element_at(col("arr"), 3))) + checkSparkAnswerAndOperator(stringArray.select(col("arr").getItem(3))) + checkSparkAnswerAndOperator(stringArray.select(element_at(col("arr"), -4))) + checkSparkAnswerAndOperator(stringArray.select(element_at(col("arr"), 4))) } - val intArray = df.select(array(col("_4"), col("_4"), col("_4")).alias("arr")) + val intArray = df.select(when(col("_4").isNotNull, array(col("_4"), col("_4"))).alias("arr")) checkSparkAnswerAndOperator( intArray - .select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(2))) + .select(col("arr").getItem(0), col("arr").getItem(1))) checkSparkAnswerAndOperator( intArray.select( element_at(col("arr"), 1), element_at(col("arr"), 2), - element_at(col("arr"), 3), element_at(col("arr"), -1), - element_at(col("arr"), -2), - element_at(col("arr"), -3))) + element_at(col("arr"), -2))) } } } From 8a198a810e15c0d1fdc436641f85ec7ebe454b77 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 30 Aug 2024 16:47:26 +0000 Subject: [PATCH 15/20] Formatting --- .../test/scala/org/apache/comet/CometExpressionSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f5016e13b..3701be5fb 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2121,7 +2121,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { val stringArray = df.select(array(col("_8"), col("_8"), lit(null)).alias("arr")) checkSparkAnswerAndOperator( - stringArray.select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(2))) + stringArray + .select(col("arr").getItem(0), col("arr").getItem(1), col("arr").getItem(2))) checkSparkAnswerAndOperator( stringArray.select( @@ -2147,7 +2148,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { checkSparkAnswerAndOperator(stringArray.select(element_at(col("arr"), 4))) } - val intArray = df.select(when(col("_4").isNotNull, array(col("_4"), col("_4"))).alias("arr")) + val intArray = + df.select(when(col("_4").isNotNull, array(col("_4"), col("_4"))).alias("arr")) checkSparkAnswerAndOperator( intArray .select(col("arr").getItem(0), col("arr").getItem(1))) From 6a94ebaff822b8d1f67742ba355f976ac1db9ba7 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 30 Aug 2024 17:59:11 -0400 Subject: [PATCH 16/20] Rename function and add default test --- native/spark-expr/src/list.rs | 97 ++++++++++++++++++----------------- 1 file changed, 49 insertions(+), 48 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 7149e9c35..e339b8e67 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -119,7 +119,7 @@ impl PhysicalExpr for ListExtract { let list_array = as_list_array(&child_value)?; let index_array = as_int32_array(&ordinal_value)?; - array_extract( + list_extract( list_array, index_array, &default_value, @@ -131,7 +131,7 @@ impl PhysicalExpr for ListExtract { let list_array = as_large_list_array(&child_value)?; let index_array = as_int32_array(&ordinal_value)?; - array_extract( + list_extract( list_array, index_array, &default_value, @@ -206,7 +206,7 @@ fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { } } -fn array_extract( +fn list_extract( list_array: &GenericListArray, index_array: &Int32Array, default_value: &ScalarValue, @@ -274,48 +274,49 @@ impl PartialEq for ListExtract { } } -// #[cfg(test)] -// mod test { -// use super::CreateNamedStruct; -// use arrow_array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray}; -// use arrow_schema::{DataType, Field, Schema}; -// use datafusion_common::Result; -// use datafusion_expr::ColumnarValue; -// use datafusion_physical_expr_common::expressions::column::Column; -// use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -// use std::sync::Arc; - -// #[test] -// fn test_create_struct_from_dict_encoded_i32() -> Result<()> { -// let keys = Int32Array::from(vec![0, 1, 2]); -// let values = Int32Array::from(vec![0, 111, 233]); -// let dict = DictionaryArray::try_new(keys, Arc::new(values))?; -// let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32)); -// let schema = Schema::new(vec![Field::new("a", data_type, false)]); -// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?; -// let field_names = vec!["a".to_string()]; -// let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names); -// let ColumnarValue::Array(x) = x.evaluate(&batch)? else { -// unreachable!() -// }; -// assert_eq!(3, x.len()); -// Ok(()) -// } - -// #[test] -// fn test_create_struct_from_dict_encoded_string() -> Result<()> { -// let keys = Int32Array::from(vec![0, 1, 2]); -// let values = StringArray::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]); -// let dict = DictionaryArray::try_new(keys, Arc::new(values))?; -// let data_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); -// let schema = Schema::new(vec![Field::new("a", data_type, false)]); -// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?; -// let field_names = vec!["a".to_string()]; -// let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names); -// let ColumnarValue::Array(x) = x.evaluate(&batch)? else { -// unreachable!() -// }; -// assert_eq!(3, x.len()); -// Ok(()) -// } -// } +#[cfg(test)] +mod test { + use crate::list::{list_extract, zero_based_index}; + + use arrow::datatypes::Int32Type; + use arrow_array::{Array, Int32Array, ListArray}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::ColumnarValue; + + #[test] + fn test_list_extract_default_value() -> Result<()> { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1)]), + None, + Some(vec![]), + ]); + let indices = Int32Array::from(vec![0, 0, 0]); + + let null_default = ScalarValue::Int32(None); + + let ColumnarValue::Array(result) = + list_extract(&list, &indices, &null_default, false, zero_based_index).unwrap() + else { + unreachable!() + }; + + assert_eq!( + &result.to_data(), + &Int32Array::from(vec![Some(1), None, None]).to_data() + ); + + let zero_default = ScalarValue::Int32(Some(0)); + + let ColumnarValue::Array(result) = + list_extract(&list, &indices, &zero_default, false, zero_based_index).unwrap() + else { + unreachable!() + }; + + assert_eq!( + &result.to_data(), + &Int32Array::from(vec![Some(1), None, Some(0)]).to_data() + ); + Ok(()) + } +} From 1c7e5c03ba419b06d36c994c1fa1338aed654107 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 30 Aug 2024 18:05:10 -0400 Subject: [PATCH 17/20] Update test --- native/spark-expr/src/list.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index e339b8e67..2812a8abd 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -295,7 +295,7 @@ mod test { let null_default = ScalarValue::Int32(None); let ColumnarValue::Array(result) = - list_extract(&list, &indices, &null_default, false, zero_based_index).unwrap() + list_extract(&list, &indices, &null_default, false, zero_based_index)? else { unreachable!() }; @@ -308,7 +308,7 @@ mod test { let zero_default = ScalarValue::Int32(Some(0)); let ColumnarValue::Array(result) = - list_extract(&list, &indices, &zero_default, false, zero_based_index).unwrap() + list_extract(&list, &indices, &zero_default, false, zero_based_index)? else { unreachable!() }; From 278b691ee1dfa9d3954d0f15d3c73a00fc6b277b Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 30 Aug 2024 18:16:50 -0400 Subject: [PATCH 18/20] Update supported expressions --- docs/source/user-guide/expressions.md | 12 +++++++----- docs/spark_expressions_support.md | 5 +++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index c2b372690..3a179054c 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -184,11 +184,13 @@ The following Spark expressions are currently available. Any known compatibility ## Complex Types -| Expression | Notes | -| ----------------- | ----- | -| CreateNamedStruct | | -| GetElementAt | | -| StructsToJson | | +| Expression | Notes | +| ----------------- | ----------- | +| CreateNamedStruct | | +| ElementAt | Arrays only | +| GetArrayItem | | +| GetStructField | | +| StructsToJson | | ## Other diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 8fb975862..297b15d2c 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -80,7 +80,7 @@ - [x] variance ### array_funcs - - [ ] array + - [x] array - [ ] array_append - [ ] array_compact - [ ] array_contains @@ -97,8 +97,9 @@ - [ ] array_union - [ ] arrays_overlap - [ ] arrays_zip + - [x] element_at - [ ] flatten - - [ ] get + - [x] get - [ ] sequence - [ ] shuffle - [ ] slice From d1c3c7d948c676132570405aa08acd46d2a7cb7c Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Tue, 3 Sep 2024 14:39:18 +0000 Subject: [PATCH 19/20] Remove array support in CometSparkToColumnarExec --- .../org/apache/spark/sql/comet/CometSparkToColumnarExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index 686146f3b..d615846ce 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -99,7 +99,7 @@ case class CometSparkToColumnarExec(child: SparkPlan) object CometSparkToColumnarExec extends DataTypeSupport { override def isAdditionallySupported(dt: DataType): Boolean = dt match { - case _: StructType | _: ArrayType => true + case _: StructType => true case _ => false } } From 28743be373e265238090986d1ff4c86b685b0cf7 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Wed, 4 Sep 2024 14:51:58 +0000 Subject: [PATCH 20/20] Add length check to with_new_children --- native/spark-expr/src/list.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs index 2812a8abd..ec953b66e 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/list.rs @@ -21,7 +21,7 @@ use arrow_schema::{DataType, FieldRef, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{ cast::{as_int32_array, as_large_list_array, as_list_array}, - DataFusionError, Result as DataFusionResult, ScalarValue, + internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, }; use datafusion_physical_expr::PhysicalExpr; use std::{ @@ -154,13 +154,16 @@ impl PhysicalExpr for ListExtract { self: Arc, children: Vec>, ) -> datafusion_common::Result> { - Ok(Arc::new(ListExtract::new( - Arc::clone(&children[0]), - Arc::clone(&children[1]), - self.default_value.clone(), - self.one_based, - self.fail_on_error, - ))) + match children.len() { + 2 => Ok(Arc::new(ListExtract::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.default_value.clone(), + self.one_based, + self.fail_on_error, + ))), + _ => internal_err!("ListExtract should have exactly two children"), + } } fn dyn_hash(&self, state: &mut dyn Hasher) {