Skip to content

Commit

Permalink
Migrate code from invoke to invoke_batch. (#13345)
Browse files Browse the repository at this point in the history
* migrate UDF invoke to invoke_batch

* fix

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
irenjj and alamb authored Nov 14, 2024
1 parent f35ab75 commit 66180fa
Show file tree
Hide file tree
Showing 10 changed files with 322 additions and 229 deletions.
281 changes: 171 additions & 110 deletions datafusion/functions/src/datetime/date_bin.rs

Large diffs are not rendered by default.

28 changes: 17 additions & 11 deletions datafusion/functions/src/datetime/date_trunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ mod tests {

use arrow::array::cast::as_primitive_array;
use arrow::array::types::TimestampNanosecondType;
use arrow::array::TimestampNanosecondArray;
use arrow::array::{Array, TimestampNanosecondArray};
use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos;
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -724,12 +724,15 @@ mod tests {
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
.collect::<TimestampNanosecondArray>()
.with_timezone_opt(tz_opt.clone());
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let batch_size = input.len();
let result = DateTruncFunc::new()
.invoke(&[
ColumnarValue::Scalar(ScalarValue::from("day")),
ColumnarValue::Array(Arc::new(input)),
])
.invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::from("day")),
ColumnarValue::Array(Arc::new(input)),
],
batch_size,
)
.unwrap();
if let ColumnarValue::Array(result) = result {
assert_eq!(
Expand Down Expand Up @@ -883,12 +886,15 @@ mod tests {
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
.collect::<TimestampNanosecondArray>()
.with_timezone_opt(tz_opt.clone());
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let batch_size = input.len();
let result = DateTruncFunc::new()
.invoke(&[
ColumnarValue::Scalar(ScalarValue::from("hour")),
ColumnarValue::Array(Arc::new(input)),
])
.invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::from("hour")),
ColumnarValue::Array(Arc::new(input)),
],
batch_size,
)
.unwrap();
if let ColumnarValue::Array(result) = result {
assert_eq!(
Expand Down
102 changes: 58 additions & 44 deletions datafusion/functions/src/datetime/make_date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,15 @@ mod tests {

#[test]
fn test_make_date() {
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let res = MakeDateFunc::new()
.invoke(&[
ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))),
])
.invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))),
],
1,
)
.expect("that make_date parsed values without error");

if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res {
Expand All @@ -249,13 +251,15 @@ mod tests {
panic!("Expected a scalar value")
}

#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let res = MakeDateFunc::new()
.invoke(&[
ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))),
ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))),
ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))),
])
.invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))),
ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))),
ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))),
],
1,
)
.expect("that make_date parsed values without error");

if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res {
Expand All @@ -264,13 +268,15 @@ mod tests {
panic!("Expected a scalar value")
}

#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let res = MakeDateFunc::new()
.invoke(&[
ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))),
ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))),
])
.invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))),
ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))),
],
1,
)
.expect("that make_date parsed values without error");

if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res {
Expand All @@ -282,13 +288,16 @@ mod tests {
let years = Arc::new((2021..2025).map(Some).collect::<Int64Array>());
let months = Arc::new((1..5).map(Some).collect::<Int32Array>());
let days = Arc::new((11..15).map(Some).collect::<UInt32Array>());
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let batch_size = years.len();
let res = MakeDateFunc::new()
.invoke(&[
ColumnarValue::Array(years),
ColumnarValue::Array(months),
ColumnarValue::Array(days),
])
.invoke_batch(
&[
ColumnarValue::Array(years),
ColumnarValue::Array(months),
ColumnarValue::Array(days),
],
batch_size,
)
.expect("that make_date parsed values without error");

if let ColumnarValue::Array(array) = res {
Expand All @@ -308,45 +317,50 @@ mod tests {
//

// invalid number of arguments
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let res = MakeDateFunc::new()
.invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]);
.invoke_batch(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], 1);
assert_eq!(
res.err().unwrap().strip_backtrace(),
"Execution error: make_date function requires 3 arguments, got 1"
);

// invalid type
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let res = MakeDateFunc::new().invoke(&[
ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))),
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)),
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)),
]);
let res = MakeDateFunc::new().invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))),
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)),
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)),
],
1,
);
assert_eq!(
res.err().unwrap().strip_backtrace(),
"Arrow error: Cast error: Casting from Interval(YearMonth) to Int32 not supported"
);

