From dff7cba2820ddcfd41d1b0c08e158acd74a5bd2d Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 30 Nov 2022 11:24:28 +0000 Subject: [PATCH] Loosen nullability restrictions added in #3205 (#3226) --- arrow-array/src/array/mod.rs | 2 +- arrow-array/src/array/struct_array.rs | 22 ++-- arrow-cast/src/cast.rs | 3 +- arrow-data/src/data.rs | 139 ++++++++++++++++++++++---- arrow/src/row/mod.rs | 35 +------ 5 files changed, 137 insertions(+), 64 deletions(-) diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index 0f9a2ce59291..5fc44d8965e4 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -916,7 +916,7 @@ mod tests { #[test] fn test_null_struct() { let struct_type = - DataType::Struct(vec![Field::new("data", DataType::Int64, true)]); + DataType::Struct(vec![Field::new("data", DataType::Int64, false)]); let array = new_null_array(&struct_type, 9); let a = array.as_any().downcast_ref::().unwrap(); diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index 7d88cc5c6deb..ffcb3731a889 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -227,13 +227,6 @@ impl From> for StructArray { field_value.data().data_type(), "the field data types must match the array data in a StructArray" ); - // Check nullability of child arrays - if !field_type.is_nullable() { - assert!( - field_value.null_count() == 0, - "non-nullable field cannot have null values" - ); - } }, ); @@ -241,6 +234,10 @@ impl From> for StructArray { .child_data(field_values.into_iter().map(|a| a.into_data()).collect()) .len(length); let array_data = unsafe { array_data.build_unchecked() }; + + // We must validate nullability + array_data.validate_nulls().unwrap(); + Self::from(array_data) } } @@ -283,13 +280,6 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray { field_value.data().data_type(), "the field data types must match the array data in a StructArray" ); - // Check nullability of child arrays - if !field_type.is_nullable() { - assert!( - field_value.null_count() == 0, - "non-nullable field cannot have null values" - ); - } }, ); @@ -298,6 +288,10 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray { .child_data(field_values.into_iter().map(|a| a.into_data()).collect()) .len(length); let array_data = unsafe { array_data.build_unchecked() }; + + // We must validate nullability + array_data.validate_nulls().unwrap(); + Self::from(array_data) } } diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 23be8839593c..aa40ad425a5e 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -6714,7 +6714,8 @@ mod tests { cast_from_null_to_other(&data_type); // Cast null from and to struct - let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, true)]); + let data_type = + DataType::Struct(vec![Field::new("data", DataType::Int64, false)]); cast_from_null_to_other(&data_type); } diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 811696e4dd17..22ef8d187b9d 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -19,6 +19,7 @@ //! common attributes and operations for Arrow array. use crate::{bit_iterator::BitSliceIterator, bitmap::Bitmap}; +use arrow_buffer::bit_chunk_iterator::BitChunks; use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; use arrow_schema::{ArrowError, DataType, IntervalUnit, UnionMode}; use half::f16; @@ -618,7 +619,7 @@ impl ArrayData { /// are within the bounds of the values buffer). /// /// See [ArrayData::validate_full] to validate fully the offset content - /// and the validitiy of utf8 data + /// and the validity of utf8 data pub fn validate(&self) -> Result<(), ArrowError> { // Need at least this mich space in each buffer let len_plus_offset = self.len + self.offset; @@ -961,26 +962,19 @@ impl ArrayData { /// 3. All String data is valid UTF-8 /// 4. All dictionary offsets are valid /// - /// Does not (yet) check - /// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85) - /// Note calls `validate()` internally + /// Internally this calls: + /// + /// * [`Self::validate`] + /// * [`Self::validate_nulls`] + /// * [`Self::validate_values`] + /// + /// And then for each child [`ArrayData`] calls [`ArrayData::validate_full`] + /// pub fn validate_full(&self) -> Result<(), ArrowError> { // Check all buffer sizes prior to looking at them more deeply in this function self.validate()?; - let null_bitmap_buffer = self - .null_bitmap - .as_ref() - .map(|null_bitmap| null_bitmap.buffer_ref()); - - let actual_null_count = count_nulls(null_bitmap_buffer, self.offset, self.len); - if actual_null_count != self.null_count { - return Err(ArrowError::InvalidArgumentError(format!( - "null_count value ({}) doesn't match actual number of nulls in array ({})", - self.null_count, actual_null_count - ))); - } - + self.validate_nulls()?; self.validate_values()?; // validate all children recursively @@ -999,6 +993,117 @@ impl ArrayData { Ok(()) } + /// Validates the the null count is correct and that any + /// nullability requirements of its children are correct + pub fn validate_nulls(&self) -> Result<(), ArrowError> { + let nulls = self.null_buffer(); + + let actual_null_count = count_nulls(nulls, self.offset, self.len); + if actual_null_count != self.null_count { + return Err(ArrowError::InvalidArgumentError(format!( + "null_count value ({}) doesn't match actual number of nulls in array ({})", + self.null_count, actual_null_count + ))); + } + + // In general non-nullable children should not contain nulls, however, for certain + // types, such as StructArray and FixedSizeList, nulls in the parent take up + // space in the child. As such we permit nulls in the children in the corresponding + // positions for such types + match &self.data_type { + DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => { + if !f.is_nullable() { + self.validate_non_nullable(None, 0, &self.child_data[0])? + } + } + DataType::FixedSizeList(field, len) => { + let child = &self.child_data[0]; + if !field.is_nullable() { + match nulls { + Some(nulls) => { + let element_len = *len as usize; + let mut buffer = + MutableBuffer::new_null(element_len * self.len); + + for i in 0..self.len { + if !bit_util::get_bit(nulls.as_ref(), self.offset + i) { + continue; + } + for j in 0..element_len { + bit_util::set_bit( + buffer.as_mut(), + i * element_len + j, + ) + } + } + let mask = buffer.into(); + self.validate_non_nullable(Some(&mask), 0, child)?; + } + None => self.validate_non_nullable(None, 0, child)?, + } + } + } + DataType::Struct(fields) => { + for (field, child) in fields.iter().zip(&self.child_data) { + if !field.is_nullable() { + self.validate_non_nullable(nulls, self.offset, child)? + } + } + } + _ => {} + } + + Ok(()) + } + + /// Verifies that `child` contains no nulls not present in `mask` + fn validate_non_nullable( + &self, + mask: Option<&Buffer>, + offset: usize, + data: &ArrayData, + ) -> Result<(), ArrowError> { + let mask = match mask { + Some(mask) => mask.as_ref(), + None => return match data.null_count { + 0 => Ok(()), + _ => Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent {}", + data.data_type(), + self.data_type + ))), + }, + }; + + match data.null_buffer() { + Some(nulls) => { + let mask = BitChunks::new(mask, offset, data.len); + let nulls = BitChunks::new(nulls.as_ref(), data.offset, data.len); + mask + .iter() + .zip(nulls.iter()) + .chain(std::iter::once(( + mask.remainder_bits(), + nulls.remainder_bits(), + ))).try_for_each(|(m, c)| { + if (m & !c) != 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent", + data.data_type() + ))) + } + Ok(()) + }) + } + None => Ok(()), + } + } + + /// Validates the values stored within this [`ArrayData`] are valid + /// without recursing into child [`ArrayData`] + /// + /// Does not (yet) check + /// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85) pub fn validate_values(&self) -> Result<(), ArrowError> { match &self.data_type { DataType::Utf8 => self.validate_utf8::(), diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs index cff49740fb15..0b28a930c299 100644 --- a/arrow/src/row/mod.rs +++ b/arrow/src/row/mod.rs @@ -1127,36 +1127,11 @@ unsafe fn decode_column( } } Codec::Struct(converter, _) => { - let child_fields = match &field.data_type { - DataType::Struct(f) => f, - _ => unreachable!(), - }; - let (null_count, nulls) = fixed::decode_nulls(rows); rows.iter_mut().for_each(|row| *row = &row[1..]); let children = converter.convert_raw(rows, validate_utf8)?; - let child_data = child_fields - .iter() - .zip(&children) - .map(|(f, c)| { - let data = c.data().clone(); - match f.is_nullable() { - true => data, - false => { - assert_eq!(data.null_count(), null_count); - // Need to strip out null buffer if any as this is created - // as an artifact of the row encoding process that encodes - // nulls from the parent struct array in the children - data.into_builder() - .null_count(0) - .null_bit_buffer(None) - .build_unchecked() - } - } - }) - .collect(); - + let child_data = children.iter().map(|c| c.data().clone()).collect(); let builder = ArrayDataBuilder::new(field.data_type.clone()) .len(rows.len()) .null_count(null_count) @@ -1585,11 +1560,8 @@ mod tests { let back = converter.convert_rows(&r2).unwrap(); assert_eq!(back.len(), 1); assert_eq!(&back[0], &s2); - let back_s = as_struct_array(&back[0]); - for c in back_s.columns() { - // Children should not contain nulls - assert_eq!(c.null_count(), 0); - } + + back[0].data().validate_full().unwrap(); } #[test] @@ -1858,6 +1830,7 @@ mod tests { let back = converter.convert_rows(&rows).unwrap(); for (actual, expected) in back.iter().zip(&arrays) { + actual.data().validate_full().unwrap(); assert_eq!(actual, expected) } }