Skip to content

Commit

Permalink
Support Binary arrays in starts_with, ends_with and contains (#…
Browse files Browse the repository at this point in the history
…6926)

* add binary support in arrow-string

* cleanup

* remove BinaryArrayType

* format

* avoid duplicate code even more

* try to add back as much original code as possible

* reorder functions to have less diff

* remove use of impl Trait from function return type

* remove use of impl Trait from function return type

* revert like and predicate changes and add back as separate files

* format

* run arrow

* update comment
  • Loading branch information
rluvaton authored Jan 22, 2025
1 parent b8fc91d commit 6fd4607
Show file tree
Hide file tree
Showing 5 changed files with 590 additions and 17 deletions.
23 changes: 23 additions & 0 deletions arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,29 @@ impl<'a> StringArrayType<'a> for &'a StringViewArray {
}
}

/// A trait for Arrow String Arrays, currently three types are supported:
/// - `BinaryArray`
/// - `LargeBinaryArray`
/// - `BinaryViewArray`
///
/// This trait helps to abstract over the different types of binary arrays
/// so that we don't need to duplicate the implementation for each type.
pub trait BinaryArrayType<'a>: ArrayAccessor<Item = &'a [u8]> + Sized {
/// Constructs a new iterator
fn iter(&self) -> ArrayIter<Self>;
}

impl<'a, O: OffsetSizeTrait> BinaryArrayType<'a> for &'a GenericBinaryArray<O> {
fn iter(&self) -> ArrayIter<Self> {
GenericBinaryArray::<O>::iter(self)
}
}
impl<'a> BinaryArrayType<'a> for &'a BinaryViewArray {
fn iter(&self) -> ArrayIter<Self> {
BinaryViewArray::iter(self)
}
}

impl PartialEq for dyn Array + '_ {
fn eq(&self, other: &Self) -> bool {
self.to_data().eq(&other.to_data())
Expand Down
166 changes: 166 additions & 0 deletions arrow-string/src/binary_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Provide SQL's CONTAINS, STARTS_WITH, ENDS_WITH operators for Arrow's binary arrays
use crate::binary_predicate::BinaryPredicate;

use arrow_array::cast::AsArray;
use arrow_array::*;
use arrow_schema::*;
use arrow_select::take::take;

#[derive(Debug)]
pub(crate) enum Op {
Contains,
StartsWith,
EndsWith,
}

impl TryFrom<crate::like::Op> for Op {
type Error = ArrowError;

fn try_from(value: crate::like::Op) -> Result<Self, Self::Error> {
match value {
crate::like::Op::Contains => Ok(Op::Contains),
crate::like::Op::StartsWith => Ok(Op::StartsWith),
crate::like::Op::EndsWith => Ok(Op::EndsWith),
_ => Err(ArrowError::InvalidArgumentError(format!(
"Invalid binary operation: {value}"
))),
}
}
}

impl std::fmt::Display for Op {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Op::Contains => write!(f, "CONTAINS"),
Op::StartsWith => write!(f, "STARTS_WITH"),
Op::EndsWith => write!(f, "ENDS_WITH"),
}
}
}

pub(crate) fn binary_apply<'a, 'i, T: BinaryArrayType<'a> + 'a>(
op: Op,
l: T,
l_s: bool,
l_v: Option<&'a dyn AnyDictionaryArray>,
r: T,
r_s: bool,
r_v: Option<&'a dyn AnyDictionaryArray>,
) -> Result<BooleanArray, ArrowError> {
let l_len = l_v.map(|l| l.len()).unwrap_or(l.len());
if r_s {
let idx = match r_v {
Some(dict) if dict.null_count() != 0 => return Ok(BooleanArray::new_null(l_len)),
Some(dict) => dict.normalized_keys()[0],
None => 0,
};
if r.is_null(idx) {
return Ok(BooleanArray::new_null(l_len));
}
op_scalar::<T>(op, l, l_v, r.value(idx))
} else {
match (l_s, l_v, r_v) {
(true, None, None) => {
let v = l.is_valid(0).then(|| l.value(0));
op_binary(op, std::iter::repeat(v), r.iter())
}
(true, Some(l_v), None) => {
let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]);
let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx)));
op_binary(op, std::iter::repeat(v), r.iter())
}
(true, None, Some(r_v)) => {
let v = l.is_valid(0).then(|| l.value(0));
op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v))
}
(true, Some(l_v), Some(r_v)) => {
let idx = l_v.is_valid(0).then(|| l_v.normalized_keys()[0]);
let v = idx.and_then(|idx| l.is_valid(idx).then(|| l.value(idx)));
op_binary(op, std::iter::repeat(v), vectored_iter(r, r_v))
}
(false, None, None) => op_binary(op, l.iter(), r.iter()),
(false, Some(l_v), None) => op_binary(op, vectored_iter(l, l_v), r.iter()),
(false, None, Some(r_v)) => op_binary(op, l.iter(), vectored_iter(r, r_v)),
(false, Some(l_v), Some(r_v)) => {
op_binary(op, vectored_iter(l, l_v), vectored_iter(r, r_v))
}
}
}
}

