Skip to content

Commit

Permalink
minor: remove prefix in type_coercion (#6283)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener authored May 8, 2023
1 parent 4b041b5 commit 2e9beeb
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 61 deletions.
97 changes: 45 additions & 52 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,10 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(Utf8, _) if DataType::is_numeric(rhs_type) => Some(Utf8),
(LargeUtf8, _) if DataType::is_numeric(rhs_type) => Some(LargeUtf8),
(_, Utf8) if DataType::is_numeric(lhs_type) => Some(Utf8),
(_, LargeUtf8) if DataType::is_numeric(lhs_type) => Some(LargeUtf8),
(Utf8, _) if is_numeric(rhs_type) => Some(Utf8),
(LargeUtf8, _) if is_numeric(rhs_type) => Some(LargeUtf8),
(_, Utf8) if is_numeric(lhs_type) => Some(Utf8),
(_, LargeUtf8) if is_numeric(lhs_type) => Some(LargeUtf8),
_ => None,
}
}
Expand All @@ -344,7 +344,9 @@ fn comparison_binary_numeric_coercion(
// that the coercion does not lose information via truncation
match (lhs_type, rhs_type) {
// support decimal data type for comparison operation
(d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2),
(Decimal128(_, _), Decimal128(_, _)) => {
get_wider_decimal_type(lhs_type, rhs_type)
}
(Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
(_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
(Float64, _) | (_, Float64) => Some(Float64),
Expand Down Expand Up @@ -390,23 +392,22 @@ fn get_comparison_common_decimal_type(
decimal_type: &DataType,
other_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
let other_decimal_type = &match other_type {
// This conversion rule is from spark
// https://github.com/apache/spark/blob/1c81ad20296d34f137238dadd67cc6ae405944eb/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala#L127
DataType::Int8 => DataType::Decimal128(3, 0),
DataType::Int16 => DataType::Decimal128(5, 0),
DataType::Int32 => DataType::Decimal128(10, 0),
DataType::Int64 => DataType::Decimal128(20, 0),
DataType::Float32 => DataType::Decimal128(14, 7),
DataType::Float64 => DataType::Decimal128(30, 15),
Int8 => Decimal128(3, 0),
Int16 => Decimal128(5, 0),
Int32 => Decimal128(10, 0),
Int64 => Decimal128(20, 0),
Float32 => Decimal128(14, 7),
Float64 => Decimal128(30, 15),
_ => {
return None;
}
};
match (decimal_type, &other_decimal_type) {
(d1 @ DataType::Decimal128(_, _), d2 @ DataType::Decimal128(_, _)) => {
get_wider_decimal_type(d1, d2)
}
(d1 @ Decimal128(_, _), d2 @ Decimal128(_, _)) => get_wider_decimal_type(d1, d2),
_ => None,
}
}
Expand All @@ -433,14 +434,15 @@ fn get_wider_decimal_type(
/// Convert the numeric data type to the decimal data type.
/// Now, we just support the signed integer type and floating-point type.
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match numeric_type {
DataType::Int8 => Some(DataType::Decimal128(3, 0)),
DataType::Int16 => Some(DataType::Decimal128(5, 0)),
DataType::Int32 => Some(DataType::Decimal128(10, 0)),
DataType::Int64 => Some(DataType::Decimal128(20, 0)),
Int8 => Some(Decimal128(3, 0)),
Int16 => Some(Decimal128(5, 0)),
Int32 => Some(Decimal128(10, 0)),
Int64 => Some(Decimal128(20, 0)),
// TODO if we convert the floating-point data to the decimal type, it maybe overflow.
DataType::Float32 => Some(DataType::Decimal128(14, 7)),
DataType::Float64 => Some(DataType::Decimal128(30, 15)),
Float32 => Some(Decimal128(14, 7)),
Float64 => Some(Decimal128(30, 15)),
_ => None,
}
}
Expand Down Expand Up @@ -605,48 +607,38 @@ pub fn decimal_op_mathematics_type(

/// Determine if at least of one of lhs and rhs is numeric, and the other must be NULL or numeric
fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> bool {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(_, DataType::Null) => is_numeric(lhs_type),
(DataType::Null, _) => is_numeric(rhs_type),
(
DataType::Dictionary(_, lhs_value_type),
DataType::Dictionary(_, rhs_value_type),
) => is_numeric(lhs_value_type) && is_numeric(rhs_value_type),
(DataType::Dictionary(_, value_type), _) => {
is_numeric(value_type) && is_numeric(rhs_type)
}
(_, DataType::Dictionary(_, value_type)) => {
is_numeric(lhs_type) && is_numeric(value_type)
(_, Null) => is_numeric(lhs_type),
(Null, _) => is_numeric(rhs_type),
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
is_numeric(lhs_value_type) && is_numeric(rhs_value_type)
}
(Dictionary(_, value_type), _) => is_numeric(value_type) && is_numeric(rhs_type),
(_, Dictionary(_, value_type)) => is_numeric(lhs_type) && is_numeric(value_type),
_ => is_numeric(lhs_type) && is_numeric(rhs_type),
}
}

/// Determine if at least of one of lhs and rhs is decimal, and the other must be NULL or decimal
fn both_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(_, DataType::Null) => is_decimal(lhs_type),
(DataType::Null, _) => is_decimal(rhs_type),
(DataType::Decimal128(_, _), DataType::Decimal128(_, _)) => true,
(DataType::Dictionary(_, value_type), _) => {
is_decimal(value_type) && is_decimal(rhs_type)
}
(_, DataType::Dictionary(_, value_type)) => {
is_decimal(lhs_type) && is_decimal(value_type)
}
(_, Null) => is_decimal(lhs_type),
(Null, _) => is_decimal(rhs_type),
(Decimal128(_, _), Decimal128(_, _)) => true,
(Dictionary(_, value_type), _) => is_decimal(value_type) && is_decimal(rhs_type),
(_, Dictionary(_, value_type)) => is_decimal(lhs_type) && is_decimal(value_type),
_ => false,
}
}

/// Determine if at least of one of lhs and rhs is decimal
pub fn any_decimal(lhs_type: &DataType, rhs_type: &DataType) -> bool {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(DataType::Dictionary(_, value_type), _) => {
is_decimal(value_type) || is_decimal(rhs_type)
}
(_, DataType::Dictionary(_, value_type)) => {
is_decimal(lhs_type) || is_decimal(value_type)
}
(Dictionary(_, value_type), _) => is_decimal(value_type) || is_decimal(rhs_type),
(_, Dictionary(_, value_type)) => is_decimal(lhs_type) || is_decimal(value_type),
(_, _) => is_decimal(lhs_type) || is_decimal(rhs_type),
}
}
Expand All @@ -661,21 +653,22 @@ fn dictionary_coercion(
rhs_type: &DataType,
preserve_dictionaries: bool,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
(
DataType::Dictionary(_lhs_index_type, lhs_value_type),
DataType::Dictionary(_rhs_index_type, rhs_value_type),
Dictionary(_lhs_index_type, lhs_value_type),
Dictionary(_rhs_index_type, rhs_value_type),
) => comparison_coercion(lhs_value_type, rhs_value_type),
(d @ DataType::Dictionary(_, value_type), other_type)
| (other_type, d @ DataType::Dictionary(_, value_type))
(d @ Dictionary(_, value_type), other_type)
| (other_type, d @ Dictionary(_, value_type))
if preserve_dictionaries && value_type.as_ref() == other_type =>
{
Some(d.clone())
}
(DataType::Dictionary(_index_type, value_type), _) => {
(Dictionary(_index_type, value_type), _) => {
comparison_coercion(value_type, rhs_type)
}
(_, DataType::Dictionary(_index_type, value_type)) => {
(_, Dictionary(_index_type, value_type)) => {
comparison_coercion(lhs_type, value_type)
}
_ => None,
Expand Down
5 changes: 2 additions & 3 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,8 @@ fn coerce_window_frame(
// The above op will be rewrite to the binary op when creating the physical op.
fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result<Expr> {
let left_type = expr.get_type(schema)?;
let right_type = DataType::Boolean;
let coerced_type = coerce_types(&left_type, &Operator::IsDistinctFrom, &right_type)?;
expr.clone().cast_to(&coerced_type, schema)
coerce_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
expr.clone().cast_to(&DataType::Boolean, schema)
}

/// Returns `expressions` coerced to types compatible with
Expand Down
12 changes: 6 additions & 6 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2208,10 +2208,10 @@ mod tests {
]);
let a = $A_ARRAY::from($A_VEC);
let b = $B_ARRAY::from($B_VEC);
let result_type = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?;
let common_type = coerce_types(&$A_TYPE, &$OP, &$B_TYPE)?;

let left = try_cast(col("a", &schema)?, &schema, result_type.clone())?;
let right = try_cast(col("b", &schema)?, &schema, result_type)?;
let left = try_cast(col("a", &schema)?, &schema, common_type.clone())?;
let right = try_cast(col("b", &schema)?, &schema, common_type)?;

// verify that we can construct the expression
let expression = binary(left, $OP, right, &schema)?;
Expand Down Expand Up @@ -3687,10 +3687,10 @@ mod tests {
) -> Result<()> {
let left_type = left.data_type();
let right_type = right.data_type();
let result_type = coerce_types(left_type, &op, right_type)?;
let common_type = coerce_types(left_type, &op, right_type)?;

let left_expr = try_cast(col("a", schema)?, schema, result_type.clone())?;
let right_expr = try_cast(col("b", schema)?, schema, result_type)?;
let left_expr = try_cast(col("a", schema)?, schema, common_type.clone())?;
let right_expr = try_cast(col("b", schema)?, schema, common_type)?;
let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
Expand Down

0 comments on commit 2e9beeb

Please sign in to comment.