Skip to content

Commit

Permalink
feat: separate get_result_type and coerce_type
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener committed May 4, 2023
1 parent 5d3802f commit 89b6e22
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 9 deletions.
5 changes: 1 addition & 4 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2053,10 +2053,7 @@ mod tests {
];
for case in cases {
let logical_plan = test_csv_scan().await?.project(vec![case.clone()]);
let message = format!(
"Expression {case:?} expected to error due to impossible coercion"
);
assert!(logical_plan.is_err(), "{}", message);
assert!(logical_plan.is_ok());
}
Ok(())
}
Expand Down
51 changes: 50 additions & 1 deletion datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub fn binary_operator_data_type(
let result_type = if !any_decimal(lhs_type, rhs_type) {
// validate that it is possible to perform the operation on incoming types.
// (or the return datatype cannot be inferred)
coerce_types(lhs_type, op, rhs_type)?
get_result_type(lhs_type, op, rhs_type)?
} else {
let (coerced_lhs_type, coerced_rhs_type) =
math_decimal_coercion(lhs_type, rhs_type);
Expand Down Expand Up @@ -106,6 +106,55 @@ pub fn binary_operator_data_type(
}
}

pub fn get_result_type(
lhs_type: &DataType,
op: &Operator,
rhs_type: &DataType,
) -> Result<DataType> {
let result = match op {
Operator::And
| Operator::Or
| Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::Gt
| Operator::GtEq
| Operator::LtEq
| Operator::IsDistinctFrom
| Operator::IsNotDistinctFrom => Some(DataType::Boolean),
Operator::Plus | Operator::Minus
if is_datetime(lhs_type)
|| is_datetime(rhs_type)
|| is_interval(lhs_type)
|| is_interval(rhs_type) =>
{
temporal_add_sub_coercion(lhs_type, rhs_type, op)
}
Operator::BitwiseAnd
| Operator::BitwiseOr
| Operator::BitwiseXor
| Operator::BitwiseShiftRight
| Operator::BitwiseShiftLeft
| Operator::Plus
| Operator::Minus
| Operator::Modulo
| Operator::Divide
| Operator::Multiply
| Operator::RegexMatch
| Operator::RegexIMatch
| Operator::RegexNotMatch
| Operator::RegexNotIMatch
| Operator::StringConcat => coerce_types(lhs_type, op, rhs_type).ok(),
};

match result {
None => Err(DataFusionError::Plan(format!(
"there isn't result type for {lhs_type:?} {op} {rhs_type:?}"
))),
Some(t) => Ok(t),
}
}

/// Coercion rules for all binary operators. Returns the output type
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
///
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/expressions/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;

use datafusion_common::{DataFusionError, Result};
use datafusion_expr::type_coercion::binary::coerce_types;
use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::{ColumnarValue, Operator};
use std::any::Any;
use std::fmt::{Display, Formatter};
Expand Down Expand Up @@ -77,7 +77,7 @@ impl PhysicalExpr for DateTimeIntervalExpr {
}

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
coerce_types(
get_result_type(
&self.lhs.data_type(input_schema)?,
&Operator::Minus,
&self.rhs.data_type(input_schema)?,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/intervals/cp_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::sync::Arc;

use arrow_schema::DataType;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::type_coercion::binary::coerce_types;
use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::Operator;
use petgraph::graph::NodeIndex;
use petgraph::stable_graph::{DefaultIx, StableGraph};
Expand Down Expand Up @@ -260,7 +260,7 @@ fn comparison_operator_target(
op: &Operator,
right_datatype: &DataType,
) -> Result<Interval> {
let datatype = coerce_types(left_datatype, &Operator::Minus, right_datatype)?;
let datatype = get_result_type(left_datatype, &Operator::Minus, right_datatype)?;
let unbounded = IntervalBound::make_unbounded(&datatype)?;
let zero = ScalarValue::new_zero(&datatype)?;
Ok(match *op {
Expand Down

0 comments on commit 89b6e22

Please sign in to comment.