diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index bcf0a161447f..cb30acf1aff2 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -43,7 +43,7 @@ pub use expr::{ count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exists, exp, exprlist_to_fields, floor, in_list, in_subquery, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min, - not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, random, + not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, power, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, scalar_subquery, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, to_timestamp_micros, diff --git a/datafusion/core/src/physical_plan/functions.rs b/datafusion/core/src/physical_plan/functions.rs index 60cb33a80f86..20917fa9b4d1 100644 --- a/datafusion/core/src/physical_plan/functions.rs +++ b/datafusion/core/src/physical_plan/functions.rs @@ -293,6 +293,10 @@ pub fn create_physical_fun( BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt), BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan), BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc), + BuiltinScalarFunction::Power => { + Arc::new(|args| make_scalar_function(math_expressions::power)(args)) + } + // string functions BuiltinScalarFunction::Array => Arc::new(array_expressions::array), BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs index 171ea23be8d6..857781aa35a3 100644 --- a/datafusion/core/tests/sql/functions.rs +++ b/datafusion/core/tests/sql/functions.rs @@ -445,3 +445,129 @@ async fn case_builtin_math_expression() { assert_batches_sorted_eq!(expected, &results); } } + +#[tokio::test] +async fn test_power() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("i32", DataType::Int16, true), + Field::new("i64", DataType::Int64, true), + Field::new("f32", DataType::Float32, true), + Field::new("f64", DataType::Float64, true), + ])); + + let data = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int16Array::from(vec![ + Some(2), + Some(5), + Some(0), + Some(-14), + None, + ])), + Arc::new(Int64Array::from(vec![ + Some(2), + Some(5), + Some(0), + Some(-14), + None, + ])), + Arc::new(Float32Array::from(vec![ + Some(1.0), + Some(2.5), + Some(0.0), + Some(-14.5), + None, + ])), + Arc::new(Float64Array::from(vec![ + Some(1.0), + Some(2.5), + Some(0.0), + Some(-14.5), + None, + ])), + ], + )?; + + let table = MemTable::try_new(schema, vec![vec![data]])?; + + let ctx = SessionContext::new(); + ctx.register_table("test", Arc::new(table))?; + let sql = r"SELECT power(i32, exp_i) as power_i32, + power(i64, exp_f) as power_i64, + power(f32, exp_i) as power_f32, + power(f64, exp_f) as power_f64, + power(2, 3) as power_int_scalar, + power(2.5, 3.0) as power_float_scalar + FROM (select test.*, 3 as exp_i, 3.0 as exp_f from test) a"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+-----------+-----------+-----------+------------------+--------------------+", + "| power_i32 | power_i64 | power_f32 | power_f64 | power_int_scalar | power_float_scalar |", + "+-----------+-----------+-----------+-----------+------------------+--------------------+", + "| 8 | 8 | 1 | 1 | 8 | 15.625 |", + "| 125 | 125 | 15.625 | 15.625 | 8 | 15.625 |", + "| 0 | 0 | 0 | 0 | 8 | 15.625 |", + "| -2744 | -2744 | -3048.625 | -3048.625 | 8 | 15.625 |", + "| | | | | 8 | 15.625 |", + "+-----------+-----------+-----------+-----------+------------------+--------------------+", + ]; + assert_batches_eq!(expected, &actual); + //dbg!(actual[0].schema().fields()); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_i32") + .unwrap() + .data_type() + .to_owned(), + DataType::Int64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_i64") + .unwrap() + .data_type() + .to_owned(), + DataType::Float64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_f32") + .unwrap() + .data_type() + .to_owned(), + DataType::Float64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_f64") + .unwrap() + .data_type() + .to_owned(), + DataType::Float64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_int_scalar") + .unwrap() + .data_type() + .to_owned(), + DataType::Int64 + ); + assert_eq!( + actual[0] + .schema() + .field_with_name("power_float_scalar") + .unwrap() + .data_type() + .to_owned(), + DataType::Float64 + ); + + Ok(()) +} diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 7cc03546131e..17df179ed400 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -54,6 +54,8 @@ pub enum BuiltinScalarFunction { Log10, /// log2 Log2, + /// power + Power, /// round Round, /// signum @@ -184,6 +186,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Log10 => Volatility::Immutable, BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Power => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, BuiltinScalarFunction::Signum => Volatility::Immutable, BuiltinScalarFunction::Sin => Volatility::Immutable, @@ -267,6 +270,7 @@ impl FromStr for BuiltinScalarFunction { "log" => BuiltinScalarFunction::Log, "log10" => BuiltinScalarFunction::Log10, "log2" => BuiltinScalarFunction::Log2, + "power" => BuiltinScalarFunction::Power, "round" => BuiltinScalarFunction::Round, "signum" => BuiltinScalarFunction::Signum, "sin" => BuiltinScalarFunction::Sin, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0a3b2ebd05fb..9ed8c536bc4a 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -282,6 +282,7 @@ unary_scalar_expr!(Log2, log2); unary_scalar_expr!(Log10, log10); unary_scalar_expr!(Ln, ln); unary_scalar_expr!(NullIf, nullif); +scalar_expr!(Power, power, base, exponent); // string functions scalar_expr!(Ascii, ascii, string); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 93c5d0e12fce..385e247bd3a6 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -217,6 +217,11 @@ pub fn return_type( } }), + BuiltinScalarFunction::Power => match &input_expr_types[0] { + DataType::Int64 => Ok(DataType::Int64), + _ => Ok(DataType::Float64), + }, + BuiltinScalarFunction::Abs | BuiltinScalarFunction::Acos | BuiltinScalarFunction::Asin @@ -505,6 +510,13 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature { fun.volatility(), ), BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()), + BuiltinScalarFunction::Power => Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]), + ], + fun.volatility(), + ), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index b16a59634f50..7f41268154a9 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -17,12 +17,14 @@ //! Math expressions -use arrow::array::{Float32Array, Float64Array}; +use arrow::array::ArrayRef; +use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use rand::{thread_rng, Rng}; +use std::any::type_name; use std::iter; use std::sync::Arc; @@ -86,6 +88,33 @@ macro_rules! math_unary_function { }; } +macro_rules! downcast_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast {} to {}", + $NAME, + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} + +macro_rules! make_function_inputs2 { + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE>() + }}; +} + math_unary_function!("sqrt", sqrt); math_unary_function!("sin", sin); math_unary_function!("cos", cos); @@ -120,6 +149,33 @@ pub fn random(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(array))) } +pub fn power(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "base", + "exponent", + Float64Array, + { f64::powf } + )) as ArrayRef), + + DataType::Int64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "base", + "exponent", + Int64Array, + { i64::pow } + )) as ArrayRef), + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function power", + other + ))), + } +} + #[cfg(test)] mod tests { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1e5a797cede0..1651a70191ea 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -184,6 +184,7 @@ enum ScalarFunction { Trim=61; Upper=62; Coalesce=63; + Power=64; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index daad11d7a96f..37466dae207d 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -31,9 +31,9 @@ use datafusion::{ logical_plan::{ abs, acos, ascii, asin, atan, ceil, character_length, chr, concat_expr, concat_ws_expr, cos, digest, exp, floor, left, ln, log10, log2, now_expr, nullif, - random, regexp_replace, repeat, replace, reverse, right, round, signum, sin, - split_part, sqrt, starts_with, strpos, substr, tan, to_hex, to_timestamp_micros, - to_timestamp_millis, to_timestamp_seconds, translate, trunc, + power, random, regexp_replace, repeat, replace, reverse, right, round, signum, + sin, split_part, sqrt, starts_with, strpos, substr, tan, to_hex, + to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trunc, window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits}, Column, DFField, DFSchema, DFSchemaRef, Expr, Operator, }, @@ -466,6 +466,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Translate => Self::Translate, ScalarFunction::RegexpMatch => Self::RegexpMatch, ScalarFunction::Coalesce => Self::Coalesce, + ScalarFunction::Power => Self::Power, } } } @@ -1243,6 +1244,10 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::Power => Ok(power( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", )), diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 9c2e11b6da92..03a9f6b10432 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -1069,6 +1069,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::RegexpMatch => Self::RegexpMatch, BuiltinScalarFunction::Coalesce => Self::Coalesce, + BuiltinScalarFunction::Power => Self::Power, }; Ok(scalar_function) diff --git a/dev/build-ballista-docker.sh b/dev/build-ballista-docker.sh old mode 100755 new mode 100644 index bc028da9e716..7add135d9c18 --- a/dev/build-ballista-docker.sh +++ b/dev/build-ballista-docker.sh @@ -21,4 +21,4 @@ set -e . ./dev/build-set-env.sh docker build -t ballista-base:$BALLISTA_VERSION -f dev/docker/ballista-base.dockerfile . -docker build -t ballista:$BALLISTA_VERSION -f dev/docker/ballista.dockerfile . +docker build -t ballista:$BALLISTA_VERSION -f dev/docker/ballista.dockerfile . \ No newline at end of file diff --git a/dev/docker/ballista.dockerfile b/dev/docker/ballista.dockerfile index a0a6ac94ad7c..c452e4684844 100644 --- a/dev/docker/ballista.dockerfile +++ b/dev/docker/ballista.dockerfile @@ -25,7 +25,7 @@ ARG RELEASE_FLAG=--release FROM ballista-base:0.6.0 AS base WORKDIR /tmp/ballista RUN apt-get -y install cmake -RUN cargo install cargo-chef --version 0.1.23 +RUN cargo install cargo-chef --version 0.1.34 FROM base as planner ADD Cargo.toml . @@ -105,4 +105,4 @@ COPY benchmarks/queries/ /queries/ ENV RUST_LOG=info ENV RUST_BACKTRACE=full -CMD ["/executor"] +CMD ["/executor"] \ No newline at end of file