Skip to content

Commit

Permalink
Specialized string filter
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Feb 2, 2022
1 parent 19cc6ee commit c5912ef
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions arrow/src/compute/kernels/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;
use crate::util::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator};
use crate::util::bit_util;
use num::Zero;
use std::ops::AddAssign;
use std::sync::Arc;
use TimeUnit::*;

Expand Down Expand Up @@ -471,6 +473,20 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<Array
DataType::Duration(TimeUnit::Nanosecond) => {
downcast_filter!(DurationNanosecondType, values, predicate)
}
DataType::Utf8 => {
let values = values
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.unwrap();
Ok(Arc::new(filter_string::<i32>(values, predicate)))
}
DataType::LargeUtf8 => {
let values = values
.as_any()
.downcast_ref::<GenericStringArray<i64>>()
.unwrap();
Ok(Arc::new(filter_string::<i64>(values, predicate)))
}
_ => {
// fallback to using MutableArrayData
let mut mutable = MutableArrayData::new(
Expand Down Expand Up @@ -616,6 +632,7 @@ where
assert_eq!(data.child_data().len(), 0);

let values = data.buffer::<T::Native>(0);
assert!(values.len() >= predicate.filter.len());

let mut buffer = MutableBuffer::with_capacity(predicate.count * T::get_byte_width());

Expand Down Expand Up @@ -654,6 +671,123 @@ where
PrimitiveArray::from(data)
}

/// [`FilterString`] is created from a source [`GenericStringArray`] and can be
/// used to build a new [`GenericStringArray`] by copying values from the source
///
/// TODO(raphael): Could this be used for the take kernel as well?
struct FilterString<'a, OffsetSize> {
src_offsets: &'a [OffsetSize],
src_values: &'a [u8],
dst_offsets: MutableBuffer,
dst_values: MutableBuffer,
cur_offset: OffsetSize,
}

impl<'a, OffsetSize> FilterString<'a, OffsetSize>
where
OffsetSize: Zero + AddAssign + StringOffsetSizeTrait,
{
fn new(capacity: usize, array: &'a GenericStringArray<OffsetSize>) -> Self {
let bytes_offset = (capacity + 1) * std::mem::size_of::<OffsetSize>();
let mut offsets = MutableBuffer::new(bytes_offset);
let values = MutableBuffer::new(0);
let cur_offset = OffsetSize::zero();
offsets.push(cur_offset);

Self {
src_offsets: array.value_offsets(),
src_values: &array.data().buffers()[1],
dst_offsets: offsets,
dst_values: values,
cur_offset,
}
}

/// Returns the byte offset at `idx`
#[inline]
fn get_value_offset(&self, idx: usize) -> usize {
self.src_offsets[idx].to_usize().expect("illegal offset")
}

/// Returns the start and end of the value at index `idx` along with its length
#[inline]
fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
// These can only fail if `array` contains invalid data
let start = self.get_value_offset(idx);
let end = self.get_value_offset(idx + 1);
let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
(start, end, len)
}

/// Extends the in-progress array by the indexes in the provided iterator
fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
for idx in iter {
let (start, end, len) = self.get_value_range(idx);
self.cur_offset += len;
self.dst_offsets.push(self.cur_offset); // push_unchecked?
self.dst_values
.extend_from_slice(&self.src_values[start..end]);
}
}

/// Extends the in-progress array by the ranges in the provided iterator
fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
for slice in iter {
// These can only fail if `array` contains invalid data
for idx in slice.0..slice.1 {
let (_, _, len) = self.get_value_range(idx);
self.cur_offset += len;
self.dst_offsets.push(self.cur_offset); // push_unchecked?
}

let start = self.get_value_offset(slice.0);
let end = self.get_value_offset(slice.1);
self.dst_values
.extend_from_slice(&self.src_values[start..end]);
}
}
}

/// `filter` implementation for string arrays
///
/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
/// data copied across. This allows handling the null mask separately from the data
fn filter_string<OffsetSize>(
array: &GenericStringArray<OffsetSize>,
predicate: &FilterPredicate,
) -> GenericStringArray<OffsetSize>
where
OffsetSize: Zero + AddAssign + StringOffsetSizeTrait,
{
let data = array.data();
assert_eq!(data.buffers().len(), 2);
assert_eq!(data.child_data().len(), 0);
let mut filter = FilterString::new(predicate.count, array);

match &predicate.iterator {
FilterIterator::SlicesIterator => {
filter.extend_slices(SlicesIterator::new(&predicate.filter))
}
FilterIterator::Slices(slices) => filter.extend_slices(slices.iter().cloned()),
FilterIterator::IndexIterator => {
filter.extend_idx(IndexIterator::new(&predicate.filter))
}
FilterIterator::Indices(indices) => filter.extend_idx(indices.iter().cloned()),
}

let mut builder = ArrayDataBuilder::new(data.data_type().clone())
.len(predicate.count)
.add_buffer(filter.dst_offsets.into())
.add_buffer(filter.dst_values.into());

if let Some((null_count, nulls)) = filter_null_mask(data, predicate) {
builder = builder.null_count(null_count).null_bit_buffer(nulls);
}

let data = unsafe { builder.build_unchecked() };
GenericStringArray::from(data)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit c5912ef

Please sign in to comment.