From 774b7743f056fc9e27b6b0a781e611ba62a30251 Mon Sep 17 00:00:00 2001 From: Ian Lai Date: Tue, 18 Feb 2025 10:15:09 +0000 Subject: [PATCH] feat: attach Diagnostic to 'wrong number of arguments' error --- datafusion/common/src/utils/mod.rs | 27 +++++++++ datafusion/functions-aggregate/src/sum.rs | 20 +++++- datafusion/sql/tests/cases/diagnostic.rs | 74 +++++++++++++++++++++-- 3 files changed, 114 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index cb77cc8e79b1..7912a7768873 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -24,6 +24,7 @@ pub mod string_utils; use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; use crate::{DataFusionError, Result, ScalarValue}; +use crate::{Diagnostic, Span}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, @@ -944,6 +945,32 @@ pub fn take_function_args( }) } +pub fn take_function_args_with_span( + function_name: &str, + args: impl IntoIterator, + function_call_site: Option, +) -> Result<[T; N]> { + let args = args.into_iter().collect::>(); + args.try_into().map_err(|v: Vec| { + let base_error = _exec_datafusion_err!( + "{} function requires {} {}, got {}", + function_name, + N, + if N == 1 { "argument" } else { "arguments" }, + v.len() + ); + + let diagnostic = Diagnostic::new_error( + format!( + "Wrong number of arguments for {} function call", + function_name + ), + function_call_site, + ); + base_error.with_diagnostic(diagnostic) + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 76a1315c2d88..15b2812c90d5 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -33,8 +33,9 @@ use arrow::datatypes::{ DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use arrow::{array::ArrayRef, datatypes::Field}; +use datafusion_common::Spans; use datafusion_common::{ - exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, + exec_err, not_impl_err, utils::take_function_args_with_span, Result, ScalarValue, }; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::function::StateFieldsArgs; @@ -97,14 +98,28 @@ macro_rules! downcast_sum { #[derive(Debug)] pub struct Sum { signature: Signature, + + /// Original source code location, if known + pub spans: Spans, } impl Sum { pub fn new() -> Self { Self { signature: Signature::user_defined(Volatility::Immutable), + spans: Spans::new(), } } + + pub fn spans(&self) -> &Spans { + &self.spans + } + + /// Returns a mutable reference to the set of locations in the SQL query + /// where this column appears, if known. + pub fn spans_mut(&mut self) -> &mut Spans { + &mut self.spans + } } impl Default for Sum { @@ -127,7 +142,8 @@ impl AggregateUDFImpl for Sum { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - let [args] = take_function_args(self.name(), arg_types)?; + let [args] = + take_function_args_with_span(self.name(), arg_types, self.spans().first())?; // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index 9dae2d0c3e93..eecea5ab8498 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -14,13 +14,13 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - -use std::collections::HashMap; - use datafusion_common::{Diagnostic, Location, Result, Span}; +use datafusion_expr::test::function_stub::sum_udaf; use datafusion_sql::planner::{ParserOptions, SqlToRel}; use regex::Regex; +use sqlparser::ast::{Expr as SQLExpr, SelectItem, SetExpr, Statement}; use sqlparser::{dialect::GenericDialect, parser::Parser}; +use std::collections::HashMap; use crate::{MockContextProvider, MockSessionState}; @@ -36,14 +36,22 @@ fn do_query(sql: &'static str) -> Diagnostic { ..ParserOptions::default() }; - let state = MockSessionState::default(); + let state = MockSessionState::default().with_aggregate_function(sum_udaf()); let context = MockContextProvider { state }; let sql_to_rel = SqlToRel::new_with_options(&context, options); match sql_to_rel.sql_statement_to_plan(statement) { Ok(_) => panic!("expected error"), + // Err(err) => match err.diagnostic() { + // Some(diag) => diag.clone(), + // None => panic!("expected diagnostic"), + // }, Err(err) => match err.diagnostic() { Some(diag) => diag.clone(), - None => panic!("expected diagnostic"), + None => { + // 使用 dbg! 來查看錯誤的內容 + dbg!(&err); + panic!("expected diagnostic") + } }, } } @@ -130,6 +138,62 @@ fn get_spans(query: &'static str) -> HashMap { spans } +#[test] +fn trace_function_call_error() { + let dialect = GenericDialect {}; + let sql = "SELECT /*a*/sum(1, 2)/*a*/"; + let statement = Parser::new(&dialect) + .try_with_sql(sql) + .expect("unable to create parser") + .parse_statement() + .expect("unable to parse query"); + + let options = ParserOptions { + collect_spans: true, + ..ParserOptions::default() + }; + + let state = MockSessionState::default().with_aggregate_function(sum_udaf()); + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new_with_options(&context, options); + + // 追蹤函數處理過程 + match statement { + Statement::Query(ref query) => { + if let SetExpr::Select(select) = query.body.as_ref() { + println!("Processing SELECT statement"); + for expr in &select.projection { + if let SelectItem::UnnamedExpr(expr) = expr { + println!("Processing expression: {:?}", expr); + // 這裡會進到函數處理 + if let SQLExpr::Function(func) = expr { + println!("Function name: {:?}", func.name); + println!("Function args: {:?}", func.args); + } + } + } + } + } + _ => panic!("Expected Query"), + } + + let result = sql_to_rel.sql_statement_to_plan(statement); + println!("Final result: {:?}", result); +} + +#[test] +fn test_wrong_argument_number() -> Result<()> { + let query = "SELECT /*a*/sum(1, 2)/*a*/"; + let spans = get_spans(query); + let diag = do_query(query); + assert_eq!( + diag.message, + "Wrong number of arguments for sum function call" + ); + assert_eq!(diag.span, Some(spans["a"])); + Ok(()) +} + #[test] fn test_table_not_found() -> Result<()> { let query = "SELECT * FROM /*a*/personx/*a*/";