diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index c87a96dca7a4..e1d536382487 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -487,7 +487,17 @@ where // 0 ~ len - 1 let adjusted_zero_index = if index < 0 { if let Ok(index) = index.try_into() { - index + len + // When index < 0 and -index > length, index is clamped to the beginning of the list. + // Otherwise, when index < 0, the index is counted from the end of the list. + // + // Note, we actually test the contrapositive, index < -length, because negating a + // negative will panic if the negative is equal to the smallest representable value + // while negating a positive is always safe. + if index < (O::zero() - O::one()) * len { + O::zero() + } else { + index + len + } } else { return exec_err!("array_slice got invalid index: {}", index); } @@ -575,7 +585,7 @@ where "array_slice got invalid stride: {:?}, it cannot be 0", stride ); - } else if (from <= to && stride.is_negative()) + } else if (from < to && stride.is_negative()) || (from > to && stride.is_positive()) { // return empty array @@ -587,7 +597,7 @@ where internal_datafusion_err!("array_slice got invalid stride: {}", stride) })?; - if from <= to { + if from <= to && stride > O::zero() { assert!(start + to <= end); if stride.eq(&O::one()) { // stride is default to 1 diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ff701b55407c..357c58f441ed 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1941,12 +1941,12 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4 query ?? select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); ---- -[] [] +[1, 2, 3, 4] [h, e, l] query ?? select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7, -2), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7, -3); ---- -[] [] +[1, 2, 3, 4] [h, e, l] # array_slice scalar function #20 (with negative indexes; nested array) query ?? @@ -1993,6 +1993,28 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, ---- [2, 3, 4] [h, e] +# array_slice scalar function #24 (with first negative index larger than len) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -2147483648, 1), list_slice(make_array('h', 'e', 'l', 'l', 'o'), -2147483648, 1); +---- +[1] [h] + +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -9223372036854775808, 1), list_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -9223372036854775808, 1); +---- +[1] [h] + +# array_slice scalar function #25 (with negative step and equal indexes) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 2, -1), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, 2, -1); +---- +[2] [e] + +query ?? +select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 2, -1), list_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, 2, -1); +---- +[2] [e] + # array_slice with columns query ? select array_slice(column1, column2, column3) from slices;