diff --git a/datafusion/__init__.py b/datafusion/__init__.py index c854f3f9d..df53b396a 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -213,6 +213,8 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None): ) if name is None: name = accum.__qualname__.lower() + if isinstance(input_type, pa.lib.DataType): + input_type = [input_type] return AggregateUDF( name=name, accumulator=accum, diff --git a/src/udaf.rs b/src/udaf.rs index 596ed6904..018cd0b6c 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -148,14 +148,14 @@ impl PyAggregateUDF { fn new( name: &str, accumulator: PyObject, - input_type: PyArrowType, + input_type: PyArrowType>, return_type: PyArrowType, state_type: PyArrowType>, volatility: &str, ) -> PyResult { let function = create_udaf( name, - vec![input_type.0], + input_type.0, Arc::new(return_type.0), parse_volatility(volatility)?, to_rust_accumulator(accumulator),