#[inline(never)]
fn op_scalar<'a, T: BinaryArrayType<'a>>(
op: Op,
l: T,
l_v: Option<&dyn AnyDictionaryArray>,
r: &[u8],
) -> Result<BooleanArray, ArrowError> {
let r = match op {
Op::Contains => BinaryPredicate::contains(r).evaluate_array(l, false),
Op::StartsWith => BinaryPredicate::StartsWith(r).evaluate_array(l, false),
Op::EndsWith => BinaryPredicate::EndsWith(r).evaluate_array(l, false),
};

Ok(match l_v {
Some(v) => take(&r, v.keys(), None)?.as_boolean().clone(),
None => r,
})
}

fn vectored_iter<'a, T: BinaryArrayType<'a> + 'a>(
a: T,
a_v: &'a dyn AnyDictionaryArray,
) -> impl Iterator<Item = Option<&'a [u8]>> + 'a {
let nulls = a_v.nulls();
let keys = a_v.normalized_keys();
keys.into_iter().enumerate().map(move |(idx, key)| {
if nulls.map(|n| n.is_null(idx)).unwrap_or_default() || a.is_null(key) {
return None;
}
Some(a.value(key))
})
}

#[inline(never)]
fn op_binary<'a>(
op: Op,
l: impl Iterator<Item = Option<&'a [u8]>>,
r: impl Iterator<Item = Option<&'a [u8]>>,
) -> Result<BooleanArray, ArrowError> {
match op {
Op::Contains => Ok(l
.zip(r)
.map(|(l, r)| Some(bytes_contains(l?, r?)))
.collect()),
Op::StartsWith => Ok(l
.zip(r)
.map(|(l, r)| Some(BinaryPredicate::StartsWith(r?).evaluate(l?)))
.collect()),
Op::EndsWith => Ok(l
.zip(r)
.map(|(l, r)| Some(BinaryPredicate::EndsWith(r?).evaluate(l?)))
.collect()),
}
}

fn bytes_contains(haystack: &[u8], needle: &[u8]) -> bool {
memchr::memmem::find(haystack, needle).is_some()
}
178 changes: 178 additions & 0 deletions arrow-string/src/binary_predicate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_array::{Array, ArrayAccessor, BinaryViewArray, BooleanArray};
use arrow_buffer::BooleanBuffer;
use memchr::memmem::Finder;
use std::iter::zip;

