Skip to content

Commit

Permalink
refactor(binder): unify array function type checking (#10257)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangjinwu authored Jun 9, 2023
1 parent d5da2d9 commit 981b40e
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 62 deletions.
2 changes: 1 addition & 1 deletion src/expr/src/vector_op/array_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use crate::ExprError;
/// ----
/// 1
///
/// query error unknown type
/// query error type unknown
/// select array_length(null);
/// ```
#[function("array_length(list) -> int64")]
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/vector_op/cardinality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use risingwave_expr_macro::function;
/// ----
/// 1
///
/// query error unknown type
/// query error type unknown
/// select cardinality(null);
/// ```
#[function("cardinality(list) -> int64")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
Bind error: failed to bind expression: ARRAY[1] || s
Caused by:
Bind error: Cannot append integer[] to varchar
Bind error: Cannot append varchar to integer[]
- name: jsonb || jsonb -> jsonb
sql: |
select '1'::jsonb || '2'::jsonb;
Expand Down
13 changes: 5 additions & 8 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,14 +591,11 @@ impl Binder {
(
"array_ndims",
guard_by_len(1, raw(|_binder, inputs| {
let input = &inputs[0];
if input.is_untyped() {
return Err(ErrorCode::BindError("could not determine polymorphic type because input has type unknown".into()).into());
}
match input.return_type().array_ndims() {
0 => Err(ErrorCode::BindError("array_ndims expects an array".into()).into()),
n => Ok(ExprImpl::literal_int(n.try_into().map_err(|_| ErrorCode::BindError("array_ndims integer overflow".into()))?))
}
inputs[0].ensure_array_type()?;

let n = inputs[0].return_type().array_ndims()
.try_into().map_err(|_| ErrorCode::BindError("array_ndims integer overflow".into()))?;
Ok(ExprImpl::literal_int(n))
})),
),
(
Expand Down
15 changes: 14 additions & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use fixedbitset::FixedBitSet;
use futures::FutureExt;
use paste::paste;
use risingwave_common::array::ListValue;
use risingwave_common::error::Result as RwResult;
use risingwave_common::error::{ErrorCode, Result as RwResult};
use risingwave_common::types::{DataType, Datum, Scalar};
use risingwave_expr::agg::AggKind;
use risingwave_expr::expr::build_from_prost;
Expand Down Expand Up @@ -244,6 +244,19 @@ impl ExprImpl {
FunctionCall::cast_mut(self, target, CastContext::Implicit)
}

/// Ensure the return type of this expression is an array of some type.
pub fn ensure_array_type(&self) -> Result<(), ErrorCode> {
if self.is_untyped() {
return Err(ErrorCode::BindError(
"could not determine polymorphic type because input has type unknown".into(),
));
}
match self.return_type() {
DataType::List(_) => Ok(()),
t => Err(ErrorCode::BindError(format!("expects array but got {t}"))),
}
}

/// Shorthand to enforce implicit cast to boolean
pub fn enforce_bool_clause(self, clause: &str) -> RwResult<ExprImpl> {
if self.is_untyped() {
Expand Down
75 changes: 25 additions & 50 deletions src/frontend/src/expr/type_inference/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,8 @@ fn infer_type_for_special(
Ok(casted) => Ok(Some(casted)),
Err(_) => Err(ErrorCode::BindError(format!(
"Cannot append {} to {}",
inputs[0].return_type(),
inputs[1].return_type()
inputs[1].return_type(),
inputs[0].return_type()
))
.into()),
}
Expand Down Expand Up @@ -580,79 +580,54 @@ fn infer_type_for_special(
}
ExprType::ArrayDistinct => {
ensure_arity!("array_distinct", | inputs | == 1);
let ret_type = inputs[0].return_type();
if inputs[0].is_untyped() {
return Err(ErrorCode::BindError(
"could not determine polymorphic type because input has type unknown"
.to_string(),
)
.into());
}
match ret_type {
DataType::List(list_elem_type) => Ok(Some(DataType::List(list_elem_type))),
_ => Ok(None),
}
inputs[0].ensure_array_type()?;

Ok(Some(inputs[0].return_type()))
}
ExprType::ArrayDims => {
ensure_arity!("array_dims", | inputs | == 1);
if inputs[0].is_untyped() {
return Ok(None);
}
match inputs[0].return_type() {
DataType::List(box DataType::List(_)) => Err(ErrorCode::BindError(
inputs[0].ensure_array_type()?;

if let DataType::List(box DataType::List(_)) = inputs[0].return_type() {
return Err(ErrorCode::BindError(
"array_dims for dimensions greater than 1 not supported".into(),
)
.into()),
DataType::List(_) => Ok(Some(DataType::Varchar)),
_ => Ok(None),
.into());
}
Ok(Some(DataType::Varchar))
}
ExprType::ArrayLength => {
ensure_arity!("array_length", 1 <= | inputs | <= 2);
let return_type = inputs[0].return_type();

if inputs[0].is_untyped() {
return Err(ErrorCode::BindError(
"Cannot find length for unknown type".to_string(),
)
.into());
}
inputs[0].ensure_array_type()?;

if let Some(arg1) = inputs.get_mut(1) {
arg1.cast_implicit_mut(DataType::Int32)?;
}

match return_type {
DataType::List(_list_elem_type) => Ok(Some(DataType::Int64)),
_ => Ok(None),
Ok(Some(DataType::Int64))
}
ExprType::StringToArray => {
ensure_arity!("string_to_array", 2 <= | inputs | <= 3);

if !inputs.iter().all(|e| e.return_type() == DataType::Varchar) {
return Ok(None);
}

Ok(Some(DataType::List(Box::new(DataType::Varchar))))
}
ExprType::StringToArray => Ok(Some(DataType::List(Box::new(DataType::Varchar)))),
ExprType::Cardinality => {
ensure_arity!("cardinality", | inputs | == 1);
let return_type = inputs[0].return_type();
inputs[0].ensure_array_type()?;

if inputs[0].is_untyped() {
return Err(ErrorCode::BindError(
"Cannot get cardinality of unknown type".to_string(),
)
.into());
}

match return_type {
DataType::List(_list_elem_type) => Ok(Some(DataType::Int64)),
_ => Ok(None),
}
Ok(Some(DataType::Int64))
}
ExprType::TrimArray => {
ensure_arity!("trim_array", | inputs | == 2);
inputs[0].ensure_array_type()?;

inputs[1].cast_implicit_mut(DataType::Int32)?;

match inputs[0].return_type() {
DataType::List(typ) => Ok(Some(DataType::List(typ))),
_ => Ok(None),
}
Ok(Some(inputs[0].return_type()))
}
ExprType::Vnode => {
ensure_arity!("vnode", 1 <= | inputs |);
Expand Down

0 comments on commit 981b40e

Please sign in to comment.