Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST from floating-point/double to decimal #384

Merged
merged 11 commits into from
May 9, 2024
11 changes: 11 additions & 0 deletions core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ pub enum CometError {
to_type: String,
},

#[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")]
NumericValueOutOfRange {
value: String,
precision: u8,
scale: i8,
},

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
Expand Down Expand Up @@ -208,6 +215,10 @@ impl jni::errors::ToException for CometError {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
CometError::NumericValueOutOfRange { .. } => Exception {
class: "org/apache/spark/SparkException".to_string(),
msg: self.to_string(),
},
CometError::NumberIntFormat { source: s } => Exception {
class: "java/lang/NumberFormatException".to_string(),
msg: s.to_string(),
Expand Down
60 changes: 58 additions & 2 deletions core/src/execution/datafusion/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use std::{
use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
datatypes::TimestampMicrosecondType,
datatypes::{Decimal128Type, DecimalType, TimestampMicrosecondType},
record_batch::RecordBatch,
util::display::FormatOptions,
};
Expand All @@ -39,7 +39,7 @@ use chrono::{TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
use num::{traits::CheckedNeg, CheckedSub, Integer, Num, ToPrimitive};
use regex::Regex;

use crate::execution::datafusion::expressions::utils::{
Expand Down Expand Up @@ -332,6 +332,10 @@ impl Cast {
(DataType::Float32, DataType::LargeUtf8) => {
Self::spark_cast_float32_to_utf8::<i64>(&array, self.eval_mode)?
}

(DataType::Float64, DataType::Decimal128(precision, scale)) => {
Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode)?
}
_ => {
// when we have no Spark-specific casting we delegate to DataFusion
cast_with_options(&array, to_type, &CAST_OPTIONS)?
Expand Down Expand Up @@ -395,6 +399,58 @@ impl Cast {
Ok(cast_array)
}

fn cast_float64_to_decimal128(
array: &dyn Array,
precision: u8,
scale: i8,
eval_mode: EvalMode,
) -> CometResult<ArrayRef> {
let input = array.as_any().downcast_ref::<Float64Array>().unwrap();
let mut cast_array = PrimitiveArray::<Decimal128Type>::builder(input.len());

let mul = (precision as f64).powi(scale as i32);
vaibhawvipul marked this conversation as resolved.
Show resolved Hide resolved

for i in 0..input.len() {
if input.is_null(i) {
cast_array.append_null();
} else {
let input_value = input.value(i);
let value = (input_value * mul).round().to_i128();

match value {
Some(v) => {
if Decimal128Type::validate_decimal_precision(v, precision).is_err() {
return Err(CometError::NumericValueOutOfRange {
value: input_value.to_string(),
precision,
scale,
});
}
cast_array.append_value(v);
vaibhawvipul marked this conversation as resolved.
Show resolved Hide resolved
}
None => {
if eval_mode == EvalMode::Ansi {
return Err(CometError::NumericValueOutOfRange {
value: input_value.to_string(),
precision,
scale,
});
} else {
cast_array.append_null();
}
}
}
}
}

let res = Arc::new(
cast_array
.with_precision_and_scale(precision, scale)?
.finish(),
) as ArrayRef;
Ok(res)
}

fn spark_cast_float64_to_utf8<OffsetSize>(
from: &dyn Array,
_eval_mode: EvalMode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ object CometCast {

private def canCastFromDouble(toType: DataType): SupportLevel = toType match {
case DataTypes.BooleanType | DataTypes.FloatType => Compatible()
case _: DecimalType => Incompatible(Some("No overflow check"))
case _: DecimalType => Compatible()
case _ => Unsupported
}

Expand Down
3 changes: 1 addition & 2 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,7 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDoubles(), DataTypes.FloatType)
}

ignore("cast DoubleType to DecimalType(10,2)") {
// Comet should have failed with [NUMERIC_VALUE_OUT_OF_RANGE]
test("cast DoubleType to DecimalType(10,2)") {
castTest(generateDoubles(), DataTypes.createDecimalType(10, 2))
}

Expand Down