Skip to content

Commit

Permalink
[MINOR] Improve performance of create_hashes (apache#6816)
Browse files Browse the repository at this point in the history
* Only rehash col >=1, specialize primitive hasher

* Fmt

* Clippy

* Typo

* Typo

* Clippy

* Add docs, assertion

* Fmt

---------

Co-authored-by: Daniël Heres <[email protected]>
  • Loading branch information
2 people authored and 2010YOUY01 committed Jul 5, 2023
1 parent 9c05c08 commit 80b1ff6
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 23 deletions.
1 change: 1 addition & 0 deletions datafusion/core/src/physical_plan/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ impl GroupedHashAggregateStream {
// actually the same key value as the group in
// existing_idx (aka group_values @ row)
let group_state = &group_states[*group_idx];

group_rows.row(row) == group_state.group_by_values.row()
});

Expand Down
96 changes: 73 additions & 23 deletions datafusion/physical-expr/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,35 +84,87 @@ macro_rules! hash_float_value {
}
hash_float_value!((half::f16, u16), (f32, u32), (f64, u64));

/// Builds hash values of PrimitiveArray and writes them into `hashes_buffer`
/// If `rehash==true` this combines the previous hash value in the buffer
/// with the new hash using `combine_hashes`
fn hash_array_primitive<T>(
array: &PrimitiveArray<T>,
random_state: &RandomState,
hashes_buffer: &mut [u64],
rehash: bool,
) where
T: ArrowPrimitiveType,
<T as arrow_array::ArrowPrimitiveType>::Native: HashValue,
{
if array.null_count() == 0 {
if rehash {
for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) {
*hash = combine_hashes(value.hash_one(random_state), *hash);
}
} else {
for (hash, &value) in hashes_buffer.iter_mut().zip(array.values().iter()) {
*hash = value.hash_one(random_state);
}
}
} else if rehash {
for (i, hash) in hashes_buffer.iter_mut().enumerate() {
if !array.is_null(i) {
let value = unsafe { array.value_unchecked(i) };
*hash = combine_hashes(value.hash_one(random_state), *hash);
}
}
} else {
for (i, hash) in hashes_buffer.iter_mut().enumerate() {
if !array.is_null(i) {
let value = unsafe { array.value_unchecked(i) };
*hash = value.hash_one(random_state);
}
}
}
}

/// Hashes one array into the `hashes_buffer`
/// If `rehash==true` this combines the previous hash value in the buffer
/// with the new hash using `combine_hashes`
fn hash_array<T>(
array: T,
random_state: &RandomState,
hashes_buffer: &mut [u64],
multi_col: bool,
rehash: bool,
) where
T: ArrayAccessor,
T::Item: HashValue,
{
assert_eq!(
hashes_buffer.len(),
array.len(),
"hashes_buffer and array should be of equal length"
);

if array.null_count() == 0 {
if multi_col {
if rehash {
for (i, hash) in hashes_buffer.iter_mut().enumerate() {
*hash = combine_hashes(array.value(i).hash_one(random_state), *hash);
let value = unsafe { array.value_unchecked(i) };
*hash = combine_hashes(value.hash_one(random_state), *hash);
}
} else {
for (i, hash) in hashes_buffer.iter_mut().enumerate() {
*hash = array.value(i).hash_one(random_state);
let value = unsafe { array.value_unchecked(i) };
*hash = value.hash_one(random_state);
}
}
} else if multi_col {
} else if rehash {
for (i, hash) in hashes_buffer.iter_mut().enumerate() {
if !array.is_null(i) {
*hash = combine_hashes(array.value(i).hash_one(random_state), *hash);
let value = unsafe { array.value_unchecked(i) };
*hash = combine_hashes(value.hash_one(random_state), *hash);
}
}
} else {
for (i, hash) in hashes_buffer.iter_mut().enumerate() {
if !array.is_null(i) {
*hash = array.value(i).hash_one(random_state);
let value = unsafe { array.value_unchecked(i) };
*hash = value.hash_one(random_state);
}
}
}
Expand Down Expand Up @@ -208,34 +260,32 @@ pub fn create_hashes<'a>(
random_state: &RandomState,
hashes_buffer: &'a mut Vec<u64>,
) -> Result<&'a mut Vec<u64>> {
// combine hashes with `combine_hashes` if we have more than 1 column

let multi_col = arrays.len() > 1;

for col in arrays {
for (i, col) in arrays.iter().enumerate() {
let array = col.as_ref();
// combine hashes with `combine_hashes` for all columns besides the first
let rehash = i >= 1;
downcast_primitive_array! {
array => hash_array(array, random_state, hashes_buffer, multi_col),
DataType::Null => hash_null(random_state, hashes_buffer, multi_col),
DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, multi_col),
DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, multi_col),
DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, multi_col),
DataType::Binary => hash_array(as_generic_binary_array::<i32>(array)?, random_state, hashes_buffer, multi_col),
DataType::LargeBinary => hash_array(as_generic_binary_array::<i64>(array)?, random_state, hashes_buffer, multi_col),
array => hash_array_primitive(array, random_state, hashes_buffer, rehash),
DataType::Null => hash_null(random_state, hashes_buffer, rehash),
DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash),
DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash),
DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash),
DataType::Binary => hash_array(as_generic_binary_array::<i32>(array)?, random_state, hashes_buffer, rehash),
DataType::LargeBinary => hash_array(as_generic_binary_array::<i64>(array)?, random_state, hashes_buffer, rehash),
DataType::FixedSizeBinary(_) => {
let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap();
hash_array(array, random_state, hashes_buffer, multi_col)
hash_array(array, random_state, hashes_buffer, rehash)
}
DataType::Decimal128(_, _) => {
let array = as_primitive_array::<Decimal128Type>(array)?;
hash_array(array, random_state, hashes_buffer, multi_col)
hash_array_primitive(array, random_state, hashes_buffer, rehash)
}
DataType::Decimal256(_, _) => {
let array = as_primitive_array::<Decimal256Type>(array)?;
hash_array(array, random_state, hashes_buffer, multi_col)
hash_array_primitive(array, random_state, hashes_buffer, rehash)
}
DataType::Dictionary(_, _) => downcast_dictionary_array! {
array => hash_dictionary(array, random_state, hashes_buffer, multi_col)?,
array => hash_dictionary(array, random_state, hashes_buffer, rehash)?,
_ => unreachable!()
}
_ => {
Expand Down

0 comments on commit 80b1ff6

Please sign in to comment.