Skip to content

Commit

Permalink
fix: Support type coercion for ScalarUDFs (apache#865)
Browse files Browse the repository at this point in the history
* Add type coercion to ScalarUDFs and support dictionaries in lists

* Formatting

* Remove unused var

* Cleanup planner

* Update comment for failing text

* Add struct tests

* Change back row count

(cherry picked from commit 7484588)
  • Loading branch information
Kimahriman authored and huaxingao committed Aug 29, 2024
1 parent 7c6372d commit 992b93c
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.util.TransferPair;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
Expand All @@ -30,13 +31,16 @@ public class CometListVector extends CometDecodedVector {
final ListVector listVector;
final ValueVector dataVector;
final ColumnVector dataColumnVector;
final DictionaryProvider dictionaryProvider;

public CometListVector(ValueVector vector, boolean useDecimal128) {
public CometListVector(
ValueVector vector, boolean useDecimal128, DictionaryProvider dictionaryProvider) {
super(vector, vector.getField(), useDecimal128);

this.listVector = ((ListVector) vector);
this.dataVector = listVector.getDataVector();
this.dataColumnVector = getVector(dataVector, useDecimal128);
this.dictionaryProvider = dictionaryProvider;
this.dataColumnVector = getVector(dataVector, useDecimal128, dictionaryProvider);
}

@Override
Expand All @@ -52,6 +56,6 @@ public CometVector slice(int offset, int length) {
TransferPair tp = this.valueVector.getTransferPair(this.valueVector.getAllocator());
tp.splitAndTransfer(offset, length);

return new CometListVector(tp.getTo(), useDecimal128);
return new CometListVector(tp.getTo(), useDecimal128, dictionaryProvider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ protected static CometVector getVector(
} else if (vector instanceof MapVector) {
return new CometMapVector(vector, useDecimal128, dictionaryProvider);
} else if (vector instanceof ListVector) {
return new CometListVector(vector, useDecimal128);
return new CometListVector(vector, useDecimal128, dictionaryProvider);
} else {
DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();
CometPlainVector cometVector = new CometPlainVector(vector, useDecimal128);
Expand Down
52 changes: 39 additions & 13 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ use crate::{
serde::to_arrow_datatype,
},
};
use arrow::compute::CastOptions;
use arrow_schema::{DataType, Field, Schema, TimeUnit, DECIMAL128_MAX_PRECISION};
use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf};
use datafusion::functions_aggregate::min_max::max_udaf;
Expand Down Expand Up @@ -1777,23 +1778,48 @@ impl PhysicalPlanner {
.map(|x| x.data_type(input_schema.as_ref()))
.collect::<Result<Vec<_>, _>>()?;

let data_type = match expr.return_type.as_ref().map(to_arrow_datatype) {
Some(t) => t,
None => {
let fun_name = match fun_name.as_str() {
"read_side_padding" => "rpad", // use the same return type as rpad
other => other,
};
self.session_ctx
.udf(fun_name)?
.inner()
.return_type(&input_expr_types)?
}
};
let (data_type, coerced_input_types) =
match expr.return_type.as_ref().map(to_arrow_datatype) {
Some(t) => (t, input_expr_types.clone()),
None => {
let fun_name = match fun_name.as_ref() {
"read_side_padding" => "rpad", // use the same return type as rpad
other => other,
};
let func = self.session_ctx.udf(fun_name)?;

let coerced_types = func
.coerce_types(&input_expr_types)
.unwrap_or_else(|_| input_expr_types.clone());

let data_type = func.inner().return_type(&coerced_types)?;

(data_type, coerced_types)
}
};

let fun_expr =
create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?;

let args = args
.into_iter()
.zip(input_expr_types.into_iter().zip(coerced_input_types))
.map(|(expr, (from_type, to_type))| {
if !from_type.equals_datatype(&to_type) {
Arc::new(CastExpr::new(
expr,
to_type,
Some(CastOptions {
safe: false,
..Default::default()
}),
))
} else {
expr
}
})
.collect::<Vec<_>>();

let scalar_expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
fun_name,
fun_expr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2450,13 +2450,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
.build()
}

// datafusion's make_array only supports nullable element types
// https://github.com/apache/datafusion/issues/11923
case array @ CreateArray(children, _) if array.dataType.containsNull =>
case CreateArray(children, _) =>
val childExprs = children.map(exprToProto(_, inputs, binding))

if (childExprs.forall(_.isDefined)) {
scalarExprToProtoWithReturnType("make_array", array.dataType, childExprs: _*)
scalarExprToProto("make_array", childExprs: _*)
} else {
withInfo(expr, "unsupported arguments for CreateArray", children: _*)
None
Expand Down
11 changes: 11 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2079,6 +2079,17 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val df = spark.read.parquet(path.toString)
checkSparkAnswerAndOperator(df.select(array(col("_2"), col("_3"), col("_4"))))
checkSparkAnswerAndOperator(df.select(array(col("_4"), col("_11"), lit(null))))
checkSparkAnswerAndOperator(
df.select(array(array(col("_4")), array(col("_4"), lit(null)))))
checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13"))))
// This ends up returning empty strings instead of nulls for the last element
// Fixed by https://github.com/apache/datafusion/commit/27304239ef79b50a443320791755bf74eed4a85d
// checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_13"), lit(null))))
checkSparkAnswerAndOperator(df.select(array(array(col("_8")), array(col("_13")))))
checkSparkAnswerAndOperator(df.select(array(col("_8"), col("_8"), lit(null))))
checkSparkAnswerAndOperator(df.select(array(struct("_4"), struct("_4"))))
// Fixed by https://github.com/apache/datafusion/commit/140f7cec78febd73d3db537a816badaaf567530a
// checkSparkAnswerAndOperator(df.select(array(struct(col("_8").alias("a")), struct(col("_13").alias("a")))))
}
}
}
Expand Down

0 comments on commit 992b93c

Please sign in to comment.