Skip to content

Commit

Permalink
feat(expr): introduce deprecated to #[function] macro (#11189)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Jul 25, 2023
1 parent c975492 commit 3d18d39
Show file tree
Hide file tree
Showing 17 changed files with 51 additions and 51 deletions.
2 changes: 2 additions & 0 deletions src/expr/macro/src/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ impl FunctionAttr {
} else {
self.generate_build_fn()?
};
let deprecated = self.deprecated;
Ok(quote! {
#[ctor::ctor]
fn #ctor_name() {
Expand All @@ -79,6 +80,7 @@ impl FunctionAttr {
inputs_type: &[#(#args),*],
ret_type: #ret,
build: #build_fn,
deprecated: #deprecated,
}) };
}
})
Expand Down
1 change: 1 addition & 0 deletions src/expr/macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ struct FunctionAttr {
init_state: Option<String>,
prebuild: Option<String>,
type_infer: Option<String>,
deprecated: bool,
user_fn: UserFunctionAttr,
}

Expand Down
9 changes: 9 additions & 0 deletions src/expr/macro/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl FunctionAttr {
init_state: find_argument(attr, "init_state"),
prebuild: find_argument(attr, "prebuild"),
type_infer: find_argument(attr, "type_infer"),
deprecated: find_name(attr, "deprecated"),
user_fn,
})
}
Expand Down Expand Up @@ -179,3 +180,11 @@ fn find_argument(attr: &syn::AttributeArgs, name: &str) -> Option<String> {
Some(lit_str.value())
})
}

/// Find name `#[xxx(.., name)]`.
fn find_name(attr: &syn::AttributeArgs, name: &str) -> bool {
attr.iter().any(|n| {
let syn::NestedMeta::Meta(syn::Meta::Path(path)) = n else { return false };
path.is_ident(name)
})
}
3 changes: 2 additions & 1 deletion src/expr/src/agg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ pub fn build(agg: AggCall) -> Result<BoxedAggState> {
func: agg.kind,
inputs_type: &args,
ret_type,
set_returning: false
set_returning: false,
deprecated: false,
}
))
})?;
Expand Down
7 changes: 1 addition & 6 deletions src/expr/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use super::expr_some_all::SomeAllExpression;
use super::expr_udf::UdfExpression;
use super::expr_vnode::VnodeExpression;
use crate::expr::expr_proctime::ProcTimeExpression;
use crate::expr::expr_to_timestamp_const_tmpl::build_to_timestamp_expr_legacy;
use crate::expr::{
BoxedExpression, Expression, InputRefExpression, LiteralExpression, TryFromExprNodeBoxed,
};
Expand Down Expand Up @@ -81,11 +80,6 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
.map(build_from_prost)
.try_collect()?;

// deprecated exprs not in signature map just for backward compatibility
if func_type == E::ToTimestamp1 && ret_type == DataType::Timestamp {
return build_to_timestamp_expr_legacy(ret_type, children);
}

build_func(func_type, ret_type, children)
}
}
Expand All @@ -111,6 +105,7 @@ pub fn build_func(
inputs_type: &args,
ret_type: (&ret_type).into(),
set_returning: false,
deprecated: false,
}
))
})?;
Expand Down
1 change: 1 addition & 0 deletions src/expr/src/expr/expr_to_timestamp_const_tmpl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ fn build_to_timestamp_expr(
}

/// Support building the variant returning timestamp without time zone for backward compatibility.
#[build_function("to_timestamp1(varchar, varchar) -> timestamp", deprecated)]
pub fn build_to_timestamp_expr_legacy(
return_type: DataType,
children: Vec<BoxedExpression>,
Expand Down
1 change: 1 addition & 0 deletions src/expr/src/sig/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ impl fmt::Debug for AggFuncSig {
inputs_type: self.inputs_type,
ret_type: self.ret_type,
set_returning: false,
deprecated: false,
}
.fmt(f)
}
Expand Down
24 changes: 15 additions & 9 deletions src/expr/src/sig/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
use std::collections::HashMap;
use std::fmt;
use std::ops::Deref;
use std::sync::LazyLock;

use risingwave_common::types::{DataType, DataTypeName};
Expand Down Expand Up @@ -53,15 +52,20 @@ impl FuncSigMap {
}

/// Returns a function signature with the same type, argument types and return type.
/// Deprecated functions are included.
pub fn get(&self, ty: PbType, args: &[DataTypeName], ret: DataTypeName) -> Option<&FuncSign> {
let v = self.0.get(&(ty, args.len()))?;
v.iter()
.find(|d| d.inputs_type == args && d.ret_type == ret)
}

