Skip to content

Commit

Permalink
Implementing math power function for SQL (#2324)
Browse files Browse the repository at this point in the history
* Implementing POWER function

* Delete pv.yaml

* Delete build-ballista-docker.sh

* Delete ballista.dockerfile

* aligining with latest upstream changes

* Readding docker files

* Formatting

* Leaving only 64bit types

* Adding tests, remove type conversion

* fix for cast

* Update functions.rs
  • Loading branch information
comphead authored Apr 28, 2022
1 parent e596236 commit c3c02cf
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 8 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub use expr::{
count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest,
exists, exp, exprlist_to_fields, floor, in_list, in_subquery, initcap, left, length,
lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, random,
not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, power, random,
regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim,
scalar_subquery, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos, substr, sum, tan, to_hex, to_timestamp_micros,
Expand Down
4 changes: 4 additions & 0 deletions datafusion/core/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),
BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan),
BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc),
BuiltinScalarFunction::Power => {
Arc::new(|args| make_scalar_function(math_expressions::power)(args))
}

// string functions
BuiltinScalarFunction::Array => Arc::new(array_expressions::array),
BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() {
Expand Down
126 changes: 126 additions & 0 deletions datafusion/core/tests/sql/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,129 @@ async fn case_builtin_math_expression() {
assert_batches_sorted_eq!(expected, &results);
}
}

#[tokio::test]
async fn test_power() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("i32", DataType::Int16, true),
Field::new("i64", DataType::Int64, true),
Field::new("f32", DataType::Float32, true),
Field::new("f64", DataType::Float64, true),
]));

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int16Array::from(vec![
Some(2),
Some(5),
Some(0),
Some(-14),
None,
])),
Arc::new(Int64Array::from(vec![
Some(2),
Some(5),
Some(0),
Some(-14),
None,
])),
Arc::new(Float32Array::from(vec![
Some(1.0),
Some(2.5),
Some(0.0),
Some(-14.5),
None,
])),
Arc::new(Float64Array::from(vec![
Some(1.0),
Some(2.5),
Some(0.0),
Some(-14.5),
None,
])),
],
)?;

let table = MemTable::try_new(schema, vec![vec![data]])?;

let ctx = SessionContext::new();
ctx.register_table("test", Arc::new(table))?;
let sql = r"SELECT power(i32, exp_i) as power_i32,
power(i64, exp_f) as power_i64,
power(f32, exp_i) as power_f32,
power(f64, exp_f) as power_f64,
power(2, 3) as power_int_scalar,
power(2.5, 3.0) as power_float_scalar
FROM (select test.*, 3 as exp_i, 3.0 as exp_f from test) a";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
"| power_i32 | power_i64 | power_f32 | power_f64 | power_int_scalar | power_float_scalar |",
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
"| 8 | 8 | 1 | 1 | 8 | 15.625 |",
"| 125 | 125 | 15.625 | 15.625 | 8 | 15.625 |",
"| 0 | 0 | 0 | 0 | 8 | 15.625 |",
"| -2744 | -2744 | -3048.625 | -3048.625 | 8 | 15.625 |",
"| | | | | 8 | 15.625 |",
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
];
assert_batches_eq!(expected, &actual);
//dbg!(actual[0].schema().fields());
assert_eq!(
actual[0]
.schema()
.field_with_name("power_i32")
.unwrap()
.data_type()
.to_owned(),
DataType::Int64
);
assert_eq!(
actual[0]
.schema()
.field_with_name("power_i64")
.unwrap()
.data_type()
.to_owned(),
DataType::Float64
);
assert_eq!(
actual[0]
.schema()
.field_with_name("power_f32")
.unwrap()
.data_type()
.to_owned(),
DataType::Float64
);
assert_eq!(
actual[0]
.schema()
.field_with_name("power_f64")
.unwrap()
.data_type()
.to_owned(),
DataType::Float64
);
assert_eq!(
actual[0]
.schema()
.field_with_name("power_int_scalar")
.unwrap()
.data_type()
.to_owned(),
DataType::Int64
);
assert_eq!(
actual[0]
.schema()
.field_with_name("power_float_scalar")
.unwrap()
.data_type()
.to_owned(),
DataType::Float64
);

Ok(())
}
4 changes: 4 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ pub enum BuiltinScalarFunction {
Log10,
/// log2
Log2,
/// power
Power,
/// round
Round,
/// signum
Expand Down Expand Up @@ -184,6 +186,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Log => Volatility::Immutable,
BuiltinScalarFunction::Log10 => Volatility::Immutable,
BuiltinScalarFunction::Log2 => Volatility::Immutable,
BuiltinScalarFunction::Power => Volatility::Immutable,
BuiltinScalarFunction::Round => Volatility::Immutable,
BuiltinScalarFunction::Signum => Volatility::Immutable,
BuiltinScalarFunction::Sin => Volatility::Immutable,
Expand Down Expand Up @@ -267,6 +270,7 @@ impl FromStr for BuiltinScalarFunction {
"log" => BuiltinScalarFunction::Log,
"log10" => BuiltinScalarFunction::Log10,
"log2" => BuiltinScalarFunction::Log2,
"power" => BuiltinScalarFunction::Power,
"round" => BuiltinScalarFunction::Round,
"signum" => BuiltinScalarFunction::Signum,
"sin" => BuiltinScalarFunction::Sin,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ unary_scalar_expr!(Log2, log2);
unary_scalar_expr!(Log10, log10);
unary_scalar_expr!(Ln, ln);
unary_scalar_expr!(NullIf, nullif);
scalar_expr!(Power, power, base, exponent);

// string functions
scalar_expr!(Ascii, ascii, string);
Expand Down
12 changes: 12 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ pub fn return_type(
}
}),