/// A binary based predicate
pub enum BinaryPredicate<'a> {
Contains(Finder<'a>),
StartsWith(&'a [u8]),
EndsWith(&'a [u8]),
}

impl<'a> BinaryPredicate<'a> {
pub fn contains(needle: &'a [u8]) -> Self {
Self::Contains(Finder::new(needle))
}

/// Evaluate this predicate against the given haystack
pub fn evaluate(&self, haystack: &[u8]) -> bool {
match self {
Self::Contains(finder) => finder.find(haystack).is_some(),
Self::StartsWith(v) => starts_with(haystack, v, equals_kernel),
Self::EndsWith(v) => ends_with(haystack, v, equals_kernel),
}
}

/// Evaluate this predicate against the elements of `array`
///
/// If `negate` is true the result of the predicate will be negated
#[inline(never)]
pub fn evaluate_array<'i, T>(&self, array: T, negate: bool) -> BooleanArray
where
T: ArrayAccessor<Item = &'i [u8]>,
{
match self {
Self::Contains(finder) => BooleanArray::from_unary(array, |haystack| {
finder.find(haystack).is_some() != negate
}),
Self::StartsWith(v) => {
if let Some(view_array) = array.as_any().downcast_ref::<BinaryViewArray>() {
let nulls = view_array.logical_nulls();
let values = BooleanBuffer::from(
view_array
.prefix_bytes_iter(v.len())
.map(|haystack| equals_bytes(haystack, v, equals_kernel) != negate)
.collect::<Vec<_>>(),
);
BooleanArray::new(values, nulls)
} else {
BooleanArray::from_unary(array, |haystack| {
starts_with(haystack, v, equals_kernel) != negate
})
}
}
Self::EndsWith(v) => {
if let Some(view_array) = array.as_any().downcast_ref::<BinaryViewArray>() {
let nulls = view_array.logical_nulls();
let values = BooleanBuffer::from(
view_array
.suffix_bytes_iter(v.len())
.map(|haystack| equals_bytes(haystack, v, equals_kernel) != negate)
.collect::<Vec<_>>(),
);
BooleanArray::new(values, nulls)
} else {
BooleanArray::from_unary(array, |haystack| {
ends_with(haystack, v, equals_kernel) != negate
})
}
}
}
}
}

fn equals_bytes(lhs: &[u8], rhs: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
lhs.len() == rhs.len() && zip(lhs, rhs).all(byte_eq_kernel)
}

/// This is faster than `[u8]::starts_with` for small slices.
/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
fn starts_with(
haystack: &[u8],
needle: &[u8],
byte_eq_kernel: impl Fn((&u8, &u8)) -> bool,
) -> bool {
if needle.len() > haystack.len() {
false
} else {
zip(haystack, needle).all(byte_eq_kernel)
}
}
/// This is faster than `[u8]::ends_with` for small slices.
/// See <https://github.com/apache/arrow-rs/issues/6107> for more details.
fn ends_with(haystack: &[u8], needle: &[u8], byte_eq_kernel: impl Fn((&u8, &u8)) -> bool) -> bool {
if needle.len() > haystack.len() {
false
} else {
zip(haystack.iter().rev(), needle.iter().rev()).all(byte_eq_kernel)
}
}

fn equals_kernel((n, h): (&u8, &u8)) -> bool {
n == h
}

#[cfg(test)]
mod tests {
use super::BinaryPredicate;

#[test]
fn test_contains() {
assert!(BinaryPredicate::contains(b"hay").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"haystack").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"h").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"k").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"stack").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"sta").evaluate(b"haystack"));
assert!(BinaryPredicate::contains(b"stack").evaluate(b"hay\0stack"));
assert!(BinaryPredicate::contains(b"\0s").evaluate(b"hay\0stack"));
assert!(BinaryPredicate::contains(b"\0").evaluate(b"hay\0stack"));
assert!(BinaryPredicate::contains(b"a").evaluate(b"a"));
// not matching
assert!(!BinaryPredicate::contains(b"hy").evaluate(b"haystack"));
assert!(!BinaryPredicate::contains(b"stackx").evaluate(b"haystack"));
assert!(!BinaryPredicate::contains(b"x").evaluate(b"haystack"));
assert!(!BinaryPredicate::contains(b"haystack haystack").evaluate(b"haystack"));
}

#[test]
fn test_starts_with() {
assert!(BinaryPredicate::StartsWith(b"hay").evaluate(b"haystack"));
assert!(BinaryPredicate::StartsWith(b"h\0ay").evaluate(b"h\0aystack"));
assert!(BinaryPredicate::StartsWith(b"haystack").evaluate(b"haystack"));
assert!(BinaryPredicate::StartsWith(b"ha").evaluate(b"haystack"));
assert!(BinaryPredicate::StartsWith(b"h").evaluate(b"haystack"));
assert!(BinaryPredicate::StartsWith(b"").evaluate(b"haystack"));

assert!(!BinaryPredicate::StartsWith(b"stack").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"haystacks").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"HAY").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"h\0ay").evaluate(b"haystack"));
assert!(!BinaryPredicate::StartsWith(b"hay").evaluate(b"h\0aystack"));
}

#[test]
fn test_ends_with() {
assert!(BinaryPredicate::EndsWith(b"stack").evaluate(b"haystack"));
assert!(BinaryPredicate::EndsWith(b"st\0ack").evaluate(b"hayst\0ack"));
assert!(BinaryPredicate::EndsWith(b"haystack").evaluate(b"haystack"));
assert!(BinaryPredicate::EndsWith(b"ck").evaluate(b"haystack"));
assert!(BinaryPredicate::EndsWith(b"k").evaluate(b"haystack"));
assert!(BinaryPredicate::EndsWith(b"").evaluate(b"haystack"));

assert!(!BinaryPredicate::EndsWith(b"hay").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"STACK").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"haystacks").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"xhaystack").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"st\0ack").evaluate(b"haystack"));
assert!(!BinaryPredicate::EndsWith(b"stack").evaluate(b"hayst\0ack"));
}
}
2 changes: 2 additions & 0 deletions arrow-string/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#![warn(missing_docs)]
//! Arrow string kernels
mod binary_like;
mod binary_predicate;
pub mod concat_elements;
pub mod length;
pub mod like;
Expand Down
Loading

0 comments on commit 6fd4607

Please sign in to comment.