Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some bugs in TypeCoercion rule #3407

Merged
merged 8 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions datafusion/expr/src/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ pub fn binary_operator_data_type(

/// Coercion rules for all binary operators. Returns the output type
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
///
/// TODO this function is trying to serve two purposes at once; it determines the result type
/// of the binary operation and also determines how the inputs can be coerced but this
/// results in inconsistencies in some cases (particular around date + interval)
///
/// Tracking issue is https://github.com/apache/arrow-datafusion/issues/3419
pub fn coerce_types(
lhs_type: &DataType,
op: &Operator,
Expand Down Expand Up @@ -516,6 +522,8 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTyp
use arrow::datatypes::DataType::*;
use arrow::datatypes::TimeUnit;
match (lhs_type, rhs_type) {
(Date64, Date32) => Some(Date64),
(Date32, Date64) => Some(Date64),
(Utf8, Date32) => Some(Date32),
(Date32, Utf8) => Some(Date32),
(Utf8, Date64) => Some(Date64),
Expand Down
117 changes: 94 additions & 23 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
//! Optimizer rule for type validation and coercion

use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{DFSchema, DFSchemaRef, Result};
use datafusion_expr::binary_rule::coerce_types;
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
use datafusion_expr::binary_rule::{coerce_types, comparison_coercion};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::builder::build_join_schema;
use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{ExprSchemable, Signature};
use std::sync::Arc;

#[derive(Default)]
pub struct TypeCoercion {}
Expand Down Expand Up @@ -54,17 +54,19 @@ impl OptimizerRule for TypeCoercion {
.map(|p| self.optimize(p, optimizer_config))
.collect::<Result<Vec<_>>>()?;

let schema = match new_inputs.len() {
1 => new_inputs[0].schema().clone(),
2 => DFSchemaRef::new(build_join_schema(
new_inputs[0].schema(),
new_inputs[1].schema(),
&JoinType::Inner,
)?),
_ => DFSchemaRef::new(DFSchema::empty()),
};
// get schema representing all available input fields. This is used for data type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// resolution only, so order does not matter here
let schema = new_inputs.iter().map(|input| input.schema()).fold(
DFSchema::empty(),
|mut lhs, rhs| {
lhs.merge(rhs);
lhs
},
);

let mut expr_rewrite = TypeCoercionRewriter { schema };
let mut expr_rewrite = TypeCoercionRewriter {
schema: Arc::new(schema),
};

let new_expr = plan
.expressions()
Expand All @@ -87,14 +89,55 @@ impl ExprRewriter for TypeCoercionRewriter {

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match expr {
Expr::BinaryExpr { left, op, right } => {
Expr::BinaryExpr {
ref left,
op,
ref right,
} => {
let left_type = left.get_type(&self.schema)?;
let right_type = right.get_type(&self.schema)?;
let coerced_type = coerce_types(&left_type, &op, &right_type)?;
Ok(Expr::BinaryExpr {
left: Box::new(left.cast_to(&coerced_type, &self.schema)?),
op,
right: Box::new(right.cast_to(&coerced_type, &self.schema)?),
match (&left_type, &right_type) {
(
DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _),
&DataType::Interval(_),
) => {
// this is a workaround for https://github.com/apache/arrow-datafusion/issues/3419
Ok(expr.clone())
}
_ => {
let coerced_type = coerce_types(&left_type, &op, &right_type)?;
Ok(Expr::BinaryExpr {
left: Box::new(
left.clone().cast_to(&coerced_type, &self.schema)?,
),
op,
right: Box::new(
right.clone().cast_to(&coerced_type, &self.schema)?,
),
})
}
}
}
Expr::Between {
expr,
negated,
low,
high,
} => {
let expr_type = expr.get_type(&self.schema)?;
let low_type = low.get_type(&self.schema)?;
let coerced_type = comparison_coercion(&expr_type, &low_type)
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Failed to coerce types {} and {} in BETWEEN expression",
expr_type, low_type
))
})?;
Ok(Expr::Between {
expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?),
negated,
low: Box::new(low.cast_to(&coerced_type, &self.schema)?),
high: Box::new(high.cast_to(&coerced_type, &self.schema)?),
})
}
Expr::ScalarUDF { fun, args } => {
Expand Down Expand Up @@ -145,12 +188,12 @@ mod test {
use crate::type_coercion::TypeCoercion;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, Result};
use datafusion_common::{DFSchema, Result, ScalarValue};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF,
Signature, Volatility,
Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation,
ScalarUDF, Signature, Volatility,
};
use std::sync::Arc;

Expand Down Expand Up @@ -244,6 +287,34 @@ mod test {
Ok(())
}

#[test]
fn binary_op_date32_add_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

let expr = Expr::BinaryExpr {
left: Box::new(Expr::Cast {
expr: Box::new(lit("1998-03-18")),
data_type: DataType::Date32,
}),
op: Operator::Plus,
right: Box::new(Expr::Literal(ScalarValue::IntervalDayTime(Some(
386547056640,
)))),
};
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"386547056640\")\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

fn empty() -> Arc<LogicalPlan> {
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
Expand Down
50 changes: 50 additions & 0 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::filter_push_down::FilterPushDown;
use datafusion_optimizer::limit_push_down::LimitPushDown;
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::projection_push_down::ProjectionPushDown;
use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin;
use datafusion_optimizer::type_coercion::TypeCoercion;
use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::ast::Statement;
Expand All @@ -56,11 +58,56 @@ fn distribute_by() -> Result<()> {
Ok(())
}

#[test]
fn intersect() -> Result<()> {
let sql = "SELECT col_int32, col_utf8 FROM test \
INTERSECT SELECT col_int32, col_utf8 FROM test \
INTERSECT SELECT col_int32, col_utf8 FROM test";
let plan = test_sql(sql)?;
let expected =
"Semi Join: #test.col_int32 = #test.col_int32, #test.col_utf8 = #test.col_utf8\
\n Distinct:\
\n Semi Join: #test.col_int32 = #test.col_int32, #test.col_utf8 = #test.col_utf8\
\n Distinct:\
\n TableScan: test projection=[col_int32, col_utf8]\
\n TableScan: test projection=[col_int32, col_utf8]\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn between_date32_plus_interval() -> Result<()> {
let sql = "SELECT count(1) FROM test \
WHERE col_date32 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'";
let plan = test_sql(sql)?;
let expected =
"Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n Filter: #test.col_date32 >= CAST(Utf8(\"1998-03-18\") AS Date32) AND #test.col_date32 <= Date32(\"10393\")\
\n TableScan: test projection=[col_date32]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn between_date64_plus_interval() -> Result<()> {
let sql = "SELECT count(1) FROM test \
WHERE col_date64 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'";
let plan = test_sql(sql)?;
let expected =
"Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n Filter: #test.col_date64 >= CAST(Utf8(\"1998-03-18\") AS Date64) AND #test.col_date64 <= CAST(Date32(\"10393\") AS Date64)\
\n TableScan: test projection=[col_date64]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

fn test_sql(sql: &str) -> Result<LogicalPlan> {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Expand All @@ -73,6 +120,7 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
Arc::new(FilterNullJoinKeys::default()),
Arc::new(ReduceOuterJoin::new()),
Arc::new(FilterPushDown::new()),
Arc::new(TypeCoercion::new()),
Arc::new(LimitPushDown::new()),
Arc::new(SingleDistinctToGroupBy::new()),
];
Expand Down Expand Up @@ -107,6 +155,8 @@ impl ContextProvider for MySchemaProvider {
vec![
Field::new("col_int32", DataType::Int32, true),
Field::new("col_utf8", DataType::Utf8, true),
Field::new("col_date32", DataType::Date32, true),
Field::new("col_date64", DataType::Date64, true),
],
HashMap::new(),
);
Expand Down