BuiltinScalarFunction::Power => match &input_expr_types[0] {
DataType::Int64 => Ok(DataType::Int64),
_ => Ok(DataType::Float64),
},

BuiltinScalarFunction::Abs
| BuiltinScalarFunction::Acos
| BuiltinScalarFunction::Asin
Expand Down Expand Up @@ -505,6 +510,13 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
fun.volatility(),
),
BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()),
BuiltinScalarFunction::Power => Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),
TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]),
],
fun.volatility(),
),
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
Expand Down
58 changes: 57 additions & 1 deletion datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

//! Math expressions
use arrow::array::{Float32Array, Float64Array};
use arrow::array::ArrayRef;
use arrow::array::{Float32Array, Float64Array, Int64Array};
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use rand::{thread_rng, Rng};
use std::any::type_name;
use std::iter;
use std::sync::Arc;

Expand Down Expand Up @@ -86,6 +88,33 @@ macro_rules! math_unary_function {
};
}

macro_rules! downcast_arg {
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
DataFusionError::Internal(format!(
"could not cast {} to {}",
$NAME,
type_name::<$ARRAY_TYPE>()
))
})?
}};
}

macro_rules! make_function_inputs2 {
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE);

arg1.iter()
.zip(arg2.iter())
.map(|(a1, a2)| match (a1, a2) {
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
}

math_unary_function!("sqrt", sqrt);
math_unary_function!("sin", sin);
math_unary_function!("cos", cos);
Expand Down Expand Up @@ -120,6 +149,33 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Array(Arc::new(array)))
}

pub fn power(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
&args[0],
&args[1],
"base",
"exponent",
Float64Array,
{ f64::powf }
)) as ArrayRef),

DataType::Int64 => Ok(Arc::new(make_function_inputs2!(
&args[0],
&args[1],
"base",
"exponent",
Int64Array,
{ i64::pow }
)) as ArrayRef),

other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function power",
other
))),
}
}

#[cfg(test)]
mod tests {

Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ enum ScalarFunction {
Trim=61;
Upper=62;
Coalesce=63;
Power=64;
}

message ScalarFunctionNode {
Expand Down
11 changes: 8 additions & 3 deletions datafusion/proto/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ use datafusion::{
logical_plan::{
abs, acos, ascii, asin, atan, ceil, character_length, chr, concat_expr,
concat_ws_expr, cos, digest, exp, floor, left, ln, log10, log2, now_expr, nullif,
random, regexp_replace, repeat, replace, reverse, right, round, signum, sin,
split_part, sqrt, starts_with, strpos, substr, tan, to_hex, to_timestamp_micros,
to_timestamp_millis, to_timestamp_seconds, translate, trunc,
power, random, regexp_replace, repeat, replace, reverse, right, round, signum,
sin, split_part, sqrt, starts_with, strpos, substr, tan, to_hex,
to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trunc,
window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits},
Column, DFField, DFSchema, DFSchemaRef, Expr, Operator,
},
Expand Down Expand Up @@ -466,6 +466,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Translate => Self::Translate,
ScalarFunction::RegexpMatch => Self::RegexpMatch,
ScalarFunction::Coalesce => Self::Coalesce,
ScalarFunction::Power => Self::Power,
}
}
}
Expand Down Expand Up @@ -1243,6 +1244,10 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
ScalarFunction::Power => Ok(power(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
_ => Err(proto_error(
"Protobuf deserialization error: Unsupported scalar function",
)),
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Translate => Self::Translate,
BuiltinScalarFunction::RegexpMatch => Self::RegexpMatch,
BuiltinScalarFunction::Coalesce => Self::Coalesce,
BuiltinScalarFunction::Power => Self::Power,
};

Ok(scalar_function)
Expand Down
2 changes: 1 addition & 1 deletion dev/build-ballista-docker.sh
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ set -e

. ./dev/build-set-env.sh
docker build -t ballista-base:$BALLISTA_VERSION -f dev/docker/ballista-base.dockerfile .
docker build -t ballista:$BALLISTA_VERSION -f dev/docker/ballista.dockerfile .
docker build -t ballista:$BALLISTA_VERSION -f dev/docker/ballista.dockerfile .
4 changes: 2 additions & 2 deletions dev/docker/ballista.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ARG RELEASE_FLAG=--release
FROM ballista-base:0.6.0 AS base
WORKDIR /tmp/ballista
RUN apt-get -y install cmake
RUN cargo install cargo-chef --version 0.1.23
RUN cargo install cargo-chef --version 0.1.34

FROM base as planner
ADD Cargo.toml .
Expand Down Expand Up @@ -105,4 +105,4 @@ COPY benchmarks/queries/ /queries/
ENV RUST_LOG=info
ENV RUST_BACKTRACE=full

CMD ["/executor"]
CMD ["/executor"]

0 comments on commit c3c02cf

Please sign in to comment.