Skip to content

Commit

Permalink
fix(expr): array_length and cardinality returns int32 as in Postg…
Browse files Browse the repository at this point in the history
…reSQL (#10267)

Co-authored-by: xxchan <[email protected]>
  • Loading branch information
2 people authored and Little-Wallace committed Jun 12, 2023
1 parent a913bf4 commit a2fc4be
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ workspace-hack = { path = "../workspace-hack" }

[dev-dependencies]
criterion = "0.4"
expect-test = "1"
serde_json = "1"

[[bench]]
Expand Down
93 changes: 93 additions & 0 deletions src/expr/src/sig/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,96 @@ pub unsafe fn _register(desc: FuncSign) {
/// vector. The calls are guaranteed to be sequential. The vector will be drained and moved into
/// `FUNC_SIG_MAP` on the first access of `FUNC_SIG_MAP`.
static mut FUNC_SIG_MAP_INIT: Vec<FuncSign> = Vec::new();

#[cfg(test)]
mod tests {
use std::collections::BTreeMap;

use itertools::Itertools;

use super::*;

#[test]
fn test_func_sig_map() {
// convert FUNC_SIG_MAP to a more convenient map for testing
let mut new_map: BTreeMap<PbType, BTreeMap<Vec<DataTypeName>, Vec<FuncSign>>> =
BTreeMap::new();
for ((func, num_args), sigs) in &FUNC_SIG_MAP.0 {
for sig in sigs {
// validate the FUNC_SIG_MAP is consistent
assert_eq!(func, &sig.func);
assert_eq!(num_args, &sig.inputs_type.len());

new_map
.entry(*func)
.or_default()
.entry(sig.inputs_type.to_vec())
.or_default()
.push(sig.clone());
}
}

let duplicated: BTreeMap<_, Vec<_>> = new_map
.into_iter()
.filter_map(|(k, funcs_with_same_name)| {
let funcs_with_same_name_type: Vec<_> = funcs_with_same_name
.into_values()
.filter_map(|v| {
if v.len() > 1 {
Some(
format!(
"{:}({:?}) -> {:?}",
v[0].func.as_str_name(),
v[0].inputs_type.iter().format(", "),
v.iter().map(|sig| sig.ret_type).format("/")
)
.to_ascii_lowercase(),
)
} else {
None
}
})
.collect();
if !funcs_with_same_name_type.is_empty() {
Some((k, funcs_with_same_name_type))
} else {
None
}
})
.collect();

// This snapshot shows the function signatures without a unique match. Frontend has to
// handle them specially without relying on FuncSigMap.
let expected = expect_test::expect![[r#"
{
Cast: [
"cast(boolean) -> int32/varchar",
"cast(int16) -> int256/decimal/float64/float32/int64/int32/varchar",
"cast(int32) -> int256/int16/decimal/float64/float32/int64/boolean/varchar",
"cast(int64) -> int256/int32/int16/decimal/float64/float32/varchar",
"cast(float32) -> decimal/int64/int32/int16/float64/varchar",
"cast(float64) -> decimal/float32/int64/int32/int16/varchar",
"cast(decimal) -> float64/float32/int64/int32/int16/varchar",
"cast(date) -> timestamp/varchar",
"cast(varchar) -> date/time/timestamp/jsonb/interval/int256/float32/float64/decimal/int16/int32/int64/varchar/boolean/bytea/list",
"cast(time) -> interval/varchar",
"cast(timestamp) -> date/time/varchar",
"cast(interval) -> time/varchar",
"cast(list) -> varchar/list",
"cast(jsonb) -> boolean/float64/float32/decimal/int64/int32/int16/varchar",
"cast(int256) -> float64/varchar",
],
ArrayAccess: [
"array_access(list, int32) -> boolean/int16/int32/int64/int256/float32/float64/decimal/serial/date/time/timestamp/timestamptz/interval/varchar/bytea/jsonb/list/struct",
],
ArrayLength: [
"array_length(list) -> int64/int32",
],
Cardinality: [
"cardinality(list) -> int64/int32",
],
}
"#]];
expected.assert_debug_eq(&duplicated);
}
}
9 changes: 5 additions & 4 deletions src/expr/src/vector_op/array_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ use crate::ExprError;
/// query error type unknown
/// select array_length(null);
/// ```
#[function("array_length(list) -> int64")]
fn array_length(array: ListRef<'_>) -> Result<i64, ExprError> {
#[function("array_length(list) -> int32")]
#[function("array_length(list) -> int64")] // for compatibility with plans from old version
fn array_length<T: TryFrom<usize>>(array: ListRef<'_>) -> Result<T, ExprError> {
array
.len()
.try_into()
Expand Down Expand Up @@ -126,8 +127,8 @@ fn array_length(array: ListRef<'_>) -> Result<i64, ExprError> {
/// statement error
/// select array_length(array[null, array[2]], 2);
/// ```
#[function("array_length(list, int32) -> int64")]
fn array_length_d(array: ListRef<'_>, d: i32) -> Result<Option<i64>, ExprError> {
#[function("array_length(list, int32) -> int32")]
fn array_length_of_dim(array: ListRef<'_>, d: i32) -> Result<Option<i32>, ExprError> {
match d {
..=0 => Ok(None),
1 => array_length(array).map(Some),
Expand Down
13 changes: 10 additions & 3 deletions src/expr/src/vector_op/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
use risingwave_common::array::ListRef;
use risingwave_expr_macro::function;

use crate::ExprError;

/// Returns the total number of elements in the array.
///
/// ```sql
Expand Down Expand Up @@ -57,7 +59,12 @@ use risingwave_expr_macro::function;
/// query error type unknown
/// select cardinality(null);
/// ```
#[function("cardinality(list) -> int64")]
fn cardinality(array: ListRef<'_>) -> i64 {
array.flatten().len() as _
#[function("cardinality(list) -> int32")]
#[function("cardinality(list) -> int64")] // for compatibility with plans from old version
fn cardinality<T: TryFrom<usize>>(array: ListRef<'_>) -> Result<T, ExprError> {
array
.flatten()
.len()
.try_into()
.map_err(|_| ExprError::NumericOverflow)
}
4 changes: 2 additions & 2 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ fn infer_type_for_special(
arg1.cast_implicit_mut(DataType::Int32)?;
}

Ok(Some(DataType::Int64))
Ok(Some(DataType::Int32))
}
ExprType::StringToArray => {
ensure_arity!("string_to_array", 2 <= | inputs | <= 3);
Expand All @@ -619,7 +619,7 @@ fn infer_type_for_special(
ensure_arity!("cardinality", | inputs | == 1);
inputs[0].ensure_array_type()?;

Ok(Some(DataType::Int64))
Ok(Some(DataType::Int32))
}
ExprType::TrimArray => {
ensure_arity!("trim_array", | inputs | == 2);
Expand Down

0 comments on commit a2fc4be

Please sign in to comment.