/// Returns all function signatures with the same type and number of arguments.
pub fn get_with_arg_nums(&self, ty: PbType, nargs: usize) -> &[FuncSign] {
self.0.get(&(ty, nargs)).map_or(&[], Deref::deref)
/// Deprecated functions are excluded.
pub fn get_with_arg_nums(&self, ty: PbType, nargs: usize) -> Vec<&FuncSign> {
match self.0.get(&(ty, nargs)) {
Some(v) => v.iter().filter(|d| !d.deprecated).collect(),
None => vec![],
}
}
}

Expand All @@ -72,6 +76,9 @@ pub struct FuncSign {
pub inputs_type: &'static [DataTypeName],
pub ret_type: DataTypeName,
pub build: fn(return_type: DataType, children: Vec<BoxedExpression>) -> Result<BoxedExpression>,
/// Whether the function is deprecated and should not be used in the frontend.
/// For backward compatibility, it is still available in the backend.
pub deprecated: bool,
}

impl fmt::Debug for FuncSign {
Expand All @@ -81,6 +88,7 @@ impl fmt::Debug for FuncSign {
inputs_type: self.inputs_type,
ret_type: self.ret_type,
set_returning: false,
deprecated: self.deprecated,
}
.fmt(f)
}
Expand Down Expand Up @@ -124,6 +132,10 @@ mod tests {
// validate the FUNC_SIG_MAP is consistent
assert_eq!(func, &sig.func);
assert_eq!(num_args, &sig.inputs_type.len());
// exclude deprecated functions
if sig.deprecated {
continue;
}

new_map
.entry(*func)
Expand Down Expand Up @@ -187,12 +199,6 @@ mod tests {
ArrayAccess: [
"array_access(list, int32) -> boolean/int16/int32/int64/int256/float32/float64/decimal/serial/date/time/timestamp/timestamptz/interval/varchar/bytea/jsonb/list/struct",
],
ArrayLength: [
"array_length(list) -> int64/int32",
],
Cardinality: [
"cardinality(list) -> int64/int32",
],
}
"#]];
expected.assert_debug_eq(&duplicated);
Expand Down
6 changes: 4 additions & 2 deletions src/expr/src/sig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,18 @@ pub(crate) struct FuncSigDebug<'a, T> {
pub inputs_type: &'a [DataTypeName],
pub ret_type: DataTypeName,
pub set_returning: bool,
pub deprecated: bool,
}

impl<'a, T: std::fmt::Display> std::fmt::Debug for FuncSigDebug<'a, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = format!(
"{}({:?}) -> {}{:?}",
"{}({:?}) -> {}{:?}{}",
self.func,
self.inputs_type.iter().format(", "),
if self.set_returning { "setof " } else { "" },
self.ret_type
self.ret_type,
if self.deprecated { " [deprecated]" } else { "" },
)
.to_ascii_lowercase();

Expand Down
1 change: 1 addition & 0 deletions src/expr/src/sig/table_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl fmt::Debug for FuncSign {
inputs_type: self.inputs_type,
ret_type: self.ret_type,
set_returning: true,
deprecated: false,
}
.fmt(f)
}
Expand Down
1 change: 1 addition & 0 deletions src/expr/src/table_function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ pub fn build(
inputs_type: &args,
ret_type: (&return_type).into(),
set_returning: true,
deprecated: false,
}
))
})?;
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/vector_op/array_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ use crate::ExprError;
/// ----
/// 1
///
/// query error type unknown
/// query error Cannot implicitly cast
/// select array_length(null);
/// ```
#[function("array_length(list) -> int32")]
#[function("array_length(list) -> int64")] // for compatibility with plans from old version
#[function("array_length(list) -> int64", deprecated)]
fn array_length<T: TryFrom<usize>>(array: ListRef<'_>) -> Result<T, ExprError> {
array
.len()
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/vector_op/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ use crate::ExprError;
/// ----
/// 1
///
/// query error type unknown
/// query error Cannot implicitly cast
/// select cardinality(null);
/// ```
#[function("cardinality(list) -> int32")]
#[function("cardinality(list) -> int64")] // for compatibility with plans from old version
#[function("cardinality(list) -> int64", deprecated)]
fn cardinality<T: TryFrom<usize>>(array: ListRef<'_>) -> Result<T, ExprError> {
array
.flatten()
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/vector_op/position.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use risingwave_expr_macro::function;
/// ----
/// 4
/// ```
#[function("strpos(varchar, varchar) -> int32")] // backward compatibility with old proto
#[function("strpos(varchar, varchar) -> int32", deprecated)]
#[function("position(varchar, varchar) -> int32")]
pub fn position(str: &str, sub_str: &str) -> i32 {
match str.find(sub_str) {
Expand Down
3 changes: 2 additions & 1 deletion src/expr/src/window_function/state/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ pub fn create_window_state(call: &WindowFuncCall) -> Result<Box<dyn WindowState
func: kind,
inputs_type: &args,
ret_type: call.return_type.clone().into(),
set_returning: false
set_returning: false,
deprecated: false,
}
)));
}
Expand Down
25 changes: 5 additions & 20 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,16 +596,6 @@ fn infer_type_for_special(
}
Ok(Some(DataType::Varchar))
}
ExprType::ArrayLength => {
ensure_arity!("array_length", 1 <= | inputs | <= 2);
inputs[0].ensure_array_type()?;

if let Some(arg1) = inputs.get_mut(1) {
arg1.cast_implicit_mut(DataType::Int32)?;
}

Ok(Some(DataType::Int32))
}
ExprType::StringToArray => {
ensure_arity!("string_to_array", 2 <= | inputs | <= 3);

Expand All @@ -615,12 +605,6 @@ fn infer_type_for_special(

Ok(Some(DataType::List(Box::new(DataType::Varchar))))
}
ExprType::Cardinality => {
ensure_arity!("cardinality", | inputs | == 1);
inputs[0].ensure_array_type()?;

Ok(Some(DataType::Int32))
}
ExprType::TrimArray => {
ensure_arity!("trim_array", | inputs | == 2);
inputs[0].ensure_array_type()?;
Expand Down Expand Up @@ -683,7 +667,7 @@ fn infer_type_name<'a>(
}
}

