diff --git a/Cargo.lock b/Cargo.lock index f34a623a65254..1f4eac68794a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6207,6 +6207,7 @@ dependencies = [ "dyn-clone", "easy-ext", "either", + "expect-test", "futures-async-stream", "futures-util", "hex", diff --git a/src/expr/Cargo.toml b/src/expr/Cargo.toml index 5ddc587b25361..9bd5fa2ecaf6f 100644 --- a/src/expr/Cargo.toml +++ b/src/expr/Cargo.toml @@ -56,6 +56,7 @@ workspace-hack = { path = "../workspace-hack" } [dev-dependencies] criterion = "0.4" +expect-test = "1" serde_json = "1" [[bench]] diff --git a/src/expr/src/sig/func.rs b/src/expr/src/sig/func.rs index 67d9e21c780cf..c933fce200993 100644 --- a/src/expr/src/sig/func.rs +++ b/src/expr/src/sig/func.rs @@ -105,3 +105,96 @@ pub unsafe fn _register(desc: FuncSign) { /// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into /// `FUNC_SIG_MAP` on the first access of `FUNC_SIG_MAP`. static mut FUNC_SIG_MAP_INIT: Vec = Vec::new(); + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use itertools::Itertools; + + use super::*; + + #[test] + fn test_func_sig_map() { + // convert FUNC_SIG_MAP to a more convenient map for testing + let mut new_map: BTreeMap, Vec>> = + BTreeMap::new(); + for ((func, num_args), sigs) in &FUNC_SIG_MAP.0 { + for sig in sigs { + // validate the FUNC_SIG_MAP is consistent + assert_eq!(func, &sig.func); + assert_eq!(num_args, &sig.inputs_type.len()); + + new_map + .entry(*func) + .or_default() + .entry(sig.inputs_type.to_vec()) + .or_default() + .push(sig.clone()); + } + } + + let duplicated: BTreeMap<_, Vec<_>> = new_map + .into_iter() + .filter_map(|(k, funcs_with_same_name)| { + let funcs_with_same_name_type: Vec<_> = funcs_with_same_name + .into_values() + .filter_map(|v| { + if v.len() > 1 { + Some( + format!( + "{:}({:?}) -> {:?}", + v[0].func.as_str_name(), + v[0].inputs_type.iter().format(", "), + v.iter().map(|sig| sig.ret_type).format("/") + ) + .to_ascii_lowercase(), + ) + } else { + None + } + }) + .collect(); + if !funcs_with_same_name_type.is_empty() { + Some((k, funcs_with_same_name_type)) + } else { + None + } + }) + .collect(); + + // This snapshot shows the function signatures without a unique match. Frontend has to + // handle them specially without relying on FuncSigMap. + let expected = expect_test::expect![[r#" + { + Cast: [ + "cast(boolean) -> int32/varchar", + "cast(int16) -> int256/decimal/float64/float32/int64/int32/varchar", + "cast(int32) -> int256/int16/decimal/float64/float32/int64/boolean/varchar", + "cast(int64) -> int256/int32/int16/decimal/float64/float32/varchar", + "cast(float32) -> decimal/int64/int32/int16/float64/varchar", + "cast(float64) -> decimal/float32/int64/int32/int16/varchar", + "cast(decimal) -> float64/float32/int64/int32/int16/varchar", + "cast(date) -> timestamp/varchar", + "cast(varchar) -> date/time/timestamp/jsonb/interval/int256/float32/float64/decimal/int16/int32/int64/varchar/boolean/bytea/list", + "cast(time) -> interval/varchar", + "cast(timestamp) -> date/time/varchar", + "cast(interval) -> time/varchar", + "cast(list) -> varchar/list", + "cast(jsonb) -> boolean/float64/float32/decimal/int64/int32/int16/varchar", + "cast(int256) -> float64/varchar", + ], + ArrayAccess: [ + "array_access(list, int32) -> boolean/int16/int32/int64/int256/float32/float64/decimal/serial/date/time/timestamp/timestamptz/interval/varchar/bytea/jsonb/list/struct", + ], + ArrayLength: [ + "array_length(list) -> int64/int32", + ], + Cardinality: [ + "cardinality(list) -> int64/int32", + ], + } + "#]]; + expected.assert_debug_eq(&duplicated); + } +} diff --git a/src/expr/src/vector_op/array_length.rs b/src/expr/src/vector_op/array_length.rs index 9b63a80dbb32c..e4e44179e76a9 100644 --- a/src/expr/src/vector_op/array_length.rs +++ b/src/expr/src/vector_op/array_length.rs @@ -59,8 +59,9 @@ use crate::ExprError; /// query error type unknown /// select array_length(null); /// ``` -#[function("array_length(list) -> int64")] -fn array_length(array: ListRef<'_>) -> Result { +#[function("array_length(list) -> int32")] +#[function("array_length(list) -> int64")] // for compatibility with plans from old version +fn array_length>(array: ListRef<'_>) -> Result { array .len() .try_into() @@ -126,8 +127,8 @@ fn array_length(array: ListRef<'_>) -> Result { /// statement error /// select array_length(array[null, array[2]], 2); /// ``` -#[function("array_length(list, int32) -> int64")] -fn array_length_d(array: ListRef<'_>, d: i32) -> Result, ExprError> { +#[function("array_length(list, int32) -> int32")] +fn array_length_of_dim(array: ListRef<'_>, d: i32) -> Result, ExprError> { match d { ..=0 => Ok(None), 1 => array_length(array).map(Some), diff --git a/src/expr/src/vector_op/cardinality.rs b/src/expr/src/vector_op/cardinality.rs index 10314857cf2c5..3fc8470295d7e 100644 --- a/src/expr/src/vector_op/cardinality.rs +++ b/src/expr/src/vector_op/cardinality.rs @@ -15,6 +15,8 @@ use risingwave_common::array::ListRef; use risingwave_expr_macro::function; +use crate::ExprError; + /// Returns the total number of elements in the array. /// /// ```sql @@ -57,7 +59,12 @@ use risingwave_expr_macro::function; /// query error type unknown /// select cardinality(null); /// ``` -#[function("cardinality(list) -> int64")] -fn cardinality(array: ListRef<'_>) -> i64 { - array.flatten().len() as _ +#[function("cardinality(list) -> int32")] +#[function("cardinality(list) -> int64")] // for compatibility with plans from old version +fn cardinality>(array: ListRef<'_>) -> Result { + array + .flatten() + .len() + .try_into() + .map_err(|_| ExprError::NumericOverflow) } diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 159825432cc6f..dd4ce14494479 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -604,7 +604,7 @@ fn infer_type_for_special( arg1.cast_implicit_mut(DataType::Int32)?; } - Ok(Some(DataType::Int64)) + Ok(Some(DataType::Int32)) } ExprType::StringToArray => { ensure_arity!("string_to_array", 2 <= | inputs | <= 3); @@ -619,7 +619,7 @@ fn infer_type_for_special( ensure_arity!("cardinality", | inputs | == 1); inputs[0].ensure_array_type()?; - Ok(Some(DataType::Int64)) + Ok(Some(DataType::Int32)) } ExprType::TrimArray => { ensure_arity!("trim_array", | inputs | == 2);