Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Functions] Support Arithmetic function COT() #6925

Merged
merged 13 commits into from
Jul 14, 2023
10 changes: 8 additions & 2 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ pub enum BuiltinScalarFunction {
Tanh,
/// trunc
Trunc,
/// cot
Cot,

// array functions
/// array_append
Expand Down Expand Up @@ -322,6 +324,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Sinh => Volatility::Immutable,
BuiltinScalarFunction::Sqrt => Volatility::Immutable,
BuiltinScalarFunction::Cbrt => Volatility::Immutable,
BuiltinScalarFunction::Cot => Volatility::Immutable,
BuiltinScalarFunction::Tan => Volatility::Immutable,
BuiltinScalarFunction::Tanh => Volatility::Immutable,
BuiltinScalarFunction::Trunc => Volatility::Immutable,
Expand Down Expand Up @@ -764,7 +767,8 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Cbrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Tanh
| BuiltinScalarFunction::Trunc => match input_expr_types[0] {
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Cot => match input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},
Expand Down Expand Up @@ -1112,7 +1116,8 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Tanh
| BuiltinScalarFunction::Trunc => {
| BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Cot => {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
// return the best approximation for it (in f64).
Expand Down Expand Up @@ -1142,6 +1147,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Cbrt => &["cbrt"],
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Cos => &["cos"],
BuiltinScalarFunction::Cot => &["cot"],
BuiltinScalarFunction::Cosh => &["cosh"],
BuiltinScalarFunction::Degrees => &["degrees"],
BuiltinScalarFunction::Exp => &["exp"],
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ scalar_expr!(Cbrt, cbrt, num, "cube root of a number");
scalar_expr!(Sin, sin, num, "sine");
scalar_expr!(Cos, cos, num, "cosine");
scalar_expr!(Tan, tan, num, "tangent");
scalar_expr!(Cot, cot, num, "cotangent");
scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
scalar_expr!(Tanh, tanh, num, "hyperbolic tangent");
Expand Down Expand Up @@ -912,6 +913,7 @@ mod test {
test_unary_scalar_expr!(Sin, sin);
test_unary_scalar_expr!(Cos, cos);
test_unary_scalar_expr!(Tan, tan);
test_unary_scalar_expr!(Cot, cot);
test_unary_scalar_expr!(Sinh, sinh);
test_unary_scalar_expr!(Cosh, cosh);
test_unary_scalar_expr!(Tanh, tanh);
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Log => {
Arc::new(|args| make_scalar_function(math_expressions::log)(args))
}
BuiltinScalarFunction::Cot => {
Arc::new(|args| make_scalar_function(math_expressions::cot)(args))
}

// array functions
BuiltinScalarFunction::ArrayAppend => {
Expand Down
79 changes: 79 additions & 0 deletions datafusion/physical-expr/src/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,39 @@ pub fn log(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

///cot SQL function
pub fn cot(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
"x",
Float64Array,
{ compute_cot64 }
)) as ArrayRef),

DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs!(
&args[0],
"x",
Float32Array,
{ compute_cot32 }
)) as ArrayRef),

other => Err(DataFusionError::Internal(format!(
"Unsupported data type {other:?} for function cot"
))),
}
}

fn compute_cot32(x: f32) -> f32 {
let a = f32::tan(x);
1.0 / a
}

fn compute_cot64(x: f64) -> f64 {
let a = f64::tan(x);
1.0 / a
}

#[cfg(test)]
mod tests {

Expand Down Expand Up @@ -739,4 +772,50 @@ mod tests {
assert_eq!(ints.value(2), 75);
assert_eq!(ints.value(3), 16);
}

#[test]
fn test_cot_f32() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let args: Vec<ArrayRef> =
vec![Arc::new(Float32Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
let result = cot(&args).expect("failed to initialize function cot");
let floats =
as_float32_array(&result).expect("failed to initialize function cot");

let expected = Float32Array::from(vec![
-1.986_460_4,
-0.156_119_96,
-0.501_202_8,
0.156_119_96,
]);

let eps = 1e-6;
assert_eq!(floats.len(), 4);
assert!((floats.value(0) - expected.value(0)).abs() < eps);
assert!((floats.value(1) - expected.value(1)).abs() < eps);
assert!((floats.value(2) - expected.value(2)).abs() < eps);
assert!((floats.value(3) - expected.value(3)).abs() < eps);
}

#[test]
fn test_cot_f64() {
let args: Vec<ArrayRef> =
vec![Arc::new(Float64Array::from(vec![12.1, 30.0, 90.0, -30.0]))];
let result = cot(&args).expect("failed to initialize function cot");
let floats =
as_float64_array(&result).expect("failed to initialize function cot");

let expected = Float64Array::from(vec![
-1.986_458_685_881_4,
-0.156_119_952_161_6,
-0.501_202_783_380_1,
0.156_119_952_161_6,
]);

let eps = 1e-12;
assert_eq!(floats.len(), 4);
assert!((floats.value(0) - expected.value(0)).abs() < eps);
assert!((floats.value(1) - expected.value(1)).abs() < eps);
assert!((floats.value(2) - expected.value(2)).abs() < eps);
assert!((floats.value(3) - expected.value(3)).abs() < eps);
}
}
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ enum ScalarFunction {
ArrayContains = 100;
Encode = 101;
Decode = 102;
Cot = 103;
}

message ScalarFunctionNode {
Expand Down
3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ use datafusion_expr::{
array_fill, array_length, array_ndims, array_position, array_positions,
array_prepend, array_remove, array_replace, array_to_string, ascii, asin, asinh,
atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length,
chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, current_date, current_time,
date_bin, date_part, date_trunc, degrees, digest, exp,
chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date,
current_time, date_bin, date_part, date_trunc, degrees, digest, exp,
expr::{self, InList, Sort, WindowFunction},
factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
Expand Down Expand Up @@ -417,6 +417,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::Sin => Self::Sin,
ScalarFunction::Cos => Self::Cos,
ScalarFunction::Tan => Self::Tan,
ScalarFunction::Cot => Self::Cot,
ScalarFunction::Asin => Self::Asin,
ScalarFunction::Acos => Self::Acos,
ScalarFunction::Atan => Self::Atan,
Expand Down Expand Up @@ -1473,6 +1474,7 @@ pub fn parse_expr(
)),
ScalarFunction::CurrentDate => Ok(current_date()),
ScalarFunction::CurrentTime => Ok(current_time()),
ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry)?)),
_ => Err(proto_error(
"Protobuf deserialization error: Unsupported scalar function",
)),
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::Sin => Self::Sin,
BuiltinScalarFunction::Cos => Self::Cos,
BuiltinScalarFunction::Tan => Self::Tan,
BuiltinScalarFunction::Cot => Self::Cot,
BuiltinScalarFunction::Sinh => Self::Sinh,
BuiltinScalarFunction::Cosh => Self::Cosh,
BuiltinScalarFunction::Tanh => Self::Tanh,
Expand Down