Skip to content

Commit

Permalink
Fix incorrect results in BitAnd GroupsAccumulator (#6957)
Browse files Browse the repository at this point in the history
Fix accumulator
  • Loading branch information
alamb authored Jul 13, 2023
1 parent e0cc8c8 commit c07c26c
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 119 deletions.
184 changes: 117 additions & 67 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1420,65 +1420,95 @@ select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.c
2 1 1.414213562373 1


# sum / count for all nulls
statement ok
create table the_nulls as values (null::bigint, 1), (null::bigint, 1), (null::bigint, 2);

# counts should be zeros (even for nulls)
query II
SELECT count(column1), column2 from the_nulls group by column2 order by column2;
----
0 1
0 2

# sums should be null
query II
SELECT sum(column1), column2 from the_nulls group by column2 order by column2;
# aggregates on empty tables
statement ok
CREATE TABLE empty (column1 bigint, column2 int);

# no group by column
query IIRIIIII
SELECT
count(column1), -- counts should be zero, even for nulls
sum(column1), -- other aggregates should be null
avg(column1),
min(column1),
max(column1),
bit_and(column1),
bit_or(column1),
bit_xor(column1)
FROM empty
----
0 NULL NULL NULL NULL NULL NULL NULL

# Same query but with grouping (no groups, so no output)
query IIRIIIIII
SELECT
count(column1),
sum(column1),
avg(column1),
min(column1),
max(column1),
bit_and(column1),
bit_or(column1),
bit_xor(column1),
column2
FROM empty
GROUP BY column2
ORDER BY column2;
----
NULL 1
NULL 2

# avg should be null
query RI
SELECT avg(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2

# bit_and should be null
query II
SELECT bit_and(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2
statement ok
drop table empty

# bit_or should be null
query II
SELECT bit_or(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2
# aggregates on all nulls
statement ok
CREATE TABLE the_nulls
AS VALUES
(null::bigint, 1),
(null::bigint, 1),
(null::bigint, 2);

# bit_xor should be null
query II
SELECT bit_xor(column1), column2 from the_nulls group by column2 order by column2;
select * from the_nulls
----
NULL 1
NULL 2

# min should be null
query II
SELECT min(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2

# max should be null
query II
SELECT max(column1), column2 from the_nulls group by column2 order by column2;
----
NULL 1
NULL 2
# no group by column
query IIRIIIII
SELECT
count(column1), -- counts should be zero, even for nulls
sum(column1), -- other aggregates should be null
avg(column1),
min(column1),
max(column1),
bit_and(column1),
bit_or(column1),
bit_xor(column1)
FROM the_nulls
----
0 NULL NULL NULL NULL NULL NULL NULL

# Same query but with grouping
query IIRIIIIII
SELECT
count(column1), -- counts should be zero, even for nulls
sum(column1), -- other aggregates should be null
avg(column1),
min(column1),
max(column1),
bit_and(column1),
bit_or(column1),
bit_xor(column1),
column2
FROM the_nulls
GROUP BY column2
ORDER BY column2;
----
0 NULL NULL NULL NULL NULL NULL NULL 1
0 NULL NULL NULL NULL NULL NULL NULL 2


statement ok
Expand All @@ -1489,29 +1519,49 @@ create table bit_aggregate_functions (
c1 SMALLINT NOT NULL,
c2 SMALLINT NOT NULL,
c3 SMALLINT,
tag varchar
)
as values
(5, 10, 11),
(33, 11, null),
(9, 12, null);

# query_bit_and
query III
SELECT bit_and(c1), bit_and(c2), bit_and(c3) FROM bit_aggregate_functions
----
1 8 11

# query_bit_or
query III
SELECT bit_or(c1), bit_or(c2), bit_or(c3) FROM bit_aggregate_functions
----
45 15 11
(5, 10, 11, 'A'),
(33, 11, null, 'B'),
(9, 12, null, 'A');

# query_bit_and, query_bit_or, query_bit_xor
query IIIIIIIII
SELECT
bit_and(c1),
bit_and(c2),
bit_and(c3),
bit_or(c1),
bit_or(c2),
bit_or(c3),
bit_xor(c1),
bit_xor(c2),
bit_xor(c3)
FROM bit_aggregate_functions
----
1 8 11 45 15 11 45 13 11

# query_bit_and, query_bit_or, query_bit_xor, with group
query IIIIIIIIIT
SELECT
bit_and(c1),
bit_and(c2),
bit_and(c3),
bit_or(c1),
bit_or(c2),
bit_or(c3),
bit_xor(c1),
bit_xor(c2),
bit_xor(c3),
tag
FROM bit_aggregate_functions
GROUP BY tag
ORDER BY tag
----
1 8 11 13 14 11 12 6 11 A
33 11 NULL 33 11 NULL 33 11 NULL B

# query_bit_xor
query III
SELECT bit_xor(c1), bit_xor(c2), bit_xor(c3) FROM bit_aggregate_functions
----
45 13 11

statement ok
create table bool_aggregate_functions (
Expand Down
83 changes: 32 additions & 51 deletions datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,16 @@ use arrow::compute::{bit_and, bit_or, bit_xor};
use datafusion_row::accessor::RowAccessor;

/// Creates a [`PrimitiveGroupsAccumulator`] with the specified
/// [`ArrowPrimitiveType`] which applies `$FN` to each element
/// [`ArrowPrimitiveType`] that initailizes each accumulator to $START
/// and applies `$FN` to each element
///
/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType
macro_rules! instantiate_primitive_accumulator {
($SELF:expr, $PRIMTYPE:ident, $FN:expr) => {{
Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(
&$SELF.data_type,
$FN,
)))
macro_rules! instantiate_accumulator {
($SELF:expr, $START:expr, $PRIMTYPE:ident, $FN:expr) => {{
Ok(Box::new(
PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$SELF.data_type, $FN)
.with_starting_value($START),
))
}};
}

Expand Down Expand Up @@ -279,35 +280,31 @@ impl AggregateExpr for BitAnd {
use std::ops::BitAndAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int8Type, |x, y| x.bitand_assign(y))
}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int16Type, |x, y| x.bitand_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int32Type, |x, y| x.bitand_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitand_assign(y))
instantiate_accumulator!(self, -1, Int64Type, |x, y| x.bitand_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
instantiate_accumulator!(self, u8::MAX, UInt8Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
instantiate_accumulator!(self, u16::MAX, UInt16Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
instantiate_accumulator!(self, u32::MAX, UInt32Type, |x, y| x
.bitand_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
instantiate_accumulator!(self, u64::MAX, UInt64Type, |x, y| x
.bitand_assign(y))
}

Expand Down Expand Up @@ -517,36 +514,28 @@ impl AggregateExpr for BitOr {
use std::ops::BitOrAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitor_assign(y))
}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitor_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitor_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitor_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitor_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitor_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitor_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
.bitor_assign(y))
instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitor_assign(y))
}

_ => Err(DataFusionError::NotImplemented(format!(
Expand Down Expand Up @@ -756,36 +745,28 @@ impl AggregateExpr for BitXor {
use std::ops::BitXorAssign;
match self.data_type {
DataType::Int8 => {
instantiate_primitive_accumulator!(self, Int8Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitxor_assign(y))
}
DataType::Int16 => {
instantiate_primitive_accumulator!(self, Int16Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitxor_assign(y))
}
DataType::Int32 => {
instantiate_primitive_accumulator!(self, Int32Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitxor_assign(y))
}
DataType::Int64 => {
instantiate_primitive_accumulator!(self, Int64Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt8 => {
instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt16 => {
instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt32 => {
instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitxor_assign(y))
}
DataType::UInt64 => {
instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x
.bitxor_assign(y))
instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitxor_assign(y))
}

_ => Err(DataFusionError::NotImplemented(format!(
Expand Down
Loading

0 comments on commit c07c26c

Please sign in to comment.