Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix panic in median "AggregateState is not a scalar aggregate" #4488

Merged
merged 3 commits into from
Dec 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ impl std::hash::Hash for ScalarValue {
/// dictionary array
#[inline]
fn get_dict_value<K: ArrowDictionaryKeyType>(
array: &ArrayRef,
array: &dyn Array,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

index: usize,
) -> (&ArrayRef, Option<usize>) {
let dict_array = as_dictionary_array::<K>(array).unwrap();
Expand Down Expand Up @@ -1963,7 +1963,7 @@ impl ScalarValue {
}

fn get_decimal_value_from_array(
array: &ArrayRef,
array: &dyn Array,
index: usize,
precision: u8,
scale: i8,
Expand All @@ -1978,7 +1978,7 @@ impl ScalarValue {
}

/// Converts a value in `array` at `index` into a ScalarValue
pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
pub fn try_from_array(array: &dyn Array, index: usize) -> Result<Self> {
// handle NULL value
if !array.is_valid(index) {
return array.data_type().try_into();
Expand Down
70 changes: 67 additions & 3 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ SELECT stddev_pop(c2) FROM aggregate_test_100
1.3665650368716449

# csv_query_stddev_2
query R
query R
SELECT stddev_pop(c6) FROM aggregate_test_100
----
5.114326382039172e18
Expand Down Expand Up @@ -216,6 +216,70 @@ SELECT approx_median(a) FROM median_f64_nan
----
NaN

# median_multi
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ported the tests to sqllogictest as much of the rest of the aggregate tests had been ported too

# test case for https://github.com/apache/arrow-datafusion/issues/3105
# has an intermediate grouping
statement ok
create table cpu (host string, usage float) as select * from (values
('host0', 90.1),
('host1', 90.2),
('host1', 90.4)
);

query CI rowsort
select host, median(usage) from cpu group by host;
----
host1 90.3
host0 90.1

query CI
select median(usage) from cpu;
----
90.2


statement ok
drop table cpu;

# median_multi_odd

# data is not sorted and has an odd number of values per group
statement ok
create table cpu (host string, usage float) as select * from (values
('host0', 90.2),
('host1', 90.1),
('host1', 90.5),
('host0', 90.5),
('host1', 90.0),
('host1', 90.3),
('host0', 87.9),
('host1', 89.3)
);

query CI rowsort
select host, median(usage) from cpu group by host;
----
host0 90.2
host1 90.1


statement ok
drop table cpu;

# median_multi_even
# data is not sorted and has an odd number of values per group
statement ok
create table cpu (host string, usage float) as select * from (values ('host0', 90.2), ('host1', 90.1), ('host1', 90.5), ('host0', 90.5), ('host1', 90.0), ('host1', 90.3), ('host1', 90.2), ('host1', 90.3));

query CI rowsort
select host, median(usage) from cpu group by host;
----
host1 90.25
host0 90.35

statement ok
drop table cpu

# csv_query_external_table_count
query I
SELECT COUNT(c12) FROM aggregate_test_100
Expand Down Expand Up @@ -818,7 +882,7 @@ select c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count
# SELECT array_agg(c13 ORDER BY c1) FROM aggregate_test_100;

# csv_query_array_cube_agg_with_overflow
query TIIRIII
query TIIRIII
select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2
----
a 1 -88 -17.6 83 -85 5
Expand Down Expand Up @@ -870,7 +934,7 @@ e 847 40.333333333333336 120 -95 21
# query IIII
# SELECT count(nanos), count(micros), count(millis), count(secs) FROM t
# ----
# 3 3 3 3
# 3 3 3 3

# aggregate_timestamps_min
# query TTTT
Expand Down
Loading