Skip to content

Commit

Permalink
feat: attach Diagnostic to 'wrong number of arguments' error
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Lai authored and Ian Lai committed Feb 18, 2025
1 parent 3e6d70e commit 774b774
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 7 deletions.
27 changes: 27 additions & 0 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub mod string_utils;

use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err};
use crate::{DataFusionError, Result, ScalarValue};
use crate::{Diagnostic, Span};
use arrow::array::{
cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray,
OffsetSizeTrait,
Expand Down Expand Up @@ -944,6 +945,32 @@ pub fn take_function_args<const N: usize, T>(
})
}

pub fn take_function_args_with_span<const N: usize, T>(
function_name: &str,
args: impl IntoIterator<Item = T>,
function_call_site: Option<Span>,
) -> Result<[T; N]> {
let args = args.into_iter().collect::<Vec<_>>();
args.try_into().map_err(|v: Vec<T>| {
let base_error = _exec_datafusion_err!(
"{} function requires {} {}, got {}",
function_name,
N,
if N == 1 { "argument" } else { "arguments" },
v.len()
);

let diagnostic = Diagnostic::new_error(
format!(
"Wrong number of arguments for {} function call",
function_name
),
function_call_site,
);
base_error.with_diagnostic(diagnostic)
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
20 changes: 18 additions & 2 deletions datafusion/functions-aggregate/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ use arrow::datatypes::{
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::Spans;
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
exec_err, not_impl_err, utils::take_function_args_with_span, Result, ScalarValue,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::function::StateFieldsArgs;
Expand Down Expand Up @@ -97,14 +98,28 @@ macro_rules! downcast_sum {
#[derive(Debug)]
pub struct Sum {
signature: Signature,

/// Original source code location, if known
pub spans: Spans,
}

impl Sum {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
spans: Spans::new(),
}
}

pub fn spans(&self) -> &Spans {
&self.spans
}

/// Returns a mutable reference to the set of locations in the SQL query
/// where this column appears, if known.
pub fn spans_mut(&mut self) -> &mut Spans {
&mut self.spans
}
}

impl Default for Sum {
Expand All @@ -127,7 +142,8 @@ impl AggregateUDFImpl for Sum {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [args] = take_function_args(self.name(), arg_types)?;
let [args] =
take_function_args_with_span(self.name(), arg_types, self.spans().first())?;

// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
Expand Down
74 changes: 69 additions & 5 deletions datafusion/sql/tests/cases/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::collections::HashMap;

use datafusion_common::{Diagnostic, Location, Result, Span};
use datafusion_expr::test::function_stub::sum_udaf;
use datafusion_sql::planner::{ParserOptions, SqlToRel};
use regex::Regex;
use sqlparser::ast::{Expr as SQLExpr, SelectItem, SetExpr, Statement};
use sqlparser::{dialect::GenericDialect, parser::Parser};
use std::collections::HashMap;

use crate::{MockContextProvider, MockSessionState};

Expand All @@ -36,14 +36,22 @@ fn do_query(sql: &'static str) -> Diagnostic {
..ParserOptions::default()
};

let state = MockSessionState::default();
let state = MockSessionState::default().with_aggregate_function(sum_udaf());
let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new_with_options(&context, options);
match sql_to_rel.sql_statement_to_plan(statement) {
Ok(_) => panic!("expected error"),
// Err(err) => match err.diagnostic() {
// Some(diag) => diag.clone(),
// None => panic!("expected diagnostic"),
// },
Err(err) => match err.diagnostic() {
Some(diag) => diag.clone(),
None => panic!("expected diagnostic"),
None => {
// 使用 dbg! 來查看錯誤的內容
dbg!(&err);
panic!("expected diagnostic")
}
},
}
}
Expand Down Expand Up @@ -130,6 +138,62 @@ fn get_spans(query: &'static str) -> HashMap<String, Span> {
spans
}

#[test]
fn trace_function_call_error() {
let dialect = GenericDialect {};
let sql = "SELECT /*a*/sum(1, 2)/*a*/";
let statement = Parser::new(&dialect)
.try_with_sql(sql)
.expect("unable to create parser")
.parse_statement()
.expect("unable to parse query");

let options = ParserOptions {
collect_spans: true,
..ParserOptions::default()
};

let state = MockSessionState::default().with_aggregate_function(sum_udaf());
let context = MockContextProvider { state };
let sql_to_rel = SqlToRel::new_with_options(&context, options);

// 追蹤函數處理過程
match statement {
Statement::Query(ref query) => {
if let SetExpr::Select(select) = query.body.as_ref() {
println!("Processing SELECT statement");
for expr in &select.projection {
if let SelectItem::UnnamedExpr(expr) = expr {
println!("Processing expression: {:?}", expr);
// 這裡會進到函數處理
if let SQLExpr::Function(func) = expr {
println!("Function name: {:?}", func.name);
println!("Function args: {:?}", func.args);
}
}
}
}
}
_ => panic!("Expected Query"),
}

let result = sql_to_rel.sql_statement_to_plan(statement);
println!("Final result: {:?}", result);
}

#[test]
fn test_wrong_argument_number() -> Result<()> {
let query = "SELECT /*a*/sum(1, 2)/*a*/";
let spans = get_spans(query);
let diag = do_query(query);
assert_eq!(
diag.message,
"Wrong number of arguments for sum function call"
);
assert_eq!(diag.span, Some(spans["a"]));
Ok(())
}

#[test]
fn test_table_not_found() -> Result<()> {
let query = "SELECT * FROM /*a*/personx/*a*/";
Expand Down

0 comments on commit 774b774

Please sign in to comment.