// overflow of month
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let res = MakeDateFunc::new().invoke(&[
ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))),
ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(22))),
]);
let res = MakeDateFunc::new().invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))),
ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(22))),
],
1,
);
assert_eq!(
res.err().unwrap().strip_backtrace(),
"Arrow error: Cast error: Can't cast value 18446744073709551615 to type Int32"
);

// overflow of day
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let res = MakeDateFunc::new().invoke(&[
ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(22))),
ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))),
]);
let res = MakeDateFunc::new().invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(22))),
ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))),
],
1,
);
assert_eq!(
res.err().unwrap().strip_backtrace(),
"Arrow error: Cast error: Can't cast value 4294967295 to type Int32"
Expand Down
60 changes: 36 additions & 24 deletions datafusion/functions/src/datetime/to_char.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,11 @@ mod tests {
];

for (value, format, expected) in scalar_data {
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let result = ToCharFunc::new()
.invoke(&[ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)])
.invoke_batch(
&[ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)],
1,
)
.expect("that to_char parsed values without error");

if let ColumnarValue::Scalar(ScalarValue::Utf8(date)) = result {
Expand Down Expand Up @@ -459,12 +461,15 @@ mod tests {
];

for (value, format, expected) in scalar_array_data {
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let batch_size = format.len();
let result = ToCharFunc::new()
.invoke(&[
ColumnarValue::Scalar(value),
ColumnarValue::Array(Arc::new(format) as ArrayRef),
])
.invoke_batch(
&[
ColumnarValue::Scalar(value),
ColumnarValue::Array(Arc::new(format) as ArrayRef),
],
batch_size,
)
.expect("that to_char parsed values without error");

if let ColumnarValue::Scalar(ScalarValue::Utf8(date)) = result {
Expand Down Expand Up @@ -585,12 +590,15 @@ mod tests {
];

for (value, format, expected) in array_scalar_data {
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let batch_size = value.len();
let result = ToCharFunc::new()
.invoke(&[
ColumnarValue::Array(value as ArrayRef),
ColumnarValue::Scalar(format),
])
.invoke_batch(
&[
ColumnarValue::Array(value as ArrayRef),
ColumnarValue::Scalar(format),
],
batch_size,
)
.expect("that to_char parsed values without error");

if let ColumnarValue::Array(result) = result {
Expand All @@ -602,12 +610,15 @@ mod tests {
}

for (value, format, expected) in array_array_data {
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let batch_size = value.len();
let result = ToCharFunc::new()
.invoke(&[
ColumnarValue::Array(value),
ColumnarValue::Array(Arc::new(format) as ArrayRef),
])
.invoke_batch(
&[
ColumnarValue::Array(value),
ColumnarValue::Array(Arc::new(format) as ArrayRef),
],
batch_size,
)
.expect("that to_char parsed values without error");

if let ColumnarValue::Array(result) = result {
Expand All @@ -623,20 +634,21 @@ mod tests {
//

// invalid number of arguments
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let result = ToCharFunc::new()
.invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]);
.invoke_batch(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], 1);
assert_eq!(
result.err().unwrap().strip_backtrace(),
"Execution error: to_char function requires 2 arguments, got 1"
);

// invalid type
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
let result = ToCharFunc::new().invoke(&[
ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)),
]);
let result = ToCharFunc::new().invoke_batch(
&[
ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)),
],
1,
);
assert_eq!(
result.err().unwrap().strip_backtrace(),
"Execution error: Format for `to_char` must be non-null Utf8, received Timestamp(Nanosecond, None)"
Expand Down
Loading

0 comments on commit 66180fa

Please sign in to comment.