let mut candidates = top_matches(candidates, inputs);
let mut candidates = top_matches(&candidates, inputs);

if candidates.is_empty() {
return Err(ErrorCode::NotImplemented(
Expand All @@ -706,7 +690,7 @@ fn infer_type_name<'a>(

match &candidates[..] {
[] => unreachable!(),
[sig] => Ok(sig),
[sig] => Ok(*sig),
_ => Err(ErrorCode::BindError(format!(
"function {:?}{:?} is not unique\nHINT: Could not choose a best candidate function. You might need to add explicit type casts.",
func_type,
Expand Down Expand Up @@ -763,7 +747,7 @@ fn implicit_ok(source: DataTypeName, target: DataTypeName, eq_ok: bool) -> bool
/// [rule 4c src]: https://github.com/postgres/postgres/blob/86a4dc1e6f29d1992a2afa3fac1a0b0a6e84568c/src/backend/parser/parse_func.c#L1062-L1104
/// [rule 4d src]: https://github.com/postgres/postgres/blob/86a4dc1e6f29d1992a2afa3fac1a0b0a6e84568c/src/backend/parser/parse_func.c#L1106-L1153
fn top_matches<'a>(
candidates: &'a [FuncSign],
candidates: &[&'a FuncSign],
inputs: &[Option<DataTypeName>],
) -> Vec<&'a FuncSign> {
let mut best_exact = 0;
Expand Down Expand Up @@ -795,7 +779,7 @@ fn top_matches<'a>(
best_candidates.clear();
}
if n_exact == best_exact && n_preferred == best_preferred {
best_candidates.push(sig);
best_candidates.push(*sig);
}
}
best_candidates
Expand Down Expand Up @@ -1226,6 +1210,7 @@ mod tests {
inputs_type: formals,
ret_type: DUMMY_RET,
build: |_, _| unreachable!(),
deprecated: false,
});
}
let result = infer_type_name(&sig_map, DUMMY_FUNC, inputs);
Expand Down
8 changes: 1 addition & 7 deletions src/tests/sqlsmith/src/sql_gen/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use risingwave_expr::sig::agg::{agg_func_sigs, AggFuncSig as RwAggFuncSig};
use risingwave_expr::sig::cast::{cast_sigs, CastContext, CastSig as RwCastSig};
use risingwave_expr::sig::func::{func_sigs, FuncSign as RwFuncSig};
use risingwave_frontend::expr::ExprType;
use risingwave_pb::expr::expr_node::PbType;
use risingwave_sqlparser::ast::{BinaryOperator, DataType as AstDataType, StructField};

pub(super) fn data_type_to_ast_data_type(data_type: &DataType) -> AstDataType {
Expand Down Expand Up @@ -187,12 +186,7 @@ pub(crate) static FUNC_TABLE: LazyLock<HashMap<DataType, Vec<FuncSig>>> = LazyLo
.iter()
.all(|t| *t != DataTypeName::Timestamptz)
&& !FUNC_BAN_LIST.contains(&func.func)
&& (func.func != PbType::Cardinality
|| !(func.inputs_type[0] == DataTypeName::List
&& func.ret_type == DataTypeName::Int64))
&& (func.func != PbType::ArrayLength
|| !(func.inputs_type[0] == DataTypeName::List
&& func.ret_type == DataTypeName::Int64))
&& !func.deprecated // deprecated functions are not accepted by frontend
})
.filter_map(|func| func.try_into().ok())
.for_each(|func: FuncSig| funcs.entry(func.ret_type.clone()).or_default().push(func));
Expand Down

0 comments on commit 3d18d39

Please sign in to comment.