Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix null coercion of scalars #2172

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/concepts/compute.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,10 @@ codes.
* `take(Array, indices: Array) -> Array`
* Take the specified nullable indices from the array.
* `filter(Array, mask: Mask) -> Array`
* Filter the array based on the given mask.
* Filter the array based on the given mask.

## Type Coercion

To maximise compatibility with compute engines, Vortex does not perform any implicit type coercion in its compute
functions or expressions. The exception to this is upcasting the nullability of input data types. For example,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about list(u32)? compared to list(u32?).

it is valid to compare a `u32` and `u32?` array, resulting in a `bool?` array.
2 changes: 1 addition & 1 deletion vortex-array/src/builders/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub fn builder_with_capacity(dtype: &DType, capacity: usize) -> Box<dyn ArrayBui
pub trait ArrayBuilderExt: ArrayBuilder {
/// A generic function to append a scalar to the builder.
fn append_scalar(&mut self, scalar: &Scalar) -> VortexResult<()> {
if !scalar.dtype().eq_ignore_nullability(self.dtype()) {
if scalar.dtype() != self.dtype() {
gatesn marked this conversation as resolved.
Show resolved Hide resolved
vortex_bail!(
"Builder has dtype {:?}, scalar has {:?}",
self.dtype(),
Expand Down
25 changes: 13 additions & 12 deletions vortex-array/src/compute/binary_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn binary_numeric(lhs: &Array, rhs: &Array, op: BinaryNumericOperator) -> Vo
}
if !matches!(lhs.dtype(), DType::Primitive(_, _))
|| !matches!(rhs.dtype(), DType::Primitive(_, _))
|| lhs.dtype() != rhs.dtype()
|| !lhs.dtype().eq_ignore_nullability(rhs.dtype())
{
vortex_bail!(
"Numeric operations are only supported on two arrays sharing the same primitive-type: {} {}",
Expand All @@ -114,16 +114,14 @@ pub fn binary_numeric(lhs: &Array, rhs: &Array, op: BinaryNumericOperator) -> Vo
// Check if LHS supports the operation directly.
if let Some(fun) = lhs.vtable().binary_numeric_fn() {
if let Some(result) = fun.binary_numeric(lhs, rhs, op)? {
check_numeric_result(&result, lhs, rhs);
return Ok(result);
return Ok(check_numeric_result(result, lhs, rhs));
}
}

// Check if RHS supports the operation directly.
if let Some(fun) = rhs.vtable().binary_numeric_fn() {
if let Some(result) = fun.binary_numeric(rhs, lhs, op.swap())? {
check_numeric_result(&result, lhs, rhs);
return Ok(result);
return Ok(check_numeric_result(result, lhs, rhs));
}
}

Expand All @@ -135,7 +133,11 @@ pub fn binary_numeric(lhs: &Array, rhs: &Array, op: BinaryNumericOperator) -> Vo
);

// If neither side implements the trait, then we delegate to Arrow compute.
arrow_numeric(lhs.clone(), rhs.clone(), op)
Ok(check_numeric_result(
arrow_numeric(lhs.clone(), rhs.clone(), op)?,
lhs,
rhs,
))
}

/// Implementation of `BinaryBooleanFn` using the Arrow crate.
Expand All @@ -146,8 +148,8 @@ fn arrow_numeric(lhs: Array, rhs: Array, operator: BinaryNumericOperator) -> Vor
let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
let len = lhs.len();

let left = Datum::try_new(lhs.clone())?;
let right = Datum::try_new(rhs.clone())?;
let left = Datum::try_new(lhs)?;
let right = Datum::try_new(rhs)?;

let array = match operator {
BinaryNumericOperator::Add => arrow_arith::numeric::add(&left, &right)?,
Expand All @@ -158,13 +160,11 @@ fn arrow_numeric(lhs: Array, rhs: Array, operator: BinaryNumericOperator) -> Vor
BinaryNumericOperator::RDiv => arrow_arith::numeric::div(&right, &left)?,
};

let result = from_arrow_array_with_len(array, len, nullable)?;
check_numeric_result(&result, &lhs, &rhs);
Ok(result)
from_arrow_array_with_len(array, len, nullable)
}

#[inline(always)]
fn check_numeric_result(result: &Array, lhs: &Array, rhs: &Array) {
fn check_numeric_result(result: Array, lhs: &Array, rhs: &Array) -> Array {
debug_assert_eq!(
result.len(),
lhs.len(),
Expand All @@ -181,6 +181,7 @@ fn check_numeric_result(result: &Array, lhs: &Array, rhs: &Array) {
"Numeric operation dtype mismatch {}",
rhs.encoding()
);
result
}

#[cfg(feature = "test-harness")]
Expand Down
17 changes: 14 additions & 3 deletions vortex-array/src/compute/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,21 @@ pub fn or_kleene(lhs: impl AsRef<Array>, rhs: impl AsRef<Array>) -> VortexResult

pub fn binary_boolean(lhs: &Array, rhs: &Array, op: BinaryOperator) -> VortexResult<Array> {
if lhs.len() != rhs.len() {
vortex_bail!("Boolean operations aren't supported on arrays of different lengths")
vortex_bail!(
"Boolean operations aren't supported on arrays of different lengths: {} and {}",
lhs.len(),
rhs.len()
)
}
if !lhs.dtype().is_boolean() || !rhs.dtype().is_boolean() {
vortex_bail!("Boolean operations are only supported on boolean arrays")
if !lhs.dtype().is_boolean()
|| !rhs.dtype().is_boolean()
|| !lhs.dtype().eq_ignore_nullability(rhs.dtype())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could simplify as !matches!(lhs.dtype(), DType::Boolean(_)) || ...

{
vortex_bail!(
"Boolean operations are only supported on boolean arrays: {} and {}",
lhs.dtype(),
rhs.dtype()
)
}

// If LHS is constant, then we make sure it's on the RHS.
Expand Down
9 changes: 7 additions & 2 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,17 @@ pub fn compare(
vortex_bail!("Compare operations only support arrays of the same length");
}
if !left.dtype().eq_ignore_nullability(right.dtype()) {
vortex_bail!("Compare operations only support arrays of the same type");
vortex_bail!(
"Cannot compare different DTypes {} and {}",
left.dtype(),
right.dtype()
);
}

// TODO(ngates): no reason why not
if left.dtype().is_struct() {
vortex_bail!(
"Compare does not support arrays with Strcut DType, got: {} and {}",
"Compare does not support arrays with Struct DType, got: {} and {}",
left.dtype(),
right.dtype()
)
Expand Down
2 changes: 1 addition & 1 deletion vortex-buffer/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::str::Utf8Error;
use crate::ByteBuffer;

/// A wrapper around a [`ByteBuffer`] that guarantees that the buffer contains valid UTF-8.
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash)]
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct BufferString(ByteBuffer);

impl BufferString {
Expand Down
2 changes: 1 addition & 1 deletion vortex-datafusion/src/persistent/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ mod tests {
assert_eq!(df.clone().count().await.unwrap(), 0);
let my_tbl = session.table("my_tbl").await.unwrap();

// Its valuable to have two insert code paths because they actually behave slightly differently
// It's valuable to have two insert code paths because they actually behave slightly differently
let values = Values {
schema: Arc::new(my_tbl.schema().clone()),
values: vec![vec![
Expand Down
13 changes: 11 additions & 2 deletions vortex-layout/src/layouts/chunked/stats_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ impl StatsAccumulator {
stats.sort_by_key(|s| u8::from(*s));
let builders = stats
.iter()
.map(|s| builder_with_capacity(&s.dtype(&dtype).as_nullable(), 1024))
.map(|s| {
// We always store stats nullable in the stats table because some chunks may be
// missing the statistic.
let stat_dtype = s.dtype(&dtype).as_nullable();
builder_with_capacity(&stat_dtype, 1024)
})
.collect();
Self {
column_dtype: dtype,
Expand All @@ -146,9 +151,13 @@ impl StatsAccumulator {
}

pub fn push_chunk(&mut self, array: &Array) -> VortexResult<()> {
if &self.column_dtype != array.dtype() {
vortex_bail!("Chunk dtype does not match expected stats column dtype");
}

for (s, builder) in self.stats.iter().zip_eq(self.builders.iter_mut()) {
if let Some(v) = array.statistics().compute(*s) {
builder.append_scalar(&Scalar::new(s.dtype(array.dtype()), v))?;
builder.append_scalar(&Scalar::new(s.dtype(array.dtype()).as_nullable(), v))?;
gatesn marked this conversation as resolved.
Show resolved Hide resolved
} else {
builder.append_null();
}
Expand Down
22 changes: 16 additions & 6 deletions vortex-scalar/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,29 @@ use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, Vort

use crate::{InnerScalarValue, Scalar, ScalarValue};

#[derive(Debug, Hash, PartialEq, Eq)]
#[derive(Debug, Hash)]
pub struct BinaryScalar<'a> {
dtype: &'a DType,
value: Option<ByteBuffer>,
}

/// Ord is not implemented since it's undefined for different nullability
impl PartialEq for BinaryScalar<'_> {
fn eq(&self, other: &Self) -> bool {
self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
}
}

impl Eq for BinaryScalar<'_> {}

impl PartialOrd for BinaryScalar<'_> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
if self.dtype != other.dtype {
return None;
}
self.value.partial_cmp(&other.value)
Some(self.value.cmp(&other.value))
}
}

impl Ord for BinaryScalar<'_> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.value.cmp(&other.value)
}
}

Expand Down
33 changes: 28 additions & 5 deletions vortex-scalar/src/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,29 @@ use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, Vort

use crate::{InnerScalarValue, Scalar, ScalarValue};

#[derive(Debug, Hash, PartialEq, Eq)]
#[derive(Debug, Hash)]
pub struct BoolScalar<'a> {
dtype: &'a DType,
value: Option<bool>,
}

impl PartialEq for BoolScalar<'_> {
fn eq(&self, other: &Self) -> bool {
self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
}
}

impl Eq for BoolScalar<'_> {}

impl PartialOrd for BoolScalar<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if self.dtype != other.dtype {
return None;
}
self.value.partial_cmp(&other.value)
Some(self.value.cmp(&other.value))
}
}

impl Ord for BoolScalar<'_> {
fn cmp(&self, other: &Self) -> Ordering {
self.value.cmp(&other.value)
}
}

Expand Down Expand Up @@ -132,11 +143,23 @@ impl From<bool> for ScalarValue {

#[cfg(test)]
mod test {
use vortex_dtype::Nullability::*;

use super::*;

#[test]
fn into_from() {
let scalar: Scalar = false.into();
assert!(!bool::try_from(&scalar).unwrap());
}

#[test]
fn equality() {
assert_eq!(&Scalar::bool(true, Nullable), &Scalar::bool(true, Nullable));
// Equality ignores nullability
assert_eq!(
&Scalar::bool(true, Nullable),
&Scalar::bool(true, NonNullable)
);
}
}
75 changes: 51 additions & 24 deletions vortex-scalar/src/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,53 +104,80 @@ impl TryFrom<Scalar> for ScalarValue {
impl From<ScalarValue> for Scalar {
fn from(value: ScalarValue) -> Scalar {
match value {
ScalarValue::Null => Some(Scalar::null(DType::Null)),
ScalarValue::Boolean(b) => b.map(Scalar::from),
ScalarValue::Float16(f) => f.map(Scalar::from),
ScalarValue::Float32(f) => f.map(Scalar::from),
ScalarValue::Float64(f) => f.map(Scalar::from),
ScalarValue::Int8(i) => i.map(Scalar::from),
ScalarValue::Int16(i) => i.map(Scalar::from),
ScalarValue::Int32(i) => i.map(Scalar::from),
ScalarValue::Int64(i) => i.map(Scalar::from),
ScalarValue::UInt8(i) => i.map(Scalar::from),
ScalarValue::UInt16(i) => i.map(Scalar::from),
ScalarValue::UInt32(i) => i.map(Scalar::from),
ScalarValue::UInt64(i) => i.map(Scalar::from),
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => {
s.as_ref().map(|s| Scalar::from(s.as_str()))
}
ScalarValue::Null => Scalar::null(DType::Null),
ScalarValue::Boolean(b) => b
.map(Scalar::from)
.unwrap_or_else(|| Scalar::null(DType::Bool(Nullability::Nullable))),
ScalarValue::Float16(f) => f.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::F16, Nullability::Nullable))
}),
ScalarValue::Float32(f) => f.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::F32, Nullability::Nullable))
}),
ScalarValue::Float64(f) => f.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable))
}),
ScalarValue::Int8(i) => i.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::I8, Nullability::Nullable))
}),
ScalarValue::Int16(i) => i.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::I16, Nullability::Nullable))
}),
ScalarValue::Int32(i) => i.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
}),
ScalarValue::Int64(i) => i.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::I64, Nullability::Nullable))
}),
ScalarValue::UInt8(i) => i.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::U8, Nullability::Nullable))
}),
ScalarValue::UInt16(i) => i.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::U16, Nullability::Nullable))
}),
ScalarValue::UInt32(i) => i.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::U32, Nullability::Nullable))
}),
ScalarValue::UInt64(i) => i.map(Scalar::from).unwrap_or_else(|| {
Scalar::null(DType::Primitive(PType::U64, Nullability::Nullable))
}),
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => s
.as_ref()
.map(|s| Scalar::from(s.as_str()))
.unwrap_or_else(|| Scalar::null(DType::Utf8(Nullability::Nullable))),
ScalarValue::Binary(b)
| ScalarValue::BinaryView(b)
| ScalarValue::LargeBinary(b)
| ScalarValue::FixedSizeBinary(_, b) => b
.as_ref()
.map(|b| Scalar::binary(ByteBuffer::from(b.clone()), Nullability::Nullable)),
.map(|b| Scalar::binary(ByteBuffer::from(b.clone()), Nullability::Nullable))
.unwrap_or_else(|| Scalar::null(DType::Binary(Nullability::Nullable))),
ScalarValue::Date32(v)
| ScalarValue::Time32Second(v)
| ScalarValue::Time32Millisecond(v) => v.map(|i| {
| ScalarValue::Time32Millisecond(v) => {
let ext_dtype = make_temporal_ext_dtype(&value.data_type())
.with_nullability(Nullability::Nullable);
Scalar::new(
DType::Extension(Arc::new(ext_dtype)),
crate::ScalarValue(InnerScalarValue::Primitive(PValue::I32(i))),
v.map(|i| crate::ScalarValue(InnerScalarValue::Primitive(PValue::I32(i))))
.unwrap_or_else(crate::ScalarValue::null),
)
}),
}
ScalarValue::Date64(v)
| ScalarValue::Time64Microsecond(v)
| ScalarValue::Time64Nanosecond(v)
| ScalarValue::TimestampSecond(v, _)
| ScalarValue::TimestampMillisecond(v, _)
| ScalarValue::TimestampMicrosecond(v, _)
| ScalarValue::TimestampNanosecond(v, _) => v.map(|i| {
| ScalarValue::TimestampNanosecond(v, _) => {
let ext_dtype = make_temporal_ext_dtype(&value.data_type());
Scalar::new(
DType::Extension(Arc::new(ext_dtype.with_nullability(Nullability::Nullable))),
crate::ScalarValue(InnerScalarValue::Primitive(PValue::I64(i))),
v.map(|i| crate::ScalarValue(InnerScalarValue::Primitive(PValue::I64(i))))
.unwrap_or_else(crate::ScalarValue::null),
)
}),
}
_ => unimplemented!("Can't convert {value:?} value to a Vortex scalar"),
}
.unwrap_or_else(|| Scalar::null(DType::Null))
}
}
Loading
Loading