Skip to content

Commit

Permalink
feat: attach Diagnostic to 'wrong number of arguments' error
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Lai authored and Ian Lai committed Feb 17, 2025
1 parent 3e6d70e commit 2614079
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
31 changes: 31 additions & 0 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub mod string_utils;

use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err};
use crate::{DataFusionError, Result, ScalarValue};
use crate::{Diagnostic, Span};
use arrow::array::{
cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray,
OffsetSizeTrait,
Expand Down Expand Up @@ -944,6 +945,36 @@ pub fn take_function_args<const N: usize, T>(
})
}

pub fn take_function_args_with_span<const N: usize, T>(
function_name: &str,
args: impl IntoIterator<Item = T>,
function_call_site: Option<Span>,
) -> Result<[T; N]> {
let args = args.into_iter().collect::<Vec<_>>();
args.try_into().map_err(|v: Vec<T>| {
let base_error = _exec_datafusion_err!(
"{} function requires {} {}, got {}",
function_name,
N,
if N == 1 { "argument" } else { "arguments" },
v.len()
);

if let Some(span) = function_call_site {
let diagnostic = Diagnostic::new_error(
format!(
"Wrong number of arguments for {} function call",
function_name
),
Some(span),
);
base_error.with_diagnostic(diagnostic)
} else {
base_error
}
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
20 changes: 18 additions & 2 deletions datafusion/functions-aggregate/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ use arrow::datatypes::{
DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION,
};
use arrow::{array::ArrayRef, datatypes::Field};
use datafusion_common::Spans;
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
exec_err, not_impl_err, utils::take_function_args_with_span, Result, ScalarValue,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::function::StateFieldsArgs;
Expand Down Expand Up @@ -97,14 +98,28 @@ macro_rules! downcast_sum {
#[derive(Debug)]
pub struct Sum {
signature: Signature,

/// Original source code location, if known
pub spans: Spans,
}

impl Sum {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
spans: Spans::new(),
}
}

pub fn spans(&self) -> &Spans {
&self.spans
}

/// Returns a mutable reference to the set of locations in the SQL query
/// where this column appears, if known.
pub fn spans_mut(&mut self) -> &mut Spans {
&mut self.spans
}
}

impl Default for Sum {
Expand All @@ -127,7 +142,8 @@ impl AggregateUDFImpl for Sum {
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [args] = take_function_args(self.name(), arg_types)?;
let [args] =
take_function_args_with_span(self.name(), arg_types, self.spans().first())?;

// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
Expand Down
13 changes: 13 additions & 0 deletions datafusion/sql/tests/cases/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,19 @@ fn get_spans(query: &'static str) -> HashMap<String, Span> {
spans
}

#[test]
fn test_wrong_argument_number() -> Result<()> {
let query = "SELECT /*a*/sum(1, 2)/*a*/";
let spans = get_spans(query);
let diag = do_query(query);
assert_eq!(
diag.message,
"Wrong number of arguments for sum function call"
);
assert_eq!(diag.span, Some(spans["a"]));
Ok(())
}

#[test]
fn test_table_not_found() -> Result<()> {
let query = "SELECT * FROM /*a*/personx/*a*/";
Expand Down

0 comments on commit 2614079

Please sign